whisper.rn 0.5.0-rc.9 → 0.5.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/cpp/ggml-alloc.c +1 -15
- package/cpp/ggml-backend-reg.cpp +17 -8
- package/cpp/ggml-backend.cpp +15 -22
- package/cpp/ggml-common.h +17 -0
- package/cpp/ggml-cpu/arch/arm/quants.c +132 -596
- package/cpp/ggml-cpu/arch/arm/repack.cpp +14 -286
- package/cpp/ggml-cpu/arch/x86/quants.c +184 -675
- package/cpp/ggml-cpu/arch/x86/repack.cpp +4679 -1657
- package/cpp/ggml-cpu/arch-fallback.h +34 -0
- package/cpp/ggml-cpu/ggml-cpu.c +22 -1
- package/cpp/ggml-cpu/ggml-cpu.cpp +21 -24
- package/cpp/ggml-cpu/ops.cpp +870 -211
- package/cpp/ggml-cpu/ops.h +3 -8
- package/cpp/ggml-cpu/quants.c +35 -0
- package/cpp/ggml-cpu/quants.h +8 -0
- package/cpp/ggml-cpu/repack.cpp +458 -47
- package/cpp/ggml-cpu/repack.h +22 -0
- package/cpp/ggml-cpu/simd-mappings.h +1 -1
- package/cpp/ggml-cpu/traits.cpp +2 -2
- package/cpp/ggml-cpu/traits.h +1 -1
- package/cpp/ggml-cpu/vec.cpp +12 -9
- package/cpp/ggml-cpu/vec.h +107 -13
- package/cpp/ggml-impl.h +77 -0
- package/cpp/ggml-metal-impl.h +51 -12
- package/cpp/ggml-metal.m +610 -115
- package/cpp/ggml-opt.cpp +97 -41
- package/cpp/ggml-opt.h +25 -6
- package/cpp/ggml-quants.c +110 -16
- package/cpp/ggml-quants.h +6 -0
- package/cpp/ggml-whisper-sim.metallib +0 -0
- package/cpp/ggml-whisper.metallib +0 -0
- package/cpp/ggml.c +314 -88
- package/cpp/ggml.h +137 -11
- package/cpp/gguf.cpp +8 -1
- package/cpp/jsi/RNWhisperJSI.cpp +23 -6
- package/cpp/whisper.cpp +15 -6
- package/ios/RNWhisper.mm +6 -6
- package/ios/RNWhisperContext.mm +2 -0
- package/ios/RNWhisperVadContext.mm +2 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-common.h +17 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +77 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +51 -12
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-opt.h +25 -6
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-quants.h +6 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +137 -11
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +17 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +77 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +51 -12
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +25 -6
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +6 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +137 -11
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-common.h +17 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +77 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +51 -12
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-opt.h +25 -6
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-quants.h +6 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +137 -11
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +17 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +77 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +51 -12
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +25 -6
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +6 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +137 -11
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/lib/commonjs/realtime-transcription/RealtimeTranscriber.js +13 -0
- package/lib/commonjs/realtime-transcription/RealtimeTranscriber.js.map +1 -1
- package/lib/module/realtime-transcription/RealtimeTranscriber.js +13 -0
- package/lib/module/realtime-transcription/RealtimeTranscriber.js.map +1 -1
- package/lib/typescript/realtime-transcription/RealtimeTranscriber.d.ts.map +1 -1
- package/lib/typescript/realtime-transcription/types.d.ts +6 -0
- package/lib/typescript/realtime-transcription/types.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/realtime-transcription/RealtimeTranscriber.ts +17 -0
- package/src/realtime-transcription/types.ts +6 -0
|
@@ -23,6 +23,9 @@
|
|
|
23
23
|
#define N_R0_Q8_0 4
|
|
24
24
|
#define N_SG_Q8_0 2
|
|
25
25
|
|
|
26
|
+
#define N_R0_MXFP4 2
|
|
27
|
+
#define N_SG_MXFP4 2
|
|
28
|
+
|
|
26
29
|
#define N_R0_Q2_K 4
|
|
27
30
|
#define N_SG_Q2_K 2
|
|
28
31
|
|
|
@@ -126,8 +129,18 @@ typedef struct {
|
|
|
126
129
|
uint64_t nb2;
|
|
127
130
|
uint64_t nb3;
|
|
128
131
|
uint64_t offs;
|
|
132
|
+
uint64_t o1[8];
|
|
129
133
|
} wsp_ggml_metal_kargs_bin;
|
|
130
134
|
|
|
135
|
+
typedef struct {
|
|
136
|
+
int64_t ne0;
|
|
137
|
+
int64_t ne1;
|
|
138
|
+
size_t nb01;
|
|
139
|
+
size_t nb02;
|
|
140
|
+
size_t nb11;
|
|
141
|
+
size_t nb21;
|
|
142
|
+
} wsp_ggml_metal_kargs_add_id;
|
|
143
|
+
|
|
131
144
|
typedef struct {
|
|
132
145
|
int32_t ne00;
|
|
133
146
|
int32_t ne01;
|
|
@@ -229,14 +242,18 @@ typedef struct {
|
|
|
229
242
|
uint64_t nb21;
|
|
230
243
|
uint64_t nb22;
|
|
231
244
|
uint64_t nb23;
|
|
245
|
+
int32_t ne32;
|
|
246
|
+
int32_t ne33;
|
|
232
247
|
uint64_t nb31;
|
|
248
|
+
uint64_t nb32;
|
|
249
|
+
uint64_t nb33;
|
|
233
250
|
int32_t ne1;
|
|
234
251
|
int32_t ne2;
|
|
235
252
|
float scale;
|
|
236
253
|
float max_bias;
|
|
237
254
|
float m0;
|
|
238
255
|
float m1;
|
|
239
|
-
|
|
256
|
+
int32_t n_head_log2;
|
|
240
257
|
float logit_softcap;
|
|
241
258
|
} wsp_ggml_metal_kargs_flash_attn_ext;
|
|
242
259
|
|
|
@@ -373,8 +390,16 @@ typedef struct {
|
|
|
373
390
|
typedef struct {
|
|
374
391
|
int32_t ne00;
|
|
375
392
|
int32_t ne00_4;
|
|
376
|
-
uint64_t
|
|
393
|
+
uint64_t nb1;
|
|
394
|
+
uint64_t nb2;
|
|
395
|
+
uint64_t nb3;
|
|
377
396
|
float eps;
|
|
397
|
+
int32_t nef1[3];
|
|
398
|
+
int32_t nef2[3];
|
|
399
|
+
int32_t nef3[3];
|
|
400
|
+
uint64_t nbf1[3];
|
|
401
|
+
uint64_t nbf2[3];
|
|
402
|
+
uint64_t nbf3[3];
|
|
378
403
|
} wsp_ggml_metal_kargs_rms_norm;
|
|
379
404
|
|
|
380
405
|
typedef struct {
|
|
@@ -431,6 +456,8 @@ typedef struct{
|
|
|
431
456
|
uint64_t nb1;
|
|
432
457
|
int32_t i00;
|
|
433
458
|
int32_t i10;
|
|
459
|
+
float alpha;
|
|
460
|
+
float limit;
|
|
434
461
|
} wsp_ggml_metal_kargs_glu;
|
|
435
462
|
|
|
436
463
|
typedef struct {
|
|
@@ -461,14 +488,26 @@ typedef struct {
|
|
|
461
488
|
} wsp_ggml_metal_kargs_sum_rows;
|
|
462
489
|
|
|
463
490
|
typedef struct {
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
491
|
+
int32_t ne00;
|
|
492
|
+
int32_t ne01;
|
|
493
|
+
int32_t ne02;
|
|
494
|
+
uint64_t nb01;
|
|
495
|
+
uint64_t nb02;
|
|
496
|
+
uint64_t nb03;
|
|
497
|
+
int32_t ne11;
|
|
498
|
+
int32_t ne12;
|
|
499
|
+
int32_t ne13;
|
|
500
|
+
uint64_t nb11;
|
|
501
|
+
uint64_t nb12;
|
|
502
|
+
uint64_t nb13;
|
|
503
|
+
uint64_t nb1;
|
|
504
|
+
uint64_t nb2;
|
|
505
|
+
uint64_t nb3;
|
|
467
506
|
float scale;
|
|
468
507
|
float max_bias;
|
|
469
508
|
float m0;
|
|
470
509
|
float m1;
|
|
471
|
-
|
|
510
|
+
int32_t n_head_log2;
|
|
472
511
|
} wsp_ggml_metal_kargs_soft_max;
|
|
473
512
|
|
|
474
513
|
typedef struct {
|
|
@@ -499,26 +538,26 @@ typedef struct {
|
|
|
499
538
|
typedef struct {
|
|
500
539
|
int64_t d_state;
|
|
501
540
|
int64_t d_inner;
|
|
541
|
+
int64_t n_head;
|
|
542
|
+
int64_t n_group;
|
|
502
543
|
int64_t n_seq_tokens;
|
|
503
544
|
int64_t n_seqs;
|
|
504
|
-
|
|
545
|
+
int64_t s_off;
|
|
505
546
|
uint64_t nb01;
|
|
506
547
|
uint64_t nb02;
|
|
507
|
-
uint64_t
|
|
548
|
+
uint64_t nb03;
|
|
508
549
|
uint64_t nb11;
|
|
509
550
|
uint64_t nb12;
|
|
510
551
|
uint64_t nb13;
|
|
511
|
-
uint64_t nb20;
|
|
512
552
|
uint64_t nb21;
|
|
513
553
|
uint64_t nb22;
|
|
514
|
-
uint64_t nb30;
|
|
515
554
|
uint64_t nb31;
|
|
516
|
-
uint64_t nb40;
|
|
517
555
|
uint64_t nb41;
|
|
518
556
|
uint64_t nb42;
|
|
519
|
-
uint64_t
|
|
557
|
+
uint64_t nb43;
|
|
520
558
|
uint64_t nb51;
|
|
521
559
|
uint64_t nb52;
|
|
560
|
+
uint64_t nb53;
|
|
522
561
|
} wsp_ggml_metal_kargs_ssm_scan;
|
|
523
562
|
|
|
524
563
|
typedef struct {
|
|
@@ -74,16 +74,26 @@ extern "C" {
|
|
|
74
74
|
WSP_GGML_OPT_BUILD_TYPE_OPT = 30,
|
|
75
75
|
};
|
|
76
76
|
|
|
77
|
+
enum wsp_ggml_opt_optimizer_type {
|
|
78
|
+
WSP_GGML_OPT_OPTIMIZER_TYPE_ADAMW,
|
|
79
|
+
WSP_GGML_OPT_OPTIMIZER_TYPE_SGD,
|
|
80
|
+
|
|
81
|
+
WSP_GGML_OPT_OPTIMIZER_TYPE_COUNT
|
|
82
|
+
};
|
|
83
|
+
|
|
77
84
|
// parameters that control which optimizer is used and how said optimizer tries to find the minimal loss
|
|
78
85
|
struct wsp_ggml_opt_optimizer_params {
|
|
79
|
-
// AdamW optimizer parameters
|
|
80
86
|
struct {
|
|
81
87
|
float alpha; // learning rate
|
|
82
|
-
float beta1;
|
|
83
|
-
float beta2;
|
|
88
|
+
float beta1; // first AdamW momentum
|
|
89
|
+
float beta2; // second AdamW momentum
|
|
84
90
|
float eps; // epsilon for numerical stability
|
|
85
|
-
float wd; // weight decay
|
|
91
|
+
float wd; // weight decay - 0.0f to disable
|
|
86
92
|
} adamw;
|
|
93
|
+
struct {
|
|
94
|
+
float alpha; // learning rate
|
|
95
|
+
float wd; // weight decay
|
|
96
|
+
} sgd;
|
|
87
97
|
};
|
|
88
98
|
|
|
89
99
|
// callback to calculate optimizer parameters prior to a backward pass
|
|
@@ -112,8 +122,11 @@ extern "C" {
|
|
|
112
122
|
|
|
113
123
|
int32_t opt_period; // after how many gradient accumulation steps an optimizer step should be done
|
|
114
124
|
|
|
115
|
-
wsp_ggml_opt_get_optimizer_params get_opt_pars;
|
|
116
|
-
void *
|
|
125
|
+
wsp_ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
|
|
126
|
+
void * get_opt_pars_ud; // userdata for calculating optimizer parameters
|
|
127
|
+
|
|
128
|
+
// only WSP_GGML_OPT_OPTIMIZER_TYPE_ADAMW needs m, v momenta per parameter tensor
|
|
129
|
+
enum wsp_ggml_opt_optimizer_type optimizer;
|
|
117
130
|
};
|
|
118
131
|
|
|
119
132
|
// get parameters for an optimization context with defaults set where possible
|
|
@@ -142,6 +155,10 @@ extern "C" {
|
|
|
142
155
|
// get the gradient accumulator for a node from the forward graph
|
|
143
156
|
WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_opt_grad_acc(wsp_ggml_opt_context_t opt_ctx, struct wsp_ggml_tensor * node);
|
|
144
157
|
|
|
158
|
+
WSP_GGML_API enum wsp_ggml_opt_optimizer_type wsp_ggml_opt_context_optimizer_type(wsp_ggml_opt_context_t); //TODO consistent naming scheme
|
|
159
|
+
|
|
160
|
+
WSP_GGML_API const char * wsp_ggml_opt_optimizer_name(enum wsp_ggml_opt_optimizer_type);
|
|
161
|
+
|
|
145
162
|
// ====== Optimization Result ======
|
|
146
163
|
|
|
147
164
|
WSP_GGML_API wsp_ggml_opt_result_t wsp_ggml_opt_result_init(void);
|
|
@@ -226,12 +243,14 @@ extern "C" {
|
|
|
226
243
|
struct wsp_ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used
|
|
227
244
|
wsp_ggml_opt_dataset_t dataset, // dataset with data and optionally also labels
|
|
228
245
|
enum wsp_ggml_opt_loss_type loss_type, // loss to minimize
|
|
246
|
+
enum wsp_ggml_opt_optimizer_type optimizer, // sgd or adamw
|
|
229
247
|
wsp_ggml_opt_get_optimizer_params get_opt_pars, // callback to get optimizer params, userdata is pointer to epoch (of type int64_t)
|
|
230
248
|
int64_t nepoch, // how many times the dataset should be iterated over
|
|
231
249
|
int64_t nbatch_logical, // datapoints optimizer step, must be a multiple of ndata_batch in inputs/outputs
|
|
232
250
|
float val_split, // fraction of the dataset to use for validation, must be in [0.0f, 1.0f)
|
|
233
251
|
bool silent); // whether or not info prints to stderr should be suppressed
|
|
234
252
|
|
|
253
|
+
|
|
235
254
|
#ifdef __cplusplus
|
|
236
255
|
}
|
|
237
256
|
#endif
|
|
@@ -21,6 +21,8 @@ WSP_GGML_API void wsp_quantize_row_q5_1_ref(const float * WSP_GGML_RESTRICT x, b
|
|
|
21
21
|
WSP_GGML_API void wsp_quantize_row_q8_0_ref(const float * WSP_GGML_RESTRICT x, block_q8_0 * WSP_GGML_RESTRICT y, int64_t k);
|
|
22
22
|
WSP_GGML_API void wsp_quantize_row_q8_1_ref(const float * WSP_GGML_RESTRICT x, block_q8_1 * WSP_GGML_RESTRICT y, int64_t k);
|
|
23
23
|
|
|
24
|
+
WSP_GGML_API void wsp_quantize_row_mxfp4_ref(const float * WSP_GGML_RESTRICT x, block_mxfp4 * WSP_GGML_RESTRICT y, int64_t k);
|
|
25
|
+
|
|
24
26
|
WSP_GGML_API void wsp_quantize_row_q2_K_ref(const float * WSP_GGML_RESTRICT x, block_q2_K * WSP_GGML_RESTRICT y, int64_t k);
|
|
25
27
|
WSP_GGML_API void wsp_quantize_row_q3_K_ref(const float * WSP_GGML_RESTRICT x, block_q3_K * WSP_GGML_RESTRICT y, int64_t k);
|
|
26
28
|
WSP_GGML_API void wsp_quantize_row_q4_K_ref(const float * WSP_GGML_RESTRICT x, block_q4_K * WSP_GGML_RESTRICT y, int64_t k);
|
|
@@ -45,6 +47,8 @@ WSP_GGML_API void wsp_dewsp_quantize_row_q5_1(const block_q5_1 * WSP_GGML_RESTRI
|
|
|
45
47
|
WSP_GGML_API void wsp_dewsp_quantize_row_q8_0(const block_q8_0 * WSP_GGML_RESTRICT x, float * WSP_GGML_RESTRICT y, int64_t k);
|
|
46
48
|
//WSP_GGML_API void wsp_dewsp_quantize_row_q8_1(const block_q8_1 * WSP_GGML_RESTRICT x, float * WSP_GGML_RESTRICT y, int64_t k);
|
|
47
49
|
|
|
50
|
+
WSP_GGML_API void wsp_dewsp_quantize_row_mxfp4(const block_mxfp4 * WSP_GGML_RESTRICT x, float * WSP_GGML_RESTRICT y, int64_t k);
|
|
51
|
+
|
|
48
52
|
WSP_GGML_API void wsp_dewsp_quantize_row_q2_K(const block_q2_K * WSP_GGML_RESTRICT x, float * WSP_GGML_RESTRICT y, int64_t k);
|
|
49
53
|
WSP_GGML_API void wsp_dewsp_quantize_row_q3_K(const block_q3_K * WSP_GGML_RESTRICT x, float * WSP_GGML_RESTRICT y, int64_t k);
|
|
50
54
|
WSP_GGML_API void wsp_dewsp_quantize_row_q4_K(const block_q4_K * WSP_GGML_RESTRICT x, float * WSP_GGML_RESTRICT y, int64_t k);
|
|
@@ -90,6 +94,8 @@ WSP_GGML_API size_t wsp_quantize_q5_0(const float * WSP_GGML_RESTRICT src, void
|
|
|
90
94
|
WSP_GGML_API size_t wsp_quantize_q5_1(const float * WSP_GGML_RESTRICT src, void * WSP_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
|
91
95
|
WSP_GGML_API size_t wsp_quantize_q8_0(const float * WSP_GGML_RESTRICT src, void * WSP_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
|
92
96
|
|
|
97
|
+
WSP_GGML_API size_t wsp_quantize_mxfp4(const float * WSP_GGML_RESTRICT src, void * WSP_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
|
98
|
+
|
|
93
99
|
WSP_GGML_API void wsp_iq2xs_init_impl(enum wsp_ggml_type type);
|
|
94
100
|
WSP_GGML_API void wsp_iq2xs_free_impl(enum wsp_ggml_type type);
|
|
95
101
|
WSP_GGML_API void wsp_iq3xs_init_impl(int grid_size);
|
|
@@ -241,6 +241,8 @@
|
|
|
241
241
|
#define WSP_GGML_ROPE_TYPE_MROPE 8
|
|
242
242
|
#define WSP_GGML_ROPE_TYPE_VISION 24
|
|
243
243
|
|
|
244
|
+
#define WSP_GGML_MROPE_SECTIONS 4
|
|
245
|
+
|
|
244
246
|
#define WSP_GGML_UNUSED(x) (void)(x)
|
|
245
247
|
|
|
246
248
|
#define WSP_GGML_PAD(x, n) (((x) + (n) - 1) & ~((n) - 1))
|
|
@@ -304,6 +306,16 @@
|
|
|
304
306
|
WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
|
|
305
307
|
WSP_GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
306
308
|
|
|
309
|
+
#define WSP_GGML_TENSOR_TERNARY_OP_LOCALS \
|
|
310
|
+
WSP_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
|
|
311
|
+
WSP_GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
|
|
312
|
+
WSP_GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
|
|
313
|
+
WSP_GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) \
|
|
314
|
+
WSP_GGML_TENSOR_LOCALS(int64_t, ne2, src2, ne) \
|
|
315
|
+
WSP_GGML_TENSOR_LOCALS(size_t, nb2, src2, nb) \
|
|
316
|
+
WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
|
|
317
|
+
WSP_GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
318
|
+
|
|
307
319
|
#define WSP_GGML_TENSOR_BINARY_OP_LOCALS01 \
|
|
308
320
|
WSP_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
|
|
309
321
|
WSP_GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
|
|
@@ -314,6 +326,13 @@
|
|
|
314
326
|
extern "C" {
|
|
315
327
|
#endif
|
|
316
328
|
|
|
329
|
+
// Function type used in fatal error callbacks
|
|
330
|
+
typedef void (*wsp_ggml_abort_callback_t)(const char * error_message);
|
|
331
|
+
|
|
332
|
+
// Set the abort callback (passing null will restore original abort functionality: printing a message to stdout)
|
|
333
|
+
// Returns the old callback for chaining
|
|
334
|
+
WSP_GGML_API wsp_ggml_abort_callback_t wsp_ggml_set_abort_callback(wsp_ggml_abort_callback_t callback);
|
|
335
|
+
|
|
317
336
|
WSP_GGML_NORETURN WSP_GGML_ATTRIBUTE_FORMAT(3, 4)
|
|
318
337
|
WSP_GGML_API void wsp_ggml_abort(const char * file, int line, const char * fmt, ...);
|
|
319
338
|
|
|
@@ -388,7 +407,8 @@ extern "C" {
|
|
|
388
407
|
// WSP_GGML_TYPE_IQ4_NL_4_4 = 36,
|
|
389
408
|
// WSP_GGML_TYPE_IQ4_NL_4_8 = 37,
|
|
390
409
|
// WSP_GGML_TYPE_IQ4_NL_8_8 = 38,
|
|
391
|
-
|
|
410
|
+
WSP_GGML_TYPE_MXFP4 = 39, // MXFP4 (1 block)
|
|
411
|
+
WSP_GGML_TYPE_COUNT = 40,
|
|
392
412
|
};
|
|
393
413
|
|
|
394
414
|
// precision
|
|
@@ -423,6 +443,7 @@ extern "C" {
|
|
|
423
443
|
WSP_GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors
|
|
424
444
|
WSP_GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors
|
|
425
445
|
WSP_GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors
|
|
446
|
+
WSP_GGML_FTYPE_MOSTLY_MXFP4 = 25, // except 1d tensors
|
|
426
447
|
};
|
|
427
448
|
|
|
428
449
|
// available tensor operations:
|
|
@@ -431,6 +452,7 @@ extern "C" {
|
|
|
431
452
|
|
|
432
453
|
WSP_GGML_OP_DUP,
|
|
433
454
|
WSP_GGML_OP_ADD,
|
|
455
|
+
WSP_GGML_OP_ADD_ID,
|
|
434
456
|
WSP_GGML_OP_ADD1,
|
|
435
457
|
WSP_GGML_OP_ACC,
|
|
436
458
|
WSP_GGML_OP_SUB,
|
|
@@ -488,7 +510,7 @@ extern "C" {
|
|
|
488
510
|
WSP_GGML_OP_POOL_1D,
|
|
489
511
|
WSP_GGML_OP_POOL_2D,
|
|
490
512
|
WSP_GGML_OP_POOL_2D_BACK,
|
|
491
|
-
WSP_GGML_OP_UPSCALE,
|
|
513
|
+
WSP_GGML_OP_UPSCALE,
|
|
492
514
|
WSP_GGML_OP_PAD,
|
|
493
515
|
WSP_GGML_OP_PAD_REFLECT_1D,
|
|
494
516
|
WSP_GGML_OP_ROLL,
|
|
@@ -520,6 +542,7 @@ extern "C" {
|
|
|
520
542
|
WSP_GGML_OP_CROSS_ENTROPY_LOSS,
|
|
521
543
|
WSP_GGML_OP_CROSS_ENTROPY_LOSS_BACK,
|
|
522
544
|
WSP_GGML_OP_OPT_STEP_ADAMW,
|
|
545
|
+
WSP_GGML_OP_OPT_STEP_SGD,
|
|
523
546
|
|
|
524
547
|
WSP_GGML_OP_GLU,
|
|
525
548
|
|
|
@@ -550,6 +573,9 @@ extern "C" {
|
|
|
550
573
|
WSP_GGML_GLU_OP_REGLU,
|
|
551
574
|
WSP_GGML_GLU_OP_GEGLU,
|
|
552
575
|
WSP_GGML_GLU_OP_SWIGLU,
|
|
576
|
+
WSP_GGML_GLU_OP_SWIGLU_OAI,
|
|
577
|
+
WSP_GGML_GLU_OP_GEGLU_ERF,
|
|
578
|
+
WSP_GGML_GLU_OP_GEGLU_QUICK,
|
|
553
579
|
|
|
554
580
|
WSP_GGML_GLU_OP_COUNT,
|
|
555
581
|
};
|
|
@@ -639,6 +665,9 @@ extern "C" {
|
|
|
639
665
|
|
|
640
666
|
// misc
|
|
641
667
|
|
|
668
|
+
WSP_GGML_API const char * wsp_ggml_version(void);
|
|
669
|
+
WSP_GGML_API const char * wsp_ggml_commit(void);
|
|
670
|
+
|
|
642
671
|
WSP_GGML_API void wsp_ggml_time_init(void); // call this once at the beginning of the program
|
|
643
672
|
WSP_GGML_API int64_t wsp_ggml_time_ms(void);
|
|
644
673
|
WSP_GGML_API int64_t wsp_ggml_time_us(void);
|
|
@@ -819,6 +848,13 @@ extern "C" {
|
|
|
819
848
|
struct wsp_ggml_tensor * b,
|
|
820
849
|
enum wsp_ggml_type type);
|
|
821
850
|
|
|
851
|
+
// dst[i0, i1, i2] = a[i0, i1, i2] + b[i0, ids[i1, i2]]
|
|
852
|
+
WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_add_id(
|
|
853
|
+
struct wsp_ggml_context * ctx,
|
|
854
|
+
struct wsp_ggml_tensor * a,
|
|
855
|
+
struct wsp_ggml_tensor * b,
|
|
856
|
+
struct wsp_ggml_tensor * ids);
|
|
857
|
+
|
|
822
858
|
WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_add1(
|
|
823
859
|
struct wsp_ggml_context * ctx,
|
|
824
860
|
struct wsp_ggml_tensor * a,
|
|
@@ -1137,6 +1173,22 @@ extern "C" {
|
|
|
1137
1173
|
struct wsp_ggml_context * ctx,
|
|
1138
1174
|
struct wsp_ggml_tensor * a);
|
|
1139
1175
|
|
|
1176
|
+
WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_geglu_erf(
|
|
1177
|
+
struct wsp_ggml_context * ctx,
|
|
1178
|
+
struct wsp_ggml_tensor * a);
|
|
1179
|
+
|
|
1180
|
+
WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_geglu_erf_swapped(
|
|
1181
|
+
struct wsp_ggml_context * ctx,
|
|
1182
|
+
struct wsp_ggml_tensor * a);
|
|
1183
|
+
|
|
1184
|
+
WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_geglu_quick(
|
|
1185
|
+
struct wsp_ggml_context * ctx,
|
|
1186
|
+
struct wsp_ggml_tensor * a);
|
|
1187
|
+
|
|
1188
|
+
WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_geglu_quick_swapped(
|
|
1189
|
+
struct wsp_ggml_context * ctx,
|
|
1190
|
+
struct wsp_ggml_tensor * a);
|
|
1191
|
+
|
|
1140
1192
|
// A: n columns, r rows,
|
|
1141
1193
|
// B: n columns, r rows,
|
|
1142
1194
|
WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_glu_split(
|
|
@@ -1160,6 +1212,23 @@ extern "C" {
|
|
|
1160
1212
|
struct wsp_ggml_tensor * a,
|
|
1161
1213
|
struct wsp_ggml_tensor * b);
|
|
1162
1214
|
|
|
1215
|
+
WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_geglu_erf_split(
|
|
1216
|
+
struct wsp_ggml_context * ctx,
|
|
1217
|
+
struct wsp_ggml_tensor * a,
|
|
1218
|
+
struct wsp_ggml_tensor * b);
|
|
1219
|
+
|
|
1220
|
+
WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_geglu_quick_split(
|
|
1221
|
+
struct wsp_ggml_context * ctx,
|
|
1222
|
+
struct wsp_ggml_tensor * a,
|
|
1223
|
+
struct wsp_ggml_tensor * b);
|
|
1224
|
+
|
|
1225
|
+
WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_swiglu_oai(
|
|
1226
|
+
struct wsp_ggml_context * ctx,
|
|
1227
|
+
struct wsp_ggml_tensor * a,
|
|
1228
|
+
struct wsp_ggml_tensor * b,
|
|
1229
|
+
float alpha,
|
|
1230
|
+
float limit);
|
|
1231
|
+
|
|
1163
1232
|
// normalize along rows
|
|
1164
1233
|
WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_norm(
|
|
1165
1234
|
struct wsp_ggml_context * ctx,
|
|
@@ -1259,6 +1328,19 @@ extern "C" {
|
|
|
1259
1328
|
struct wsp_ggml_tensor * a,
|
|
1260
1329
|
float s);
|
|
1261
1330
|
|
|
1331
|
+
// x = s * a + b
|
|
1332
|
+
WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_scale_bias(
|
|
1333
|
+
struct wsp_ggml_context * ctx,
|
|
1334
|
+
struct wsp_ggml_tensor * a,
|
|
1335
|
+
float s,
|
|
1336
|
+
float b);
|
|
1337
|
+
|
|
1338
|
+
WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_scale_bias_inplace(
|
|
1339
|
+
struct wsp_ggml_context * ctx,
|
|
1340
|
+
struct wsp_ggml_tensor * a,
|
|
1341
|
+
float s,
|
|
1342
|
+
float b);
|
|
1343
|
+
|
|
1262
1344
|
// b -> view(a,offset,nb1,nb2,3), return modified a
|
|
1263
1345
|
WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set(
|
|
1264
1346
|
struct wsp_ggml_context * ctx,
|
|
@@ -1503,8 +1585,14 @@ extern "C" {
|
|
|
1503
1585
|
struct wsp_ggml_context * ctx,
|
|
1504
1586
|
struct wsp_ggml_tensor * a);
|
|
1505
1587
|
|
|
1588
|
+
// a [ne0, ne01, ne02, ne03]
|
|
1589
|
+
// mask [ne0, ne11, ne12, ne13] | ne11 >= ne01, F16 or F32, optional
|
|
1590
|
+
//
|
|
1591
|
+
// broadcast:
|
|
1592
|
+
// ne02 % ne12 == 0
|
|
1593
|
+
// ne03 % ne13 == 0
|
|
1594
|
+
//
|
|
1506
1595
|
// fused soft_max(a*scale + mask*(ALiBi slope))
|
|
1507
|
-
// mask is optional
|
|
1508
1596
|
// max_bias = 0.0f for no ALiBi
|
|
1509
1597
|
WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_soft_max_ext(
|
|
1510
1598
|
struct wsp_ggml_context * ctx,
|
|
@@ -1513,6 +1601,10 @@ extern "C" {
|
|
|
1513
1601
|
float scale,
|
|
1514
1602
|
float max_bias);
|
|
1515
1603
|
|
|
1604
|
+
WSP_GGML_API void wsp_ggml_soft_max_add_sinks(
|
|
1605
|
+
struct wsp_ggml_tensor * a,
|
|
1606
|
+
struct wsp_ggml_tensor * sinks);
|
|
1607
|
+
|
|
1516
1608
|
WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_soft_max_ext_back(
|
|
1517
1609
|
struct wsp_ggml_context * ctx,
|
|
1518
1610
|
struct wsp_ggml_tensor * a,
|
|
@@ -1571,7 +1663,7 @@ extern "C" {
|
|
|
1571
1663
|
struct wsp_ggml_tensor * b,
|
|
1572
1664
|
struct wsp_ggml_tensor * c,
|
|
1573
1665
|
int n_dims,
|
|
1574
|
-
int sections[
|
|
1666
|
+
int sections[WSP_GGML_MROPE_SECTIONS],
|
|
1575
1667
|
int mode,
|
|
1576
1668
|
int n_ctx_orig,
|
|
1577
1669
|
float freq_base,
|
|
@@ -1597,6 +1689,22 @@ extern "C" {
|
|
|
1597
1689
|
float beta_fast,
|
|
1598
1690
|
float beta_slow);
|
|
1599
1691
|
|
|
1692
|
+
WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_multi_inplace(
|
|
1693
|
+
struct wsp_ggml_context * ctx,
|
|
1694
|
+
struct wsp_ggml_tensor * a,
|
|
1695
|
+
struct wsp_ggml_tensor * b,
|
|
1696
|
+
struct wsp_ggml_tensor * c,
|
|
1697
|
+
int n_dims,
|
|
1698
|
+
int sections[WSP_GGML_MROPE_SECTIONS],
|
|
1699
|
+
int mode,
|
|
1700
|
+
int n_ctx_orig,
|
|
1701
|
+
float freq_base,
|
|
1702
|
+
float freq_scale,
|
|
1703
|
+
float ext_factor,
|
|
1704
|
+
float attn_factor,
|
|
1705
|
+
float beta_fast,
|
|
1706
|
+
float beta_slow);
|
|
1707
|
+
|
|
1600
1708
|
WSP_GGML_DEPRECATED(WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_custom(
|
|
1601
1709
|
struct wsp_ggml_context * ctx,
|
|
1602
1710
|
struct wsp_ggml_tensor * a,
|
|
@@ -1967,11 +2075,17 @@ extern "C" {
|
|
|
1967
2075
|
|
|
1968
2076
|
#define WSP_GGML_KQ_MASK_PAD 64
|
|
1969
2077
|
|
|
1970
|
-
// q: [n_embd_k, n_batch, n_head,
|
|
1971
|
-
// k: [n_embd_k, n_kv, n_head_kv,
|
|
1972
|
-
// v: [n_embd_v, n_kv, n_head_kv,
|
|
1973
|
-
// mask: [n_kv, n_batch_pad,
|
|
1974
|
-
// res: [n_embd_v, n_head, n_batch,
|
|
2078
|
+
// q: [n_embd_k, n_batch, n_head, ne3 ]
|
|
2079
|
+
// k: [n_embd_k, n_kv, n_head_kv, ne3 ]
|
|
2080
|
+
// v: [n_embd_v, n_kv, n_head_kv, ne3 ] !! not transposed !!
|
|
2081
|
+
// mask: [n_kv, n_batch_pad, ne32, ne33] !! n_batch_pad = WSP_GGML_PAD(n_batch, WSP_GGML_KQ_MASK_PAD) !!
|
|
2082
|
+
// res: [n_embd_v, n_head, n_batch, ne3 ] !! permuted !!
|
|
2083
|
+
//
|
|
2084
|
+
// broadcast:
|
|
2085
|
+
// n_head % n_head_kv == 0
|
|
2086
|
+
// n_head % ne32 == 0
|
|
2087
|
+
// ne3 % ne33 == 0
|
|
2088
|
+
//
|
|
1975
2089
|
WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_flash_attn_ext(
|
|
1976
2090
|
struct wsp_ggml_context * ctx,
|
|
1977
2091
|
struct wsp_ggml_tensor * q,
|
|
@@ -1989,6 +2103,10 @@ extern "C" {
|
|
|
1989
2103
|
WSP_GGML_API enum wsp_ggml_prec wsp_ggml_flash_attn_ext_get_prec(
|
|
1990
2104
|
const struct wsp_ggml_tensor * a);
|
|
1991
2105
|
|
|
2106
|
+
WSP_GGML_API void wsp_ggml_flash_attn_ext_add_sinks(
|
|
2107
|
+
struct wsp_ggml_tensor * a,
|
|
2108
|
+
struct wsp_ggml_tensor * sinks);
|
|
2109
|
+
|
|
1992
2110
|
// TODO: needs to be adapted to wsp_ggml_flash_attn_ext
|
|
1993
2111
|
WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_flash_attn_back(
|
|
1994
2112
|
struct wsp_ggml_context * ctx,
|
|
@@ -2010,7 +2128,8 @@ extern "C" {
|
|
|
2010
2128
|
struct wsp_ggml_tensor * dt,
|
|
2011
2129
|
struct wsp_ggml_tensor * A,
|
|
2012
2130
|
struct wsp_ggml_tensor * B,
|
|
2013
|
-
struct wsp_ggml_tensor * C
|
|
2131
|
+
struct wsp_ggml_tensor * C,
|
|
2132
|
+
struct wsp_ggml_tensor * ids);
|
|
2014
2133
|
|
|
2015
2134
|
// partition into non-overlapping windows with padding if needed
|
|
2016
2135
|
// example:
|
|
@@ -2193,7 +2312,14 @@ extern "C" {
|
|
|
2193
2312
|
struct wsp_ggml_tensor * grad,
|
|
2194
2313
|
struct wsp_ggml_tensor * m,
|
|
2195
2314
|
struct wsp_ggml_tensor * v,
|
|
2196
|
-
struct wsp_ggml_tensor * adamw_params); // parameters such
|
|
2315
|
+
struct wsp_ggml_tensor * adamw_params); // parameters such as the learning rate
|
|
2316
|
+
|
|
2317
|
+
// stochastic gradient descent step (with weight decay)
|
|
2318
|
+
WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_opt_step_sgd(
|
|
2319
|
+
struct wsp_ggml_context * ctx,
|
|
2320
|
+
struct wsp_ggml_tensor * a,
|
|
2321
|
+
struct wsp_ggml_tensor * grad,
|
|
2322
|
+
struct wsp_ggml_tensor * sgd_params); // alpha, weight decay
|
|
2197
2323
|
|
|
2198
2324
|
//
|
|
2199
2325
|
// automatic differentiation
|
|
Binary file
|
|
Binary file
|
|
@@ -99,6 +99,9 @@ typedef sycl::half2 wsp_ggml_half2;
|
|
|
99
99
|
#define QI4_1 (QK4_1 / (4 * QR4_1))
|
|
100
100
|
#define QR4_1 2
|
|
101
101
|
|
|
102
|
+
#define QI_MXFP4 (QK_MXFP4 / (4 * QR_MXFP4))
|
|
103
|
+
#define QR_MXFP4 2
|
|
104
|
+
|
|
102
105
|
#define QI5_0 (QK5_0 / (4 * QR5_0))
|
|
103
106
|
#define QR5_0 2
|
|
104
107
|
|
|
@@ -184,6 +187,13 @@ typedef struct {
|
|
|
184
187
|
} block_q4_1;
|
|
185
188
|
static_assert(sizeof(block_q4_1) == 2 * sizeof(wsp_ggml_half) + QK4_1 / 2, "wrong q4_1 block size/padding");
|
|
186
189
|
|
|
190
|
+
#define QK_MXFP4 32
|
|
191
|
+
typedef struct {
|
|
192
|
+
uint8_t e; // E8M0
|
|
193
|
+
uint8_t qs[QK_MXFP4/2];
|
|
194
|
+
} block_mxfp4;
|
|
195
|
+
static_assert(sizeof(block_mxfp4) == sizeof(uint8_t) + QK_MXFP4/2, "wrong mxfp4 block size/padding");
|
|
196
|
+
|
|
187
197
|
#define QK5_0 32
|
|
188
198
|
typedef struct {
|
|
189
199
|
wsp_ggml_half d; // delta
|
|
@@ -1074,10 +1084,17 @@ WSP_GGML_TABLE_BEGIN(uint32_t, iq3s_grid, 512)
|
|
|
1074
1084
|
0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101,
|
|
1075
1085
|
WSP_GGML_TABLE_END()
|
|
1076
1086
|
|
|
1087
|
+
// TODO: fix name to kvalues_iq4_nl
|
|
1077
1088
|
WSP_GGML_TABLE_BEGIN(int8_t, kvalues_iq4nl, 16)
|
|
1078
1089
|
-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113,
|
|
1079
1090
|
WSP_GGML_TABLE_END()
|
|
1080
1091
|
|
|
1092
|
+
// e2m1 values (doubled)
|
|
1093
|
+
// ref: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
|
|
1094
|
+
WSP_GGML_TABLE_BEGIN(int8_t, kvalues_mxfp4, 16)
|
|
1095
|
+
0, 1, 2, 3, 4, 6, 8, 12, 0, -1, -2, -3, -4, -6, -8, -12,
|
|
1096
|
+
WSP_GGML_TABLE_END()
|
|
1097
|
+
|
|
1081
1098
|
#define NGRID_IQ1S 2048
|
|
1082
1099
|
#define IQ1S_DELTA 0.125f
|
|
1083
1100
|
#define IQ1M_DELTA 0.125f
|
|
@@ -73,6 +73,22 @@ static inline int wsp_ggml_up(int n, int m) {
|
|
|
73
73
|
return (n + m - 1) & ~(m - 1);
|
|
74
74
|
}
|
|
75
75
|
|
|
76
|
+
// TODO: move to ggml.h?
|
|
77
|
+
static bool wsp_ggml_are_same_layout(const struct wsp_ggml_tensor * a, const struct wsp_ggml_tensor * b) {
|
|
78
|
+
if (a->type != b->type) {
|
|
79
|
+
return false;
|
|
80
|
+
}
|
|
81
|
+
for (int i = 0; i < WSP_GGML_MAX_DIMS; i++) {
|
|
82
|
+
if (a->ne[i] != b->ne[i]) {
|
|
83
|
+
return false;
|
|
84
|
+
}
|
|
85
|
+
if (a->nb[i] != b->nb[i]) {
|
|
86
|
+
return false;
|
|
87
|
+
}
|
|
88
|
+
}
|
|
89
|
+
return true;
|
|
90
|
+
}
|
|
91
|
+
|
|
76
92
|
//
|
|
77
93
|
// logging
|
|
78
94
|
//
|
|
@@ -394,6 +410,67 @@ static inline wsp_ggml_fp16_t wsp_ggml_compute_fp32_to_fp16(float f) {
|
|
|
394
410
|
#define WSP_GGML_FP16_TO_FP32(x) WSP_GGML_COMPUTE_FP16_TO_FP32(x)
|
|
395
411
|
#define WSP_GGML_FP32_TO_FP16(x) WSP_GGML_COMPUTE_FP32_TO_FP16(x)
|
|
396
412
|
|
|
413
|
+
static inline float wsp_ggml_e8m0_to_fp32(uint8_t x) {
|
|
414
|
+
uint32_t bits; // Stores the raw bit representation of the float
|
|
415
|
+
|
|
416
|
+
// Handle special case for minimum exponent (denormalized float)
|
|
417
|
+
if (x == 0) {
|
|
418
|
+
// Bit pattern for 2^(-127):
|
|
419
|
+
// - Sign bit: 0 (positive)
|
|
420
|
+
// - Exponent: 0 (denormalized number)
|
|
421
|
+
// - Mantissa: 0x400000 (0.5 in fractional form)
|
|
422
|
+
// Value = 0.5 * 2^(-126) = 2^(-127)
|
|
423
|
+
bits = 0x00400000;
|
|
424
|
+
}
|
|
425
|
+
// note: disabled as we don't need to handle NaNs
|
|
426
|
+
//// Handle special case for NaN (all bits set)
|
|
427
|
+
//else if (x == 0xFF) {
|
|
428
|
+
// // Standard quiet NaN pattern:
|
|
429
|
+
// // - Sign bit: 0
|
|
430
|
+
// // - Exponent: all 1s (0xFF)
|
|
431
|
+
// // - Mantissa: 0x400000 (quiet NaN flag)
|
|
432
|
+
// bits = 0x7FC00000;
|
|
433
|
+
//}
|
|
434
|
+
// Normalized values (most common case)
|
|
435
|
+
else {
|
|
436
|
+
// Construct normalized float by shifting exponent into position:
|
|
437
|
+
// - Exponent field: 8 bits (positions 30-23)
|
|
438
|
+
// - Mantissa: 0 (implicit leading 1)
|
|
439
|
+
// Value = 2^(x - 127)
|
|
440
|
+
bits = (uint32_t) x << 23;
|
|
441
|
+
}
|
|
442
|
+
|
|
443
|
+
float result; // Final float value
|
|
444
|
+
// Safely reinterpret bit pattern as float without type-punning issues
|
|
445
|
+
memcpy(&result, &bits, sizeof(float));
|
|
446
|
+
return result;
|
|
447
|
+
}
|
|
448
|
+
|
|
449
|
+
// Equal to wsp_ggml_e8m0_to_fp32/2
|
|
450
|
+
// Useful with MXFP4 quantization since the E0M2 values are doubled
|
|
451
|
+
static inline float wsp_ggml_e8m0_to_fp32_half(uint8_t x) {
|
|
452
|
+
uint32_t bits;
|
|
453
|
+
|
|
454
|
+
// For x < 2: use precomputed denormal patterns
|
|
455
|
+
if (x < 2) {
|
|
456
|
+
// 0x00200000 = 2^(-128), 0x00400000 = 2^(-127)
|
|
457
|
+
bits = 0x00200000 << x;
|
|
458
|
+
}
|
|
459
|
+
// For x >= 2: normalized exponent adjustment
|
|
460
|
+
else {
|
|
461
|
+
// 0.5 * 2^(x-127) = 2^(x-128) = normalized with exponent (x-1)
|
|
462
|
+
bits = (uint32_t)(x - 1) << 23;
|
|
463
|
+
}
|
|
464
|
+
// Note: NaNs are not handled here
|
|
465
|
+
|
|
466
|
+
float result;
|
|
467
|
+
memcpy(&result, &bits, sizeof(float));
|
|
468
|
+
return result;
|
|
469
|
+
}
|
|
470
|
+
|
|
471
|
+
#define WSP_GGML_E8M0_TO_FP32(x) wsp_ggml_e8m0_to_fp32(x)
|
|
472
|
+
#define WSP_GGML_E8M0_TO_FP32_HALF(x) wsp_ggml_e8m0_to_fp32_half(x)
|
|
473
|
+
|
|
397
474
|
/**
|
|
398
475
|
* Converts brain16 to float32.
|
|
399
476
|
*
|