@rlabs-inc/sparse 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CLAUDE.md +92 -0
- package/libsparse.dylib +0 -0
- package/native/Makefile +30 -0
- package/native/libsparse.dylib +0 -0
- package/native/sparse.h +180 -0
- package/native/sparse.m +734 -0
- package/native/sparse.metal +215 -0
- package/package.json +38 -0
- package/src/ffi.ts +156 -0
- package/src/gpu.ts +382 -0
- package/src/index.ts +7 -0
- package/src/test-debug-spikes.ts +70 -0
- package/src/test-limits.ts +140 -0
- package/src/test-scatter-loop.ts +226 -0
- package/src/test-stress.ts +160 -0
- package/src/test-webgpu.ts +31 -0
package/native/sparse.m
ADDED
|
@@ -0,0 +1,734 @@
|
|
|
1
|
+
// ============================================================================
|
|
2
|
+
// SPARSE - Metal Implementation
|
|
3
|
+
// Objective-C bridge between C API and Metal compute
|
|
4
|
+
// ============================================================================
|
|
5
|
+
|
|
6
|
+
#import <Foundation/Foundation.h>
|
|
7
|
+
#import <Metal/Metal.h>
|
|
8
|
+
#include "sparse.h"
|
|
9
|
+
#include <stdlib.h>
|
|
10
|
+
|
|
11
|
+
// ============================================================================
|
|
12
|
+
// INTERNAL STRUCTURES
|
|
13
|
+
// ============================================================================
|
|
14
|
+
|
|
15
|
+
struct SparseContext {
|
|
16
|
+
id<MTLDevice> device;
|
|
17
|
+
id<MTLCommandQueue> commandQueue;
|
|
18
|
+
id<MTLLibrary> library;
|
|
19
|
+
|
|
20
|
+
// Compiled kernel functions
|
|
21
|
+
id<MTLComputePipelineState> scatterAddPipeline;
|
|
22
|
+
id<MTLComputePipelineState> addArraysPipeline;
|
|
23
|
+
id<MTLComputePipelineState> addScalarPipeline;
|
|
24
|
+
id<MTLComputePipelineState> multiplyArraysPipeline;
|
|
25
|
+
id<MTLComputePipelineState> multiplyScalarPipeline;
|
|
26
|
+
id<MTLComputePipelineState> squarePipeline;
|
|
27
|
+
id<MTLComputePipelineState> greaterEqualPipeline;
|
|
28
|
+
id<MTLComputePipelineState> whereSelectPipeline;
|
|
29
|
+
id<MTLComputePipelineState> whereScalarPipeline;
|
|
30
|
+
id<MTLComputePipelineState> gatherPipeline;
|
|
31
|
+
id<MTLComputePipelineState> gatherBoolPipeline;
|
|
32
|
+
id<MTLComputePipelineState> fillFloatPipeline;
|
|
33
|
+
id<MTLComputePipelineState> fillZerosPipeline;
|
|
34
|
+
id<MTLComputePipelineState> sumReducePipeline;
|
|
35
|
+
|
|
36
|
+
char deviceName[256];
|
|
37
|
+
};
|
|
38
|
+
|
|
39
|
+
struct SparseBuffer {
|
|
40
|
+
id<MTLBuffer> buffer;
|
|
41
|
+
uint32_t count;
|
|
42
|
+
SparseDataType dtype;
|
|
43
|
+
SparseContextRef ctx;
|
|
44
|
+
};
|
|
45
|
+
|
|
46
|
+
// ============================================================================
|
|
47
|
+
// HELPER FUNCTIONS
|
|
48
|
+
// ============================================================================
|
|
49
|
+
|
|
50
|
+
static id<MTLComputePipelineState> createPipeline(SparseContextRef ctx, NSString* functionName) {
|
|
51
|
+
NSError* error = nil;
|
|
52
|
+
id<MTLFunction> function = [ctx->library newFunctionWithName:functionName];
|
|
53
|
+
if (!function) {
|
|
54
|
+
NSLog(@"Failed to find function: %@", functionName);
|
|
55
|
+
return nil;
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
id<MTLComputePipelineState> pipeline = [ctx->device newComputePipelineStateWithFunction:function error:&error];
|
|
59
|
+
if (error) {
|
|
60
|
+
NSLog(@"Failed to create pipeline for %@: %@", functionName, error);
|
|
61
|
+
return nil;
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
return pipeline;
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
static void dispatchCompute(SparseContextRef ctx, id<MTLComputePipelineState> pipeline, uint32_t count) {
|
|
68
|
+
id<MTLCommandBuffer> commandBuffer = [ctx->commandQueue commandBuffer];
|
|
69
|
+
id<MTLComputeCommandEncoder> encoder = [commandBuffer computeCommandEncoder];
|
|
70
|
+
|
|
71
|
+
[encoder setComputePipelineState:pipeline];
|
|
72
|
+
|
|
73
|
+
NSUInteger threadGroupSize = pipeline.maxTotalThreadsPerThreadgroup;
|
|
74
|
+
if (threadGroupSize > 256) threadGroupSize = 256;
|
|
75
|
+
|
|
76
|
+
MTLSize gridSize = MTLSizeMake(count, 1, 1);
|
|
77
|
+
MTLSize groupSize = MTLSizeMake(threadGroupSize, 1, 1);
|
|
78
|
+
|
|
79
|
+
[encoder dispatchThreads:gridSize threadsPerThreadgroup:groupSize];
|
|
80
|
+
[encoder endEncoding];
|
|
81
|
+
|
|
82
|
+
[commandBuffer commit];
|
|
83
|
+
[commandBuffer waitUntilCompleted];
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
// ============================================================================
|
|
87
|
+
// CONTEXT MANAGEMENT
|
|
88
|
+
// ============================================================================
|
|
89
|
+
|
|
90
|
+
SparseContextRef sparse_init(void) {
|
|
91
|
+
@autoreleasepool {
|
|
92
|
+
SparseContextRef ctx = (SparseContextRef)calloc(1, sizeof(struct SparseContext));
|
|
93
|
+
if (!ctx) return NULL;
|
|
94
|
+
|
|
95
|
+
// Get the default Metal device
|
|
96
|
+
ctx->device = MTLCreateSystemDefaultDevice();
|
|
97
|
+
if (!ctx->device) {
|
|
98
|
+
NSLog(@"Metal is not supported on this device");
|
|
99
|
+
free(ctx);
|
|
100
|
+
return NULL;
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
// Store device name
|
|
104
|
+
const char* name = [ctx->device.name UTF8String];
|
|
105
|
+
strncpy(ctx->deviceName, name, sizeof(ctx->deviceName) - 1);
|
|
106
|
+
|
|
107
|
+
// Create command queue
|
|
108
|
+
ctx->commandQueue = [ctx->device newCommandQueue];
|
|
109
|
+
|
|
110
|
+
// Load shader library from source
|
|
111
|
+
NSString* shaderPath = [[NSBundle mainBundle] pathForResource:@"sparse" ofType:@"metal"];
|
|
112
|
+
NSString* shaderSource = nil;
|
|
113
|
+
|
|
114
|
+
if (shaderPath) {
|
|
115
|
+
shaderSource = [NSString stringWithContentsOfFile:shaderPath encoding:NSUTF8StringEncoding error:nil];
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
if (!shaderSource) {
|
|
119
|
+
// Try loading from current directory or embedded
|
|
120
|
+
NSString* currentDir = [[NSFileManager defaultManager] currentDirectoryPath];
|
|
121
|
+
NSString* localPath = [currentDir stringByAppendingPathComponent:@"native/sparse.metal"];
|
|
122
|
+
shaderSource = [NSString stringWithContentsOfFile:localPath encoding:NSUTF8StringEncoding error:nil];
|
|
123
|
+
|
|
124
|
+
if (!shaderSource) {
|
|
125
|
+
// Try relative to library location
|
|
126
|
+
localPath = [currentDir stringByAppendingPathComponent:@"sparse.metal"];
|
|
127
|
+
shaderSource = [NSString stringWithContentsOfFile:localPath encoding:NSUTF8StringEncoding error:nil];
|
|
128
|
+
}
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
if (!shaderSource) {
|
|
132
|
+
NSLog(@"Could not load sparse.metal shader file");
|
|
133
|
+
free(ctx);
|
|
134
|
+
return NULL;
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
NSError* error = nil;
|
|
138
|
+
MTLCompileOptions* options = [[MTLCompileOptions alloc] init];
|
|
139
|
+
ctx->library = [ctx->device newLibraryWithSource:shaderSource options:options error:&error];
|
|
140
|
+
|
|
141
|
+
if (error) {
|
|
142
|
+
NSLog(@"Failed to compile Metal shaders: %@", error);
|
|
143
|
+
free(ctx);
|
|
144
|
+
return NULL;
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
// Create all compute pipelines
|
|
148
|
+
ctx->scatterAddPipeline = createPipeline(ctx, @"scatter_add");
|
|
149
|
+
ctx->addArraysPipeline = createPipeline(ctx, @"add_arrays");
|
|
150
|
+
ctx->addScalarPipeline = createPipeline(ctx, @"add_scalar");
|
|
151
|
+
ctx->multiplyArraysPipeline = createPipeline(ctx, @"multiply_arrays");
|
|
152
|
+
ctx->multiplyScalarPipeline = createPipeline(ctx, @"multiply_scalar");
|
|
153
|
+
ctx->squarePipeline = createPipeline(ctx, @"square");
|
|
154
|
+
ctx->greaterEqualPipeline = createPipeline(ctx, @"greater_equal");
|
|
155
|
+
ctx->whereSelectPipeline = createPipeline(ctx, @"where_select");
|
|
156
|
+
ctx->whereScalarPipeline = createPipeline(ctx, @"where_scalar");
|
|
157
|
+
ctx->gatherPipeline = createPipeline(ctx, @"gather");
|
|
158
|
+
ctx->gatherBoolPipeline = createPipeline(ctx, @"gather_bool");
|
|
159
|
+
ctx->fillFloatPipeline = createPipeline(ctx, @"fill_float");
|
|
160
|
+
ctx->fillZerosPipeline = createPipeline(ctx, @"fill_zeros");
|
|
161
|
+
ctx->sumReducePipeline = createPipeline(ctx, @"sum_reduce");
|
|
162
|
+
|
|
163
|
+
return ctx;
|
|
164
|
+
}
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
void sparse_cleanup(SparseContextRef ctx) {
|
|
168
|
+
if (!ctx) return;
|
|
169
|
+
// ARC handles Metal object cleanup
|
|
170
|
+
free(ctx);
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
void sparse_sync(SparseContextRef ctx) {
|
|
174
|
+
if (!ctx) return;
|
|
175
|
+
// Create and immediately complete a command buffer to sync
|
|
176
|
+
id<MTLCommandBuffer> commandBuffer = [ctx->commandQueue commandBuffer];
|
|
177
|
+
[commandBuffer commit];
|
|
178
|
+
[commandBuffer waitUntilCompleted];
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
const char* sparse_device_name(SparseContextRef ctx) {
|
|
182
|
+
return ctx ? ctx->deviceName : "Unknown";
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
uint64_t sparse_device_memory(SparseContextRef ctx) {
|
|
186
|
+
if (!ctx) return 0;
|
|
187
|
+
return ctx->device.recommendedMaxWorkingSetSize;
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
// ============================================================================
|
|
191
|
+
// BUFFER MANAGEMENT
|
|
192
|
+
// ============================================================================
|
|
193
|
+
|
|
194
|
+
SparseBufferRef sparse_zeros(SparseContextRef ctx, uint32_t count, SparseDataType dtype) {
|
|
195
|
+
if (!ctx || count == 0) return NULL;
|
|
196
|
+
|
|
197
|
+
size_t elementSize = (dtype == SPARSE_FLOAT32) ? sizeof(float) : sizeof(uint32_t);
|
|
198
|
+
size_t bufferSize = count * elementSize;
|
|
199
|
+
|
|
200
|
+
SparseBufferRef buf = (SparseBufferRef)calloc(1, sizeof(struct SparseBuffer));
|
|
201
|
+
buf->buffer = [ctx->device newBufferWithLength:bufferSize options:MTLResourceStorageModeShared];
|
|
202
|
+
buf->count = count;
|
|
203
|
+
buf->dtype = dtype;
|
|
204
|
+
buf->ctx = ctx;
|
|
205
|
+
|
|
206
|
+
// Zero the buffer
|
|
207
|
+
memset(buf->buffer.contents, 0, bufferSize);
|
|
208
|
+
|
|
209
|
+
return buf;
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
SparseBufferRef sparse_full(SparseContextRef ctx, uint32_t count, float value) {
|
|
213
|
+
if (!ctx || count == 0) return NULL;
|
|
214
|
+
|
|
215
|
+
SparseBufferRef buf = sparse_zeros(ctx, count, SPARSE_FLOAT32);
|
|
216
|
+
if (!buf) return NULL;
|
|
217
|
+
|
|
218
|
+
// Use GPU to fill
|
|
219
|
+
id<MTLCommandBuffer> commandBuffer = [ctx->commandQueue commandBuffer];
|
|
220
|
+
id<MTLComputeCommandEncoder> encoder = [commandBuffer computeCommandEncoder];
|
|
221
|
+
|
|
222
|
+
[encoder setComputePipelineState:ctx->fillFloatPipeline];
|
|
223
|
+
[encoder setBuffer:buf->buffer offset:0 atIndex:0];
|
|
224
|
+
[encoder setBytes:&value length:sizeof(float) atIndex:1];
|
|
225
|
+
[encoder setBytes:&count length:sizeof(uint32_t) atIndex:2];
|
|
226
|
+
|
|
227
|
+
NSUInteger threadGroupSize = 256;
|
|
228
|
+
MTLSize gridSize = MTLSizeMake(count, 1, 1);
|
|
229
|
+
MTLSize groupSize = MTLSizeMake(threadGroupSize, 1, 1);
|
|
230
|
+
|
|
231
|
+
[encoder dispatchThreads:gridSize threadsPerThreadgroup:groupSize];
|
|
232
|
+
[encoder endEncoding];
|
|
233
|
+
[commandBuffer commit];
|
|
234
|
+
[commandBuffer waitUntilCompleted];
|
|
235
|
+
|
|
236
|
+
return buf;
|
|
237
|
+
}
|
|
238
|
+
|
|
239
|
+
SparseBufferRef sparse_from_float(SparseContextRef ctx, const float* data, uint32_t count) {
|
|
240
|
+
if (!ctx || !data || count == 0) return NULL;
|
|
241
|
+
|
|
242
|
+
SparseBufferRef buf = (SparseBufferRef)calloc(1, sizeof(struct SparseBuffer));
|
|
243
|
+
buf->buffer = [ctx->device newBufferWithBytes:data
|
|
244
|
+
length:count * sizeof(float)
|
|
245
|
+
options:MTLResourceStorageModeShared];
|
|
246
|
+
buf->count = count;
|
|
247
|
+
buf->dtype = SPARSE_FLOAT32;
|
|
248
|
+
buf->ctx = ctx;
|
|
249
|
+
|
|
250
|
+
return buf;
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
SparseBufferRef sparse_from_uint(SparseContextRef ctx, const uint32_t* data, uint32_t count) {
|
|
254
|
+
if (!ctx || !data || count == 0) return NULL;
|
|
255
|
+
|
|
256
|
+
SparseBufferRef buf = (SparseBufferRef)calloc(1, sizeof(struct SparseBuffer));
|
|
257
|
+
buf->buffer = [ctx->device newBufferWithBytes:data
|
|
258
|
+
length:count * sizeof(uint32_t)
|
|
259
|
+
options:MTLResourceStorageModeShared];
|
|
260
|
+
buf->count = count;
|
|
261
|
+
buf->dtype = SPARSE_UINT32;
|
|
262
|
+
buf->ctx = ctx;
|
|
263
|
+
|
|
264
|
+
return buf;
|
|
265
|
+
}
|
|
266
|
+
|
|
267
|
+
void sparse_to_float(SparseBufferRef buf, float* out, uint32_t count) {
|
|
268
|
+
if (!buf || !out) return;
|
|
269
|
+
uint32_t copyCount = (count < buf->count) ? count : buf->count;
|
|
270
|
+
memcpy(out, buf->buffer.contents, copyCount * sizeof(float));
|
|
271
|
+
}
|
|
272
|
+
|
|
273
|
+
void sparse_to_uint(SparseBufferRef buf, uint32_t* out, uint32_t count) {
|
|
274
|
+
if (!buf || !out) return;
|
|
275
|
+
uint32_t copyCount = (count < buf->count) ? count : buf->count;
|
|
276
|
+
memcpy(out, buf->buffer.contents, copyCount * sizeof(uint32_t));
|
|
277
|
+
}
|
|
278
|
+
|
|
279
|
+
uint32_t sparse_buffer_count(SparseBufferRef buf) {
|
|
280
|
+
return buf ? buf->count : 0;
|
|
281
|
+
}
|
|
282
|
+
|
|
283
|
+
SparseDataType sparse_buffer_dtype(SparseBufferRef buf) {
|
|
284
|
+
return buf ? buf->dtype : SPARSE_FLOAT32;
|
|
285
|
+
}
|
|
286
|
+
|
|
287
|
+
void sparse_buffer_free(SparseBufferRef buf) {
|
|
288
|
+
if (!buf) return;
|
|
289
|
+
// ARC handles MTLBuffer release
|
|
290
|
+
free(buf);
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
// ============================================================================
|
|
294
|
+
// CORE OPERATIONS
|
|
295
|
+
// ============================================================================
|
|
296
|
+
|
|
297
|
+
void sparse_scatter_add(
|
|
298
|
+
SparseContextRef ctx,
|
|
299
|
+
SparseBufferRef target,
|
|
300
|
+
SparseBufferRef indices,
|
|
301
|
+
SparseBufferRef values,
|
|
302
|
+
uint32_t count
|
|
303
|
+
) {
|
|
304
|
+
if (!ctx || !target || !indices || !values || count == 0) return;
|
|
305
|
+
|
|
306
|
+
@autoreleasepool {
|
|
307
|
+
id<MTLCommandBuffer> commandBuffer = [ctx->commandQueue commandBuffer];
|
|
308
|
+
id<MTLComputeCommandEncoder> encoder = [commandBuffer computeCommandEncoder];
|
|
309
|
+
|
|
310
|
+
[encoder setComputePipelineState:ctx->scatterAddPipeline];
|
|
311
|
+
[encoder setBuffer:target->buffer offset:0 atIndex:0];
|
|
312
|
+
[encoder setBuffer:indices->buffer offset:0 atIndex:1];
|
|
313
|
+
[encoder setBuffer:values->buffer offset:0 atIndex:2];
|
|
314
|
+
[encoder setBytes:&count length:sizeof(uint32_t) atIndex:3];
|
|
315
|
+
|
|
316
|
+
NSUInteger threadGroupSize = 256;
|
|
317
|
+
MTLSize gridSize = MTLSizeMake(count, 1, 1);
|
|
318
|
+
MTLSize groupSize = MTLSizeMake(threadGroupSize, 1, 1);
|
|
319
|
+
|
|
320
|
+
[encoder dispatchThreads:gridSize threadsPerThreadgroup:groupSize];
|
|
321
|
+
[encoder endEncoding];
|
|
322
|
+
[commandBuffer commit];
|
|
323
|
+
[commandBuffer waitUntilCompleted];
|
|
324
|
+
}
|
|
325
|
+
}
|
|
326
|
+
|
|
327
|
+
SparseBufferRef sparse_gather(
|
|
328
|
+
SparseContextRef ctx,
|
|
329
|
+
SparseBufferRef source,
|
|
330
|
+
SparseBufferRef indices,
|
|
331
|
+
uint32_t count
|
|
332
|
+
) {
|
|
333
|
+
if (!ctx || !source || !indices || count == 0) return NULL;
|
|
334
|
+
|
|
335
|
+
SparseBufferRef result = sparse_zeros(ctx, count, SPARSE_FLOAT32);
|
|
336
|
+
|
|
337
|
+
@autoreleasepool {
|
|
338
|
+
id<MTLCommandBuffer> commandBuffer = [ctx->commandQueue commandBuffer];
|
|
339
|
+
id<MTLComputeCommandEncoder> encoder = [commandBuffer computeCommandEncoder];
|
|
340
|
+
|
|
341
|
+
[encoder setComputePipelineState:ctx->gatherPipeline];
|
|
342
|
+
[encoder setBuffer:source->buffer offset:0 atIndex:0];
|
|
343
|
+
[encoder setBuffer:indices->buffer offset:0 atIndex:1];
|
|
344
|
+
[encoder setBuffer:result->buffer offset:0 atIndex:2];
|
|
345
|
+
[encoder setBytes:&count length:sizeof(uint32_t) atIndex:3];
|
|
346
|
+
|
|
347
|
+
NSUInteger threadGroupSize = 256;
|
|
348
|
+
MTLSize gridSize = MTLSizeMake(count, 1, 1);
|
|
349
|
+
MTLSize groupSize = MTLSizeMake(threadGroupSize, 1, 1);
|
|
350
|
+
|
|
351
|
+
[encoder dispatchThreads:gridSize threadsPerThreadgroup:groupSize];
|
|
352
|
+
[encoder endEncoding];
|
|
353
|
+
[commandBuffer commit];
|
|
354
|
+
[commandBuffer waitUntilCompleted];
|
|
355
|
+
}
|
|
356
|
+
|
|
357
|
+
return result;
|
|
358
|
+
}
|
|
359
|
+
|
|
360
|
+
SparseBufferRef sparse_gather_bool(
|
|
361
|
+
SparseContextRef ctx,
|
|
362
|
+
SparseBufferRef source,
|
|
363
|
+
SparseBufferRef indices,
|
|
364
|
+
uint32_t count
|
|
365
|
+
) {
|
|
366
|
+
if (!ctx || !source || !indices || count == 0) return NULL;
|
|
367
|
+
|
|
368
|
+
SparseBufferRef result = sparse_zeros(ctx, count, SPARSE_UINT32);
|
|
369
|
+
|
|
370
|
+
@autoreleasepool {
|
|
371
|
+
id<MTLCommandBuffer> commandBuffer = [ctx->commandQueue commandBuffer];
|
|
372
|
+
id<MTLComputeCommandEncoder> encoder = [commandBuffer computeCommandEncoder];
|
|
373
|
+
|
|
374
|
+
[encoder setComputePipelineState:ctx->gatherBoolPipeline];
|
|
375
|
+
[encoder setBuffer:source->buffer offset:0 atIndex:0];
|
|
376
|
+
[encoder setBuffer:indices->buffer offset:0 atIndex:1];
|
|
377
|
+
[encoder setBuffer:result->buffer offset:0 atIndex:2];
|
|
378
|
+
[encoder setBytes:&count length:sizeof(uint32_t) atIndex:3];
|
|
379
|
+
|
|
380
|
+
NSUInteger threadGroupSize = 256;
|
|
381
|
+
MTLSize gridSize = MTLSizeMake(count, 1, 1);
|
|
382
|
+
MTLSize groupSize = MTLSizeMake(threadGroupSize, 1, 1);
|
|
383
|
+
|
|
384
|
+
[encoder dispatchThreads:gridSize threadsPerThreadgroup:groupSize];
|
|
385
|
+
[encoder endEncoding];
|
|
386
|
+
[commandBuffer commit];
|
|
387
|
+
[commandBuffer waitUntilCompleted];
|
|
388
|
+
}
|
|
389
|
+
|
|
390
|
+
return result;
|
|
391
|
+
}
|
|
392
|
+
|
|
393
|
+
// ============================================================================
|
|
394
|
+
// ELEMENT-WISE OPERATIONS
|
|
395
|
+
// ============================================================================
|
|
396
|
+
|
|
397
|
+
SparseBufferRef sparse_add(SparseContextRef ctx, SparseBufferRef a, SparseBufferRef b) {
|
|
398
|
+
if (!ctx || !a || !b) return NULL;
|
|
399
|
+
uint32_t count = a->count;
|
|
400
|
+
|
|
401
|
+
SparseBufferRef result = sparse_zeros(ctx, count, SPARSE_FLOAT32);
|
|
402
|
+
|
|
403
|
+
@autoreleasepool {
|
|
404
|
+
id<MTLCommandBuffer> commandBuffer = [ctx->commandQueue commandBuffer];
|
|
405
|
+
id<MTLComputeCommandEncoder> encoder = [commandBuffer computeCommandEncoder];
|
|
406
|
+
|
|
407
|
+
[encoder setComputePipelineState:ctx->addArraysPipeline];
|
|
408
|
+
[encoder setBuffer:a->buffer offset:0 atIndex:0];
|
|
409
|
+
[encoder setBuffer:b->buffer offset:0 atIndex:1];
|
|
410
|
+
[encoder setBuffer:result->buffer offset:0 atIndex:2];
|
|
411
|
+
[encoder setBytes:&count length:sizeof(uint32_t) atIndex:3];
|
|
412
|
+
|
|
413
|
+
NSUInteger threadGroupSize = 256;
|
|
414
|
+
MTLSize gridSize = MTLSizeMake(count, 1, 1);
|
|
415
|
+
MTLSize groupSize = MTLSizeMake(threadGroupSize, 1, 1);
|
|
416
|
+
|
|
417
|
+
[encoder dispatchThreads:gridSize threadsPerThreadgroup:groupSize];
|
|
418
|
+
[encoder endEncoding];
|
|
419
|
+
[commandBuffer commit];
|
|
420
|
+
[commandBuffer waitUntilCompleted];
|
|
421
|
+
}
|
|
422
|
+
|
|
423
|
+
return result;
|
|
424
|
+
}
|
|
425
|
+
|
|
426
|
+
SparseBufferRef sparse_add_scalar(SparseContextRef ctx, SparseBufferRef a, float scalar) {
|
|
427
|
+
if (!ctx || !a) return NULL;
|
|
428
|
+
uint32_t count = a->count;
|
|
429
|
+
|
|
430
|
+
SparseBufferRef result = sparse_zeros(ctx, count, SPARSE_FLOAT32);
|
|
431
|
+
|
|
432
|
+
@autoreleasepool {
|
|
433
|
+
id<MTLCommandBuffer> commandBuffer = [ctx->commandQueue commandBuffer];
|
|
434
|
+
id<MTLComputeCommandEncoder> encoder = [commandBuffer computeCommandEncoder];
|
|
435
|
+
|
|
436
|
+
[encoder setComputePipelineState:ctx->addScalarPipeline];
|
|
437
|
+
[encoder setBuffer:a->buffer offset:0 atIndex:0];
|
|
438
|
+
[encoder setBytes:&scalar length:sizeof(float) atIndex:1];
|
|
439
|
+
[encoder setBuffer:result->buffer offset:0 atIndex:2];
|
|
440
|
+
[encoder setBytes:&count length:sizeof(uint32_t) atIndex:3];
|
|
441
|
+
|
|
442
|
+
NSUInteger threadGroupSize = 256;
|
|
443
|
+
MTLSize gridSize = MTLSizeMake(count, 1, 1);
|
|
444
|
+
MTLSize groupSize = MTLSizeMake(threadGroupSize, 1, 1);
|
|
445
|
+
|
|
446
|
+
[encoder dispatchThreads:gridSize threadsPerThreadgroup:groupSize];
|
|
447
|
+
[encoder endEncoding];
|
|
448
|
+
[commandBuffer commit];
|
|
449
|
+
[commandBuffer waitUntilCompleted];
|
|
450
|
+
}
|
|
451
|
+
|
|
452
|
+
return result;
|
|
453
|
+
}
|
|
454
|
+
|
|
455
|
+
SparseBufferRef sparse_multiply(SparseContextRef ctx, SparseBufferRef a, SparseBufferRef b) {
|
|
456
|
+
if (!ctx || !a || !b) return NULL;
|
|
457
|
+
uint32_t count = a->count;
|
|
458
|
+
|
|
459
|
+
SparseBufferRef result = sparse_zeros(ctx, count, SPARSE_FLOAT32);
|
|
460
|
+
|
|
461
|
+
@autoreleasepool {
|
|
462
|
+
id<MTLCommandBuffer> commandBuffer = [ctx->commandQueue commandBuffer];
|
|
463
|
+
id<MTLComputeCommandEncoder> encoder = [commandBuffer computeCommandEncoder];
|
|
464
|
+
|
|
465
|
+
[encoder setComputePipelineState:ctx->multiplyArraysPipeline];
|
|
466
|
+
[encoder setBuffer:a->buffer offset:0 atIndex:0];
|
|
467
|
+
[encoder setBuffer:b->buffer offset:0 atIndex:1];
|
|
468
|
+
[encoder setBuffer:result->buffer offset:0 atIndex:2];
|
|
469
|
+
[encoder setBytes:&count length:sizeof(uint32_t) atIndex:3];
|
|
470
|
+
|
|
471
|
+
NSUInteger threadGroupSize = 256;
|
|
472
|
+
MTLSize gridSize = MTLSizeMake(count, 1, 1);
|
|
473
|
+
MTLSize groupSize = MTLSizeMake(threadGroupSize, 1, 1);
|
|
474
|
+
|
|
475
|
+
[encoder dispatchThreads:gridSize threadsPerThreadgroup:groupSize];
|
|
476
|
+
[encoder endEncoding];
|
|
477
|
+
[commandBuffer commit];
|
|
478
|
+
[commandBuffer waitUntilCompleted];
|
|
479
|
+
}
|
|
480
|
+
|
|
481
|
+
return result;
|
|
482
|
+
}
|
|
483
|
+
|
|
484
|
+
SparseBufferRef sparse_multiply_scalar(SparseContextRef ctx, SparseBufferRef a, float scalar) {
|
|
485
|
+
if (!ctx || !a) return NULL;
|
|
486
|
+
uint32_t count = a->count;
|
|
487
|
+
|
|
488
|
+
SparseBufferRef result = sparse_zeros(ctx, count, SPARSE_FLOAT32);
|
|
489
|
+
|
|
490
|
+
@autoreleasepool {
|
|
491
|
+
id<MTLCommandBuffer> commandBuffer = [ctx->commandQueue commandBuffer];
|
|
492
|
+
id<MTLComputeCommandEncoder> encoder = [commandBuffer computeCommandEncoder];
|
|
493
|
+
|
|
494
|
+
[encoder setComputePipelineState:ctx->multiplyScalarPipeline];
|
|
495
|
+
[encoder setBuffer:a->buffer offset:0 atIndex:0];
|
|
496
|
+
[encoder setBytes:&scalar length:sizeof(float) atIndex:1];
|
|
497
|
+
[encoder setBuffer:result->buffer offset:0 atIndex:2];
|
|
498
|
+
[encoder setBytes:&count length:sizeof(uint32_t) atIndex:3];
|
|
499
|
+
|
|
500
|
+
NSUInteger threadGroupSize = 256;
|
|
501
|
+
MTLSize gridSize = MTLSizeMake(count, 1, 1);
|
|
502
|
+
MTLSize groupSize = MTLSizeMake(threadGroupSize, 1, 1);
|
|
503
|
+
|
|
504
|
+
[encoder dispatchThreads:gridSize threadsPerThreadgroup:groupSize];
|
|
505
|
+
[encoder endEncoding];
|
|
506
|
+
[commandBuffer commit];
|
|
507
|
+
[commandBuffer waitUntilCompleted];
|
|
508
|
+
}
|
|
509
|
+
|
|
510
|
+
return result;
|
|
511
|
+
}
|
|
512
|
+
|
|
513
|
+
SparseBufferRef sparse_square(SparseContextRef ctx, SparseBufferRef a) {
|
|
514
|
+
if (!ctx || !a) return NULL;
|
|
515
|
+
uint32_t count = a->count;
|
|
516
|
+
|
|
517
|
+
SparseBufferRef result = sparse_zeros(ctx, count, SPARSE_FLOAT32);
|
|
518
|
+
|
|
519
|
+
@autoreleasepool {
|
|
520
|
+
id<MTLCommandBuffer> commandBuffer = [ctx->commandQueue commandBuffer];
|
|
521
|
+
id<MTLComputeCommandEncoder> encoder = [commandBuffer computeCommandEncoder];
|
|
522
|
+
|
|
523
|
+
[encoder setComputePipelineState:ctx->squarePipeline];
|
|
524
|
+
[encoder setBuffer:a->buffer offset:0 atIndex:0];
|
|
525
|
+
[encoder setBuffer:result->buffer offset:0 atIndex:1];
|
|
526
|
+
[encoder setBytes:&count length:sizeof(uint32_t) atIndex:2];
|
|
527
|
+
|
|
528
|
+
NSUInteger threadGroupSize = 256;
|
|
529
|
+
MTLSize gridSize = MTLSizeMake(count, 1, 1);
|
|
530
|
+
MTLSize groupSize = MTLSizeMake(threadGroupSize, 1, 1);
|
|
531
|
+
|
|
532
|
+
[encoder dispatchThreads:gridSize threadsPerThreadgroup:groupSize];
|
|
533
|
+
[encoder endEncoding];
|
|
534
|
+
[commandBuffer commit];
|
|
535
|
+
[commandBuffer waitUntilCompleted];
|
|
536
|
+
}
|
|
537
|
+
|
|
538
|
+
return result;
|
|
539
|
+
}
|
|
540
|
+
|
|
541
|
+
// ============================================================================
|
|
542
|
+
// CONDITIONAL OPERATIONS
|
|
543
|
+
// ============================================================================
|
|
544
|
+
|
|
545
|
+
SparseBufferRef sparse_greater_equal(SparseContextRef ctx, SparseBufferRef a, float threshold) {
|
|
546
|
+
if (!ctx || !a) return NULL;
|
|
547
|
+
uint32_t count = a->count;
|
|
548
|
+
|
|
549
|
+
SparseBufferRef result = sparse_zeros(ctx, count, SPARSE_UINT32);
|
|
550
|
+
|
|
551
|
+
@autoreleasepool {
|
|
552
|
+
id<MTLCommandBuffer> commandBuffer = [ctx->commandQueue commandBuffer];
|
|
553
|
+
id<MTLComputeCommandEncoder> encoder = [commandBuffer computeCommandEncoder];
|
|
554
|
+
|
|
555
|
+
[encoder setComputePipelineState:ctx->greaterEqualPipeline];
|
|
556
|
+
[encoder setBuffer:a->buffer offset:0 atIndex:0];
|
|
557
|
+
[encoder setBytes:&threshold length:sizeof(float) atIndex:1];
|
|
558
|
+
[encoder setBuffer:result->buffer offset:0 atIndex:2];
|
|
559
|
+
[encoder setBytes:&count length:sizeof(uint32_t) atIndex:3];
|
|
560
|
+
|
|
561
|
+
NSUInteger threadGroupSize = 256;
|
|
562
|
+
MTLSize gridSize = MTLSizeMake(count, 1, 1);
|
|
563
|
+
MTLSize groupSize = MTLSizeMake(threadGroupSize, 1, 1);
|
|
564
|
+
|
|
565
|
+
[encoder dispatchThreads:gridSize threadsPerThreadgroup:groupSize];
|
|
566
|
+
[encoder endEncoding];
|
|
567
|
+
[commandBuffer commit];
|
|
568
|
+
[commandBuffer waitUntilCompleted];
|
|
569
|
+
}
|
|
570
|
+
|
|
571
|
+
return result;
|
|
572
|
+
}
|
|
573
|
+
|
|
574
|
+
SparseBufferRef sparse_where(
|
|
575
|
+
SparseContextRef ctx,
|
|
576
|
+
SparseBufferRef condition,
|
|
577
|
+
SparseBufferRef if_true,
|
|
578
|
+
SparseBufferRef if_false
|
|
579
|
+
) {
|
|
580
|
+
if (!ctx || !condition || !if_true || !if_false) return NULL;
|
|
581
|
+
uint32_t count = condition->count;
|
|
582
|
+
|
|
583
|
+
SparseBufferRef result = sparse_zeros(ctx, count, SPARSE_FLOAT32);
|
|
584
|
+
|
|
585
|
+
@autoreleasepool {
|
|
586
|
+
id<MTLCommandBuffer> commandBuffer = [ctx->commandQueue commandBuffer];
|
|
587
|
+
id<MTLComputeCommandEncoder> encoder = [commandBuffer computeCommandEncoder];
|
|
588
|
+
|
|
589
|
+
[encoder setComputePipelineState:ctx->whereSelectPipeline];
|
|
590
|
+
[encoder setBuffer:condition->buffer offset:0 atIndex:0];
|
|
591
|
+
[encoder setBuffer:if_true->buffer offset:0 atIndex:1];
|
|
592
|
+
[encoder setBuffer:if_false->buffer offset:0 atIndex:2];
|
|
593
|
+
[encoder setBuffer:result->buffer offset:0 atIndex:3];
|
|
594
|
+
[encoder setBytes:&count length:sizeof(uint32_t) atIndex:4];
|
|
595
|
+
|
|
596
|
+
NSUInteger threadGroupSize = 256;
|
|
597
|
+
MTLSize gridSize = MTLSizeMake(count, 1, 1);
|
|
598
|
+
MTLSize groupSize = MTLSizeMake(threadGroupSize, 1, 1);
|
|
599
|
+
|
|
600
|
+
[encoder dispatchThreads:gridSize threadsPerThreadgroup:groupSize];
|
|
601
|
+
[encoder endEncoding];
|
|
602
|
+
[commandBuffer commit];
|
|
603
|
+
[commandBuffer waitUntilCompleted];
|
|
604
|
+
}
|
|
605
|
+
|
|
606
|
+
return result;
|
|
607
|
+
}
|
|
608
|
+
|
|
609
|
+
SparseBufferRef sparse_where_scalar(
|
|
610
|
+
SparseContextRef ctx,
|
|
611
|
+
SparseBufferRef condition,
|
|
612
|
+
float if_true,
|
|
613
|
+
float if_false
|
|
614
|
+
) {
|
|
615
|
+
if (!ctx || !condition) return NULL;
|
|
616
|
+
uint32_t count = condition->count;
|
|
617
|
+
|
|
618
|
+
SparseBufferRef result = sparse_zeros(ctx, count, SPARSE_FLOAT32);
|
|
619
|
+
|
|
620
|
+
@autoreleasepool {
|
|
621
|
+
id<MTLCommandBuffer> commandBuffer = [ctx->commandQueue commandBuffer];
|
|
622
|
+
id<MTLComputeCommandEncoder> encoder = [commandBuffer computeCommandEncoder];
|
|
623
|
+
|
|
624
|
+
[encoder setComputePipelineState:ctx->whereScalarPipeline];
|
|
625
|
+
[encoder setBuffer:condition->buffer offset:0 atIndex:0];
|
|
626
|
+
[encoder setBytes:&if_true length:sizeof(float) atIndex:1];
|
|
627
|
+
[encoder setBytes:&if_false length:sizeof(float) atIndex:2];
|
|
628
|
+
[encoder setBuffer:result->buffer offset:0 atIndex:3];
|
|
629
|
+
[encoder setBytes:&count length:sizeof(uint32_t) atIndex:4];
|
|
630
|
+
|
|
631
|
+
NSUInteger threadGroupSize = 256;
|
|
632
|
+
MTLSize gridSize = MTLSizeMake(count, 1, 1);
|
|
633
|
+
MTLSize groupSize = MTLSizeMake(threadGroupSize, 1, 1);
|
|
634
|
+
|
|
635
|
+
[encoder dispatchThreads:gridSize threadsPerThreadgroup:groupSize];
|
|
636
|
+
[encoder endEncoding];
|
|
637
|
+
[commandBuffer commit];
|
|
638
|
+
[commandBuffer waitUntilCompleted];
|
|
639
|
+
}
|
|
640
|
+
|
|
641
|
+
return result;
|
|
642
|
+
}
|
|
643
|
+
|
|
644
|
+
// ============================================================================
|
|
645
|
+
// REDUCTION OPERATIONS
|
|
646
|
+
// ============================================================================
|
|
647
|
+
|
|
648
|
+
float sparse_sum(SparseContextRef ctx, SparseBufferRef a) {
|
|
649
|
+
if (!ctx || !a) return 0.0f;
|
|
650
|
+
|
|
651
|
+
// For simplicity, read back to CPU and sum
|
|
652
|
+
// TODO: Use GPU reduction for large arrays
|
|
653
|
+
float* data = (float*)malloc(a->count * sizeof(float));
|
|
654
|
+
sparse_to_float(a, data, a->count);
|
|
655
|
+
|
|
656
|
+
float sum = 0.0f;
|
|
657
|
+
for (uint32_t i = 0; i < a->count; i++) {
|
|
658
|
+
sum += data[i];
|
|
659
|
+
}
|
|
660
|
+
|
|
661
|
+
free(data);
|
|
662
|
+
return sum;
|
|
663
|
+
}
|
|
664
|
+
|
|
665
|
+
uint32_t sparse_sum_bool(SparseContextRef ctx, SparseBufferRef a) {
|
|
666
|
+
if (!ctx || !a) return 0;
|
|
667
|
+
|
|
668
|
+
uint32_t* data = (uint32_t*)malloc(a->count * sizeof(uint32_t));
|
|
669
|
+
sparse_to_uint(a, data, a->count);
|
|
670
|
+
|
|
671
|
+
uint32_t sum = 0;
|
|
672
|
+
for (uint32_t i = 0; i < a->count; i++) {
|
|
673
|
+
sum += data[i];
|
|
674
|
+
}
|
|
675
|
+
|
|
676
|
+
free(data);
|
|
677
|
+
return sum;
|
|
678
|
+
}
|
|
679
|
+
|
|
680
|
+
// ============================================================================
|
|
681
|
+
// RANDOM OPERATIONS
|
|
682
|
+
// ============================================================================
|
|
683
|
+
|
|
684
|
+
SparseBufferRef sparse_random_uniform(
|
|
685
|
+
SparseContextRef ctx,
|
|
686
|
+
uint32_t count,
|
|
687
|
+
float low,
|
|
688
|
+
float high
|
|
689
|
+
) {
|
|
690
|
+
if (!ctx || count == 0) return NULL;
|
|
691
|
+
|
|
692
|
+
// Generate on CPU (Metal random is complex)
|
|
693
|
+
float* data = (float*)malloc(count * sizeof(float));
|
|
694
|
+
float range = high - low;
|
|
695
|
+
|
|
696
|
+
for (uint32_t i = 0; i < count; i++) {
|
|
697
|
+
data[i] = low + ((float)arc4random() / (float)UINT32_MAX) * range;
|
|
698
|
+
}
|
|
699
|
+
|
|
700
|
+
SparseBufferRef result = sparse_from_float(ctx, data, count);
|
|
701
|
+
free(data);
|
|
702
|
+
|
|
703
|
+
return result;
|
|
704
|
+
}
|
|
705
|
+
|
|
706
|
+
SparseBufferRef sparse_random_normal(
|
|
707
|
+
SparseContextRef ctx,
|
|
708
|
+
uint32_t count,
|
|
709
|
+
float mean,
|
|
710
|
+
float std
|
|
711
|
+
) {
|
|
712
|
+
if (!ctx || count == 0) return NULL;
|
|
713
|
+
|
|
714
|
+
// Box-Muller transform on CPU
|
|
715
|
+
float* data = (float*)malloc(count * sizeof(float));
|
|
716
|
+
|
|
717
|
+
for (uint32_t i = 0; i < count; i += 2) {
|
|
718
|
+
float u1 = ((float)arc4random() + 1.0f) / ((float)UINT32_MAX + 1.0f);
|
|
719
|
+
float u2 = ((float)arc4random() + 1.0f) / ((float)UINT32_MAX + 1.0f);
|
|
720
|
+
|
|
721
|
+
float z0 = sqrtf(-2.0f * logf(u1)) * cosf(2.0f * M_PI * u2);
|
|
722
|
+
float z1 = sqrtf(-2.0f * logf(u1)) * sinf(2.0f * M_PI * u2);
|
|
723
|
+
|
|
724
|
+
data[i] = mean + z0 * std;
|
|
725
|
+
if (i + 1 < count) {
|
|
726
|
+
data[i + 1] = mean + z1 * std;
|
|
727
|
+
}
|
|
728
|
+
}
|
|
729
|
+
|
|
730
|
+
SparseBufferRef result = sparse_from_float(ctx, data, count);
|
|
731
|
+
free(data);
|
|
732
|
+
|
|
733
|
+
return result;
|
|
734
|
+
}
|