@rlabs-inc/sparse 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,215 @@
1
+ // ============================================================================
2
+ // SPARSE - Metal Compute Shaders for Biological Neural Simulation
3
+ // Sparsity-first GPU compute. No transformer bloat.
4
+ // ============================================================================
5
+
6
+ #include <metal_stdlib>
7
+ using namespace metal;
8
+
9
+ // ============================================================================
10
+ // SCATTER-ADD KERNEL
11
+ // The core operation for sparse synapse transmission.
12
+ // Thread-safe atomic add - indices are READ-ONLY, never corrupted.
13
+ // ============================================================================
14
+
15
+ kernel void scatter_add(
16
+ device float* target [[buffer(0)]], // Target array to add into
17
+ device const uint* indices [[buffer(1)]], // Indices (READ-ONLY!)
18
+ device const float* values [[buffer(2)]], // Values to add
19
+ constant uint& count [[buffer(3)]], // Number of elements
20
+ uint gid [[thread_position_in_grid]]
21
+ ) {
22
+ if (gid >= count) return;
23
+
24
+ uint idx = indices[gid];
25
+ float val = values[gid];
26
+
27
+ // Atomic add - thread-safe, no corruption
28
+ atomic_fetch_add_explicit(
29
+ (device atomic_float*)&target[idx],
30
+ val,
31
+ memory_order_relaxed
32
+ );
33
+ }
34
+
35
+ // ============================================================================
36
+ // ELEMENT-WISE OPERATIONS
37
+ // ============================================================================
38
+
39
+ kernel void add_arrays(
40
+ device const float* a [[buffer(0)]],
41
+ device const float* b [[buffer(1)]],
42
+ device float* result [[buffer(2)]],
43
+ constant uint& count [[buffer(3)]],
44
+ uint gid [[thread_position_in_grid]]
45
+ ) {
46
+ if (gid >= count) return;
47
+ result[gid] = a[gid] + b[gid];
48
+ }
49
+
50
+ kernel void add_scalar(
51
+ device const float* a [[buffer(0)]],
52
+ constant float& scalar [[buffer(1)]],
53
+ device float* result [[buffer(2)]],
54
+ constant uint& count [[buffer(3)]],
55
+ uint gid [[thread_position_in_grid]]
56
+ ) {
57
+ if (gid >= count) return;
58
+ result[gid] = a[gid] + scalar;
59
+ }
60
+
61
+ kernel void multiply_arrays(
62
+ device const float* a [[buffer(0)]],
63
+ device const float* b [[buffer(1)]],
64
+ device float* result [[buffer(2)]],
65
+ constant uint& count [[buffer(3)]],
66
+ uint gid [[thread_position_in_grid]]
67
+ ) {
68
+ if (gid >= count) return;
69
+ result[gid] = a[gid] * b[gid];
70
+ }
71
+
72
+ kernel void multiply_scalar(
73
+ device const float* a [[buffer(0)]],
74
+ constant float& scalar [[buffer(1)]],
75
+ device float* result [[buffer(2)]],
76
+ constant uint& count [[buffer(3)]],
77
+ uint gid [[thread_position_in_grid]]
78
+ ) {
79
+ if (gid >= count) return;
80
+ result[gid] = a[gid] * scalar;
81
+ }
82
+
83
+ kernel void square(
84
+ device const float* a [[buffer(0)]],
85
+ device float* result [[buffer(1)]],
86
+ constant uint& count [[buffer(2)]],
87
+ uint gid [[thread_position_in_grid]]
88
+ ) {
89
+ if (gid >= count) return;
90
+ float val = a[gid];
91
+ result[gid] = val * val;
92
+ }
93
+
94
+ // ============================================================================
95
+ // CONDITIONAL OPERATIONS
96
+ // ============================================================================
97
+
98
+ kernel void greater_equal(
99
+ device const float* a [[buffer(0)]],
100
+ constant float& threshold [[buffer(1)]],
101
+ device uint* result [[buffer(2)]], // Boolean as uint (0 or 1)
102
+ constant uint& count [[buffer(3)]],
103
+ uint gid [[thread_position_in_grid]]
104
+ ) {
105
+ if (gid >= count) return;
106
+ result[gid] = (a[gid] >= threshold) ? 1 : 0;
107
+ }
108
+
109
+ kernel void where_select(
110
+ device const uint* condition [[buffer(0)]], // Boolean mask
111
+ device const float* if_true [[buffer(1)]],
112
+ device const float* if_false [[buffer(2)]],
113
+ device float* result [[buffer(3)]],
114
+ constant uint& count [[buffer(4)]],
115
+ uint gid [[thread_position_in_grid]]
116
+ ) {
117
+ if (gid >= count) return;
118
+ result[gid] = condition[gid] ? if_true[gid] : if_false[gid];
119
+ }
120
+
121
+ kernel void where_scalar(
122
+ device const uint* condition [[buffer(0)]], // Boolean mask
123
+ constant float& if_true [[buffer(1)]],
124
+ constant float& if_false [[buffer(2)]],
125
+ device float* result [[buffer(3)]],
126
+ constant uint& count [[buffer(4)]],
127
+ uint gid [[thread_position_in_grid]]
128
+ ) {
129
+ if (gid >= count) return;
130
+ result[gid] = condition[gid] ? if_true : if_false;
131
+ }
132
+
133
+ // ============================================================================
134
+ // REDUCTION OPERATIONS
135
+ // ============================================================================
136
+
137
+ kernel void sum_reduce(
138
+ device const float* input [[buffer(0)]],
139
+ device float* output [[buffer(1)]],
140
+ constant uint& count [[buffer(2)]],
141
+ threadgroup float* shared [[threadgroup(0)]],
142
+ uint gid [[thread_position_in_grid]],
143
+ uint lid [[thread_position_in_threadgroup]],
144
+ uint group_size [[threads_per_threadgroup]]
145
+ ) {
146
+ // Load into shared memory
147
+ shared[lid] = (gid < count) ? input[gid] : 0.0f;
148
+ threadgroup_barrier(mem_flags::mem_threadgroup);
149
+
150
+ // Parallel reduction in shared memory
151
+ for (uint stride = group_size / 2; stride > 0; stride /= 2) {
152
+ if (lid < stride) {
153
+ shared[lid] += shared[lid + stride];
154
+ }
155
+ threadgroup_barrier(mem_flags::mem_threadgroup);
156
+ }
157
+
158
+ // Write result
159
+ if (lid == 0) {
160
+ atomic_fetch_add_explicit(
161
+ (device atomic_float*)output,
162
+ shared[0],
163
+ memory_order_relaxed
164
+ );
165
+ }
166
+ }
167
+
168
+ // ============================================================================
169
+ // GATHER OPERATION (for sparse indexing)
170
+ // ============================================================================
171
+
172
+ kernel void gather(
173
+ device const float* source [[buffer(0)]],
174
+ device const uint* indices [[buffer(1)]],
175
+ device float* result [[buffer(2)]],
176
+ constant uint& count [[buffer(3)]],
177
+ uint gid [[thread_position_in_grid]]
178
+ ) {
179
+ if (gid >= count) return;
180
+ result[gid] = source[indices[gid]];
181
+ }
182
+
183
+ kernel void gather_bool(
184
+ device const uint* source [[buffer(0)]], // Boolean array as uint
185
+ device const uint* indices [[buffer(1)]],
186
+ device uint* result [[buffer(2)]],
187
+ constant uint& count [[buffer(3)]],
188
+ uint gid [[thread_position_in_grid]]
189
+ ) {
190
+ if (gid >= count) return;
191
+ result[gid] = source[indices[gid]];
192
+ }
193
+
194
+ // ============================================================================
195
+ // FILL OPERATIONS
196
+ // ============================================================================
197
+
198
+ kernel void fill_float(
199
+ device float* arr [[buffer(0)]],
200
+ constant float& value [[buffer(1)]],
201
+ constant uint& count [[buffer(2)]],
202
+ uint gid [[thread_position_in_grid]]
203
+ ) {
204
+ if (gid >= count) return;
205
+ arr[gid] = value;
206
+ }
207
+
208
+ kernel void fill_zeros(
209
+ device float* arr [[buffer(0)]],
210
+ constant uint& count [[buffer(1)]],
211
+ uint gid [[thread_position_in_grid]]
212
+ ) {
213
+ if (gid >= count) return;
214
+ arr[gid] = 0.0f;
215
+ }
package/package.json ADDED
@@ -0,0 +1,38 @@
1
+ {
2
+ "name": "@rlabs-inc/sparse",
3
+ "version": "0.1.0",
4
+ "description": "Sparsity-first GPU compute for biological neural simulation",
5
+ "type": "module",
6
+ "main": "src/index.ts",
7
+ "types": "src/index.ts",
8
+ "files": [
9
+ "src",
10
+ "native",
11
+ "libsparse.dylib",
12
+ "CLAUDE.md"
13
+ ],
14
+ "scripts": {
15
+ "test": "bun test",
16
+ "build": "bun build src/index.ts --outdir dist --target node",
17
+ "build:native": "cd native && make"
18
+ },
19
+ "os": ["darwin"],
20
+ "cpu": ["arm64", "x64"],
21
+ "keywords": [
22
+ "gpu",
23
+ "metal",
24
+ "sparse",
25
+ "neural",
26
+ "simulation",
27
+ "scatter-add",
28
+ "apple-silicon"
29
+ ],
30
+ "author": "RLabs Inc",
31
+ "license": "MIT",
32
+ "dependencies": {
33
+ "@rlabs-inc/signals": "^1.0.0"
34
+ },
35
+ "devDependencies": {
36
+ "@types/bun": "latest"
37
+ }
38
+ }
package/src/ffi.ts ADDED
@@ -0,0 +1,156 @@
1
+ // ============================================================================
2
+ // SPARSE - FFI Bindings
3
+ // TypeScript bindings to the native Metal library via Bun FFI
4
+ // ============================================================================
5
+
6
+ import { dlopen, FFIType, ptr, suffix, type Pointer } from 'bun:ffi'
7
+ import { join, dirname } from 'path'
8
+ import { fileURLToPath } from 'url'
9
+
10
+ // Get the directory of this file to find the native library
11
+ const __dirname = dirname(fileURLToPath(import.meta.url))
12
+ const libPath = join(__dirname, '..', `libsparse.${suffix}`)
13
+
14
+ // Data types
15
+ export const DataType = {
16
+ FLOAT32: 0,
17
+ UINT32: 1,
18
+ } as const
19
+
20
+ export type DataType = typeof DataType[keyof typeof DataType]
21
+
22
+ // Load the native library
23
+ const lib = dlopen(libPath, {
24
+ // Context management
25
+ sparse_init: {
26
+ returns: FFIType.ptr,
27
+ args: [],
28
+ },
29
+ sparse_cleanup: {
30
+ returns: FFIType.void,
31
+ args: [FFIType.ptr],
32
+ },
33
+ sparse_sync: {
34
+ returns: FFIType.void,
35
+ args: [FFIType.ptr],
36
+ },
37
+ sparse_device_name: {
38
+ returns: FFIType.cstring,
39
+ args: [FFIType.ptr],
40
+ },
41
+ sparse_device_memory: {
42
+ returns: FFIType.u64,
43
+ args: [FFIType.ptr],
44
+ },
45
+
46
+ // Buffer management
47
+ sparse_zeros: {
48
+ returns: FFIType.ptr,
49
+ args: [FFIType.ptr, FFIType.u32, FFIType.i32],
50
+ },
51
+ sparse_full: {
52
+ returns: FFIType.ptr,
53
+ args: [FFIType.ptr, FFIType.u32, FFIType.f32],
54
+ },
55
+ sparse_from_float: {
56
+ returns: FFIType.ptr,
57
+ args: [FFIType.ptr, FFIType.ptr, FFIType.u32],
58
+ },
59
+ sparse_from_uint: {
60
+ returns: FFIType.ptr,
61
+ args: [FFIType.ptr, FFIType.ptr, FFIType.u32],
62
+ },
63
+ sparse_to_float: {
64
+ returns: FFIType.void,
65
+ args: [FFIType.ptr, FFIType.ptr, FFIType.u32],
66
+ },
67
+ sparse_to_uint: {
68
+ returns: FFIType.void,
69
+ args: [FFIType.ptr, FFIType.ptr, FFIType.u32],
70
+ },
71
+ sparse_buffer_count: {
72
+ returns: FFIType.u32,
73
+ args: [FFIType.ptr],
74
+ },
75
+ sparse_buffer_dtype: {
76
+ returns: FFIType.i32,
77
+ args: [FFIType.ptr],
78
+ },
79
+ sparse_buffer_free: {
80
+ returns: FFIType.void,
81
+ args: [FFIType.ptr],
82
+ },
83
+
84
+ // Core operations
85
+ sparse_scatter_add: {
86
+ returns: FFIType.void,
87
+ args: [FFIType.ptr, FFIType.ptr, FFIType.ptr, FFIType.ptr, FFIType.u32],
88
+ },
89
+ sparse_gather: {
90
+ returns: FFIType.ptr,
91
+ args: [FFIType.ptr, FFIType.ptr, FFIType.ptr, FFIType.u32],
92
+ },
93
+ sparse_gather_bool: {
94
+ returns: FFIType.ptr,
95
+ args: [FFIType.ptr, FFIType.ptr, FFIType.ptr, FFIType.u32],
96
+ },
97
+
98
+ // Element-wise operations
99
+ sparse_add: {
100
+ returns: FFIType.ptr,
101
+ args: [FFIType.ptr, FFIType.ptr, FFIType.ptr],
102
+ },
103
+ sparse_add_scalar: {
104
+ returns: FFIType.ptr,
105
+ args: [FFIType.ptr, FFIType.ptr, FFIType.f32],
106
+ },
107
+ sparse_multiply: {
108
+ returns: FFIType.ptr,
109
+ args: [FFIType.ptr, FFIType.ptr, FFIType.ptr],
110
+ },
111
+ sparse_multiply_scalar: {
112
+ returns: FFIType.ptr,
113
+ args: [FFIType.ptr, FFIType.ptr, FFIType.f32],
114
+ },
115
+ sparse_square: {
116
+ returns: FFIType.ptr,
117
+ args: [FFIType.ptr, FFIType.ptr],
118
+ },
119
+
120
+ // Conditional operations
121
+ sparse_greater_equal: {
122
+ returns: FFIType.ptr,
123
+ args: [FFIType.ptr, FFIType.ptr, FFIType.f32],
124
+ },
125
+ sparse_where: {
126
+ returns: FFIType.ptr,
127
+ args: [FFIType.ptr, FFIType.ptr, FFIType.ptr, FFIType.ptr],
128
+ },
129
+ sparse_where_scalar: {
130
+ returns: FFIType.ptr,
131
+ args: [FFIType.ptr, FFIType.ptr, FFIType.f32, FFIType.f32],
132
+ },
133
+
134
+ // Reduction operations
135
+ sparse_sum: {
136
+ returns: FFIType.f32,
137
+ args: [FFIType.ptr, FFIType.ptr],
138
+ },
139
+ sparse_sum_bool: {
140
+ returns: FFIType.u32,
141
+ args: [FFIType.ptr, FFIType.ptr],
142
+ },
143
+
144
+ // Random operations
145
+ sparse_random_uniform: {
146
+ returns: FFIType.ptr,
147
+ args: [FFIType.ptr, FFIType.u32, FFIType.f32, FFIType.f32],
148
+ },
149
+ sparse_random_normal: {
150
+ returns: FFIType.ptr,
151
+ args: [FFIType.ptr, FFIType.u32, FFIType.f32, FFIType.f32],
152
+ },
153
+ })
154
+
155
+ export const symbols = lib.symbols
156
+ export type Symbols = typeof symbols