@lgrammel/ds4-provider 0.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. package/README.md +96 -0
  2. package/binding.gyp +75 -0
  3. package/dist/ds4-language-model.d.ts +71 -0
  4. package/dist/ds4-language-model.d.ts.map +1 -0
  5. package/dist/ds4-language-model.js +888 -0
  6. package/dist/ds4-language-model.js.map +1 -0
  7. package/dist/ds4-provider.d.ts +13 -0
  8. package/dist/ds4-provider.d.ts.map +1 -0
  9. package/dist/ds4-provider.js +20 -0
  10. package/dist/ds4-provider.js.map +1 -0
  11. package/dist/index.d.ts +4 -0
  12. package/dist/index.d.ts.map +1 -0
  13. package/dist/index.js +4 -0
  14. package/dist/index.js.map +1 -0
  15. package/dist/native-binding.d.ts +42 -0
  16. package/dist/native-binding.d.ts.map +1 -0
  17. package/dist/native-binding.js +157 -0
  18. package/dist/native-binding.js.map +1 -0
  19. package/ds4/LICENSE +22 -0
  20. package/ds4/ds4.c +18268 -0
  21. package/ds4/ds4.h +196 -0
  22. package/ds4/ds4_gpu.h +804 -0
  23. package/ds4/ds4_metal.m +14657 -0
  24. package/ds4/metal/argsort.metal +266 -0
  25. package/ds4/metal/bin.metal +192 -0
  26. package/ds4/metal/concat.metal +62 -0
  27. package/ds4/metal/cpy.metal +57 -0
  28. package/ds4/metal/dense.metal +1121 -0
  29. package/ds4/metal/dsv4_hc.metal +861 -0
  30. package/ds4/metal/dsv4_kv.metal +227 -0
  31. package/ds4/metal/dsv4_misc.metal +1088 -0
  32. package/ds4/metal/dsv4_rope.metal +155 -0
  33. package/ds4/metal/flash_attn.metal +1426 -0
  34. package/ds4/metal/get_rows.metal +54 -0
  35. package/ds4/metal/glu.metal +36 -0
  36. package/ds4/metal/moe.metal +1737 -0
  37. package/ds4/metal/norm.metal +153 -0
  38. package/ds4/metal/repeat.metal +52 -0
  39. package/ds4/metal/set_rows.metal +55 -0
  40. package/ds4/metal/softmax.metal +241 -0
  41. package/ds4/metal/sum_rows.metal +102 -0
  42. package/ds4/metal/unary.metal +312 -0
  43. package/native/binding.cpp +621 -0
  44. package/package.json +66 -0
  45. package/scripts/postinstall.cjs +13 -0
  46. package/scripts/vendor-ds4.cjs +67 -0
@@ -0,0 +1,621 @@
1
+ #include <algorithm>
2
+ #include <atomic>
3
+ #include <chrono>
4
+ #include <cstdint>
5
+ #include <cstdlib>
6
+ #include <cstring>
7
+ #include <fcntl.h>
8
+ #include <memory>
9
+ #include <mutex>
10
+ #include <napi.h>
11
+ #include <string>
12
+ #include <unistd.h>
13
+ #include <unordered_map>
14
+ #include <vector>
15
+
16
+ extern "C" {
17
+ #include "ds4.h"
18
+ }
19
+
20
+ struct ChatMessage {
21
+ std::string role;
22
+ std::string content;
23
+ };
24
+
25
+ struct GenerateParams {
26
+ std::vector<ChatMessage> messages;
27
+ int max_tokens = 2048;
28
+ float temperature = 1.0f;
29
+ float top_p = 1.0f;
30
+ int top_k = 0;
31
+ float min_p = 0.0f;
32
+ uint64_t seed = 0;
33
+ std::vector<std::string> stop_sequences;
34
+ ds4_think_mode think_mode = DS4_THINK_NONE;
35
+ };
36
+
37
+ struct GenerateResult {
38
+ std::string text;
39
+ int prompt_tokens = 0;
40
+ int completion_tokens = 0;
41
+ std::string finish_reason = "stop";
42
+ std::string error_message;
43
+ };
44
+
45
+ struct ModelState {
46
+ ds4_engine *engine = nullptr;
47
+ ds4_session *session = nullptr;
48
+ int ctx_size = 32768;
49
+ std::mutex mutex;
50
+
51
+ ~ModelState() {
52
+ if (session != nullptr) {
53
+ ds4_session_free(session);
54
+ }
55
+ if (engine != nullptr) {
56
+ ds4_engine_close(engine);
57
+ }
58
+ }
59
+ };
60
+
61
+ class GenerateWorker;
62
+ class StreamGenerateWorker;
63
+
64
+ static std::unordered_map<int, std::unique_ptr<ModelState>> g_models;
65
+ static std::mutex g_models_mutex;
66
+ static std::atomic<int> g_next_handle{1};
67
+ static std::unordered_map<int, std::vector<GenerateWorker *>> g_generate_workers;
68
+ static std::unordered_map<int, std::vector<StreamGenerateWorker *>> g_stream_workers;
69
+ static std::mutex g_workers_mutex;
70
+
71
+ class ScopedStderrSilencer {
72
+ public:
73
+ explicit ScopedStderrSilencer(bool enabled) : enabled_(enabled) {
74
+ if (!enabled_) {
75
+ return;
76
+ }
77
+
78
+ fflush(stderr);
79
+ saved_stderr_fd_ = dup(STDERR_FILENO);
80
+ null_fd_ = open("/dev/null", O_WRONLY);
81
+
82
+ if (saved_stderr_fd_ == -1 || null_fd_ == -1) {
83
+ Restore();
84
+ return;
85
+ }
86
+
87
+ dup2(null_fd_, STDERR_FILENO);
88
+ }
89
+
90
+ ~ScopedStderrSilencer() { Restore(); }
91
+
92
+ private:
93
+ void Restore() {
94
+ if (saved_stderr_fd_ != -1) {
95
+ fflush(stderr);
96
+ dup2(saved_stderr_fd_, STDERR_FILENO);
97
+ close(saved_stderr_fd_);
98
+ saved_stderr_fd_ = -1;
99
+ }
100
+
101
+ if (null_fd_ != -1) {
102
+ close(null_fd_);
103
+ null_fd_ = -1;
104
+ }
105
+ }
106
+
107
+ bool enabled_;
108
+ int saved_stderr_fd_ = -1;
109
+ int null_fd_ = -1;
110
+ };
111
+
112
+ template <typename Worker>
113
+ void RemoveWorker(std::unordered_map<int, std::vector<Worker *>> &workers_by_handle, int handle,
114
+ Worker *worker) {
115
+ auto it = workers_by_handle.find(handle);
116
+ if (it == workers_by_handle.end()) {
117
+ return;
118
+ }
119
+
120
+ auto &workers = it->second;
121
+ workers.erase(std::remove(workers.begin(), workers.end(), worker), workers.end());
122
+ if (workers.empty()) {
123
+ workers_by_handle.erase(it);
124
+ }
125
+ }
126
+
127
+ static ds4_backend ParseBackend(const std::string &backend) {
128
+ if (backend == "cpu") {
129
+ return DS4_BACKEND_CPU;
130
+ }
131
+ if (backend == "cuda") {
132
+ return DS4_BACKEND_CUDA;
133
+ }
134
+ #ifdef __APPLE__
135
+ return DS4_BACKEND_METAL;
136
+ #else
137
+ return DS4_BACKEND_CPU;
138
+ #endif
139
+ }
140
+
141
+ static uint64_t DefaultSeed() {
142
+ return static_cast<uint64_t>(
143
+ std::chrono::high_resolution_clock::now().time_since_epoch().count());
144
+ }
145
+
146
+ static bool EndsWithStopSequence(const std::string &text,
147
+ const std::vector<std::string> &stop_sequences,
148
+ size_t *stop_start) {
149
+ for (const auto &stop : stop_sequences) {
150
+ if (stop.empty() || text.size() < stop.size()) {
151
+ continue;
152
+ }
153
+ if (text.compare(text.size() - stop.size(), stop.size(), stop) == 0) {
154
+ *stop_start = text.size() - stop.size();
155
+ return true;
156
+ }
157
+ }
158
+ return false;
159
+ }
160
+
161
+ static ds4_think_mode ParseThinkMode(const std::string &think_mode) {
162
+ if (think_mode == "high") {
163
+ return DS4_THINK_HIGH;
164
+ }
165
+ if (think_mode == "max") {
166
+ return DS4_THINK_MAX;
167
+ }
168
+ return DS4_THINK_NONE;
169
+ }
170
+
171
+ static ds4_tokens BuildPrompt(ModelState *model, const std::vector<ChatMessage> &messages,
172
+ ds4_think_mode think_mode) {
173
+ ds4_tokens prompt{};
174
+ ds4_chat_begin(model->engine, &prompt);
175
+
176
+ ds4_think_mode effective_think_mode = ds4_think_mode_for_context(think_mode, model->ctx_size);
177
+ if (effective_think_mode == DS4_THINK_MAX) {
178
+ ds4_chat_append_max_effort_prefix(model->engine, &prompt);
179
+ }
180
+
181
+ for (const auto &message : messages) {
182
+ ds4_chat_append_message(model->engine, &prompt, message.role.c_str(), message.content.c_str());
183
+ }
184
+
185
+ ds4_chat_append_assistant_prefix(model->engine, &prompt, effective_think_mode);
186
+ return prompt;
187
+ }
188
+
189
+ static GenerateResult RunGeneration(ModelState *model, const GenerateParams &params,
190
+ const std::atomic<bool> &cancelled,
191
+ const std::function<bool(const std::string &)> &on_token) {
192
+ std::lock_guard<std::mutex> model_lock(model->mutex);
193
+
194
+ GenerateResult result;
195
+ char err[256] = {0};
196
+ ds4_tokens prompt = BuildPrompt(model, params.messages, params.think_mode);
197
+ result.prompt_tokens = prompt.len;
198
+
199
+ if (ds4_session_sync(model->session, &prompt, err, sizeof(err)) != 0) {
200
+ result.finish_reason = "error";
201
+ result.error_message = err;
202
+ ds4_tokens_free(&prompt);
203
+ return result;
204
+ }
205
+
206
+ int max_tokens = params.max_tokens;
207
+ const int room = ds4_session_ctx(model->session) - ds4_session_pos(model->session);
208
+ if (room <= 1) {
209
+ max_tokens = 0;
210
+ } else if (max_tokens > room - 1) {
211
+ max_tokens = room - 1;
212
+ }
213
+
214
+ uint64_t rng = params.seed != 0 ? params.seed : DefaultSeed();
215
+ for (int generated = 0; generated < max_tokens && !cancelled.load(); generated++) {
216
+ const int token = ds4_session_sample(model->session, params.temperature, params.top_k,
217
+ params.top_p, params.min_p, &rng);
218
+ if (token == ds4_token_eos(model->engine)) {
219
+ result.finish_reason = "stop";
220
+ break;
221
+ }
222
+
223
+ if (ds4_session_eval(model->session, token, err, sizeof(err)) != 0) {
224
+ result.finish_reason = "error";
225
+ result.error_message = err;
226
+ break;
227
+ }
228
+
229
+ size_t piece_len = 0;
230
+ char *piece = ds4_token_text(model->engine, token, &piece_len);
231
+ if (piece != nullptr && piece_len > 0) {
232
+ result.text.append(piece, piece_len);
233
+ if (on_token && !on_token(std::string(piece, piece_len))) {
234
+ free(piece);
235
+ break;
236
+ }
237
+ }
238
+ free(piece);
239
+
240
+ result.completion_tokens++;
241
+ size_t stop_start = 0;
242
+ if (EndsWithStopSequence(result.text, params.stop_sequences, &stop_start)) {
243
+ result.text.resize(stop_start);
244
+ result.finish_reason = "stop";
245
+ break;
246
+ }
247
+
248
+ if (generated + 1 == max_tokens) {
249
+ result.finish_reason = "length";
250
+ }
251
+ }
252
+
253
+ ds4_tokens_free(&prompt);
254
+ return result;
255
+ }
256
+
257
+ static Napi::Object ToJSResult(Napi::Env env, const GenerateResult &result) {
258
+ Napi::Object object = Napi::Object::New(env);
259
+ object.Set("text", Napi::String::New(env, result.text));
260
+ object.Set("promptTokens", Napi::Number::New(env, result.prompt_tokens));
261
+ object.Set("completionTokens", Napi::Number::New(env, result.completion_tokens));
262
+ object.Set("finishReason", Napi::String::New(env, result.finish_reason));
263
+ if (!result.error_message.empty()) {
264
+ object.Set("errorMessage", Napi::String::New(env, result.error_message));
265
+ }
266
+ return object;
267
+ }
268
+
269
+ static std::vector<ChatMessage> ParseMessages(Napi::Array messages) {
270
+ std::vector<ChatMessage> result;
271
+ for (uint32_t i = 0; i < messages.Length(); i++) {
272
+ Napi::Object message = messages.Get(i).As<Napi::Object>();
273
+ result.push_back({
274
+ message.Get("role").As<Napi::String>().Utf8Value(),
275
+ message.Get("content").As<Napi::String>().Utf8Value(),
276
+ });
277
+ }
278
+ return result;
279
+ }
280
+
281
+ static GenerateParams ParseGenerateParams(Napi::Object options) {
282
+ GenerateParams params;
283
+ params.messages = ParseMessages(options.Get("messages").As<Napi::Array>());
284
+ params.max_tokens =
285
+ options.Has("maxTokens") ? options.Get("maxTokens").As<Napi::Number>().Int32Value() : 2048;
286
+ params.temperature =
287
+ options.Has("temperature") ? options.Get("temperature").As<Napi::Number>().FloatValue() : 1.0f;
288
+ params.top_p = options.Has("topP") ? options.Get("topP").As<Napi::Number>().FloatValue() : 1.0f;
289
+ params.top_k = options.Has("topK") ? options.Get("topK").As<Napi::Number>().Int32Value() : 0;
290
+ params.min_p = options.Has("minP") ? options.Get("minP").As<Napi::Number>().FloatValue() : 0.0f;
291
+ params.seed = options.Has("seed")
292
+ ? static_cast<uint64_t>(options.Get("seed").As<Napi::Number>().Int64Value())
293
+ : 0;
294
+
295
+ if (options.Has("stopSequences") && options.Get("stopSequences").IsArray()) {
296
+ Napi::Array stops = options.Get("stopSequences").As<Napi::Array>();
297
+ for (uint32_t i = 0; i < stops.Length(); i++) {
298
+ params.stop_sequences.push_back(stops.Get(i).As<Napi::String>().Utf8Value());
299
+ }
300
+ }
301
+ if (options.Has("thinkMode") && options.Get("thinkMode").IsString()) {
302
+ params.think_mode = ParseThinkMode(options.Get("thinkMode").As<Napi::String>().Utf8Value());
303
+ }
304
+
305
+ return params;
306
+ }
307
+
308
+ static ModelState *GetModel(int handle) {
309
+ std::lock_guard<std::mutex> lock(g_models_mutex);
310
+ auto it = g_models.find(handle);
311
+ return it == g_models.end() ? nullptr : it->second.get();
312
+ }
313
+
314
+ class LoadModelWorker : public Napi::AsyncWorker {
315
+ public:
316
+ LoadModelWorker(Napi::Function &callback, Napi::Object options)
317
+ : Napi::AsyncWorker(callback),
318
+ model_path_(options.Get("modelPath").As<Napi::String>().Utf8Value()),
319
+ mtp_path_(options.Has("mtpPath") ? options.Get("mtpPath").As<Napi::String>().Utf8Value()
320
+ : ""),
321
+ backend_(options.Has("backend") ? options.Get("backend").As<Napi::String>().Utf8Value()
322
+ : ""),
323
+ ctx_size_(options.Has("contextSize")
324
+ ? options.Get("contextSize").As<Napi::Number>().Int32Value()
325
+ : 32768),
326
+ threads_(options.Has("threads") ? options.Get("threads").As<Napi::Number>().Int32Value()
327
+ : 0),
328
+ mtp_draft_tokens_(options.Has("mtpDraftTokens")
329
+ ? options.Get("mtpDraftTokens").As<Napi::Number>().Int32Value()
330
+ : 0),
331
+ mtp_margin_(options.Has("mtpMargin") ? options.Get("mtpMargin").As<Napi::Number>().FloatValue()
332
+ : 0.0f),
333
+ warm_weights_(options.Has("warmWeights") &&
334
+ options.Get("warmWeights").As<Napi::Boolean>().Value()),
335
+ quality_(options.Has("quality") && options.Get("quality").As<Napi::Boolean>().Value()),
336
+ debug_(options.Has("debug") && options.Get("debug").As<Napi::Boolean>().Value()) {}
337
+
338
+ void Execute() override {
339
+ auto model = std::make_unique<ModelState>();
340
+ model->ctx_size = ctx_size_;
341
+
342
+ ds4_engine_options options{};
343
+ options.model_path = model_path_.c_str();
344
+ options.mtp_path = mtp_path_.empty() ? nullptr : mtp_path_.c_str();
345
+ options.backend = ParseBackend(backend_);
346
+ options.n_threads = threads_;
347
+ options.mtp_draft_tokens = mtp_draft_tokens_;
348
+ options.mtp_margin = mtp_margin_;
349
+ options.warm_weights = warm_weights_;
350
+ options.quality = quality_;
351
+
352
+ {
353
+ ScopedStderrSilencer silence_stderr(!debug_);
354
+ if (ds4_engine_open(&model->engine, &options) != 0) {
355
+ SetError("Failed to open DS4 model: " + model_path_);
356
+ return;
357
+ }
358
+ }
359
+ if (ds4_session_create(&model->session, model->engine, ctx_size_) != 0) {
360
+ SetError("Failed to create DS4 session");
361
+ return;
362
+ }
363
+
364
+ handle_ = g_next_handle++;
365
+ {
366
+ std::lock_guard<std::mutex> lock(g_models_mutex);
367
+ g_models[handle_] = std::move(model);
368
+ }
369
+ }
370
+
371
+ void OnOK() override {
372
+ Callback().Call({Env().Null(), Napi::Number::New(Env(), handle_)});
373
+ }
374
+
375
+ void OnError(const Napi::Error &error) override {
376
+ Callback().Call({Napi::String::New(Env(), error.Message()), Env().Null()});
377
+ }
378
+
379
+ private:
380
+ std::string model_path_;
381
+ std::string mtp_path_;
382
+ std::string backend_;
383
+ int ctx_size_;
384
+ int threads_;
385
+ int mtp_draft_tokens_;
386
+ float mtp_margin_;
387
+ bool warm_weights_;
388
+ bool quality_;
389
+ bool debug_;
390
+ int handle_ = -1;
391
+ };
392
+
393
+ class GenerateWorker : public Napi::AsyncWorker {
394
+ public:
395
+ GenerateWorker(Napi::Function &callback, int handle, GenerateParams params)
396
+ : Napi::AsyncWorker(callback), handle_(handle), params_(std::move(params)) {}
397
+
398
+ void Cancel() { cancelled_.store(true); }
399
+
400
+ void Execute() override {
401
+ ModelState *model = GetModel(handle_);
402
+ if (model == nullptr) {
403
+ SetError("Invalid DS4 model handle");
404
+ return;
405
+ }
406
+
407
+ result_ = RunGeneration(model, params_, cancelled_, nullptr);
408
+ }
409
+
410
+ void OnOK() override {
411
+ {
412
+ std::lock_guard<std::mutex> lock(g_workers_mutex);
413
+ RemoveWorker(g_generate_workers, handle_, this);
414
+ }
415
+ Callback().Call({Env().Null(), ToJSResult(Env(), result_)});
416
+ }
417
+
418
+ void OnError(const Napi::Error &error) override {
419
+ {
420
+ std::lock_guard<std::mutex> lock(g_workers_mutex);
421
+ RemoveWorker(g_generate_workers, handle_, this);
422
+ }
423
+ Callback().Call({Napi::String::New(Env(), error.Message()), Env().Null()});
424
+ }
425
+
426
+ private:
427
+ int handle_;
428
+ GenerateParams params_;
429
+ GenerateResult result_;
430
+ std::atomic<bool> cancelled_{false};
431
+ };
432
+
433
+ class StreamGenerateWorker : public Napi::AsyncWorker {
434
+ public:
435
+ StreamGenerateWorker(Napi::Function &callback, int handle, GenerateParams params,
436
+ Napi::ThreadSafeFunction tsfn)
437
+ : Napi::AsyncWorker(callback), handle_(handle), params_(std::move(params)), tsfn_(tsfn) {}
438
+
439
+ void Cancel() { cancelled_.store(true); }
440
+
441
+ void Execute() override {
442
+ ModelState *model = GetModel(handle_);
443
+ if (model == nullptr) {
444
+ SetError("Invalid DS4 model handle");
445
+ return;
446
+ }
447
+
448
+ result_ = RunGeneration(model, params_, cancelled_, [this](const std::string &token) {
449
+ if (cancelled_.load()) {
450
+ return false;
451
+ }
452
+
453
+ auto *copy = new std::string(token);
454
+ napi_status status = tsfn_.BlockingCall(
455
+ copy, [](Napi::Env env, Napi::Function callback, std::string *data) {
456
+ callback.Call({Napi::String::New(env, *data)});
457
+ delete data;
458
+ });
459
+ return status == napi_ok && !cancelled_.load();
460
+ });
461
+ }
462
+
463
+ void OnOK() override {
464
+ {
465
+ std::lock_guard<std::mutex> lock(g_workers_mutex);
466
+ RemoveWorker(g_stream_workers, handle_, this);
467
+ }
468
+ tsfn_.Release();
469
+ Callback().Call({Env().Null(), ToJSResult(Env(), result_)});
470
+ }
471
+
472
+ void OnError(const Napi::Error &error) override {
473
+ {
474
+ std::lock_guard<std::mutex> lock(g_workers_mutex);
475
+ RemoveWorker(g_stream_workers, handle_, this);
476
+ }
477
+ tsfn_.Release();
478
+ Callback().Call({Napi::String::New(Env(), error.Message()), Env().Null()});
479
+ }
480
+
481
+ private:
482
+ int handle_;
483
+ GenerateParams params_;
484
+ GenerateResult result_;
485
+ Napi::ThreadSafeFunction tsfn_;
486
+ std::atomic<bool> cancelled_{false};
487
+ };
488
+
489
+ Napi::Value LoadModel(const Napi::CallbackInfo &info) {
490
+ Napi::Env env = info.Env();
491
+ if (info.Length() < 2 || !info[0].IsObject() || !info[1].IsFunction()) {
492
+ Napi::TypeError::New(env, "Expected (options, callback)").ThrowAsJavaScriptException();
493
+ return env.Null();
494
+ }
495
+
496
+ auto callback = info[1].As<Napi::Function>();
497
+ auto worker = new LoadModelWorker(callback, info[0].As<Napi::Object>());
498
+ worker->Queue();
499
+ return env.Undefined();
500
+ }
501
+
502
+ Napi::Value UnloadModel(const Napi::CallbackInfo &info) {
503
+ Napi::Env env = info.Env();
504
+ if (info.Length() < 1 || !info[0].IsNumber()) {
505
+ Napi::TypeError::New(env, "Expected model handle").ThrowAsJavaScriptException();
506
+ return env.Null();
507
+ }
508
+
509
+ const int handle = info[0].As<Napi::Number>().Int32Value();
510
+ {
511
+ std::lock_guard<std::mutex> lock(g_models_mutex);
512
+ g_models.erase(handle);
513
+ }
514
+ return Napi::Boolean::New(env, true);
515
+ }
516
+
517
+ Napi::Value Generate(const Napi::CallbackInfo &info) {
518
+ Napi::Env env = info.Env();
519
+ if (info.Length() < 3 || !info[0].IsNumber() || !info[1].IsObject() || !info[2].IsFunction()) {
520
+ Napi::TypeError::New(env, "Expected (handle, options, callback)").ThrowAsJavaScriptException();
521
+ return env.Null();
522
+ }
523
+
524
+ Napi::Object options = info[1].As<Napi::Object>();
525
+ if (!options.Has("messages") || !options.Get("messages").IsArray()) {
526
+ Napi::TypeError::New(env, "Expected messages array in options").ThrowAsJavaScriptException();
527
+ return env.Null();
528
+ }
529
+
530
+ const int handle = info[0].As<Napi::Number>().Int32Value();
531
+ auto callback = info[2].As<Napi::Function>();
532
+ auto worker = new GenerateWorker(callback, handle, ParseGenerateParams(options));
533
+ {
534
+ std::lock_guard<std::mutex> lock(g_workers_mutex);
535
+ g_generate_workers[handle].push_back(worker);
536
+ }
537
+ worker->Queue();
538
+ return env.Undefined();
539
+ }
540
+
541
+ Napi::Value GenerateStream(const Napi::CallbackInfo &info) {
542
+ Napi::Env env = info.Env();
543
+ if (info.Length() < 4 || !info[0].IsNumber() || !info[1].IsObject() || !info[2].IsFunction() ||
544
+ !info[3].IsFunction()) {
545
+ Napi::TypeError::New(env, "Expected (handle, options, tokenCallback, doneCallback)")
546
+ .ThrowAsJavaScriptException();
547
+ return env.Null();
548
+ }
549
+
550
+ Napi::Object options = info[1].As<Napi::Object>();
551
+ if (!options.Has("messages") || !options.Get("messages").IsArray()) {
552
+ Napi::TypeError::New(env, "Expected messages array in options").ThrowAsJavaScriptException();
553
+ return env.Null();
554
+ }
555
+
556
+ const int handle = info[0].As<Napi::Number>().Int32Value();
557
+ auto token_callback = info[2].As<Napi::Function>();
558
+ auto done_callback = info[3].As<Napi::Function>();
559
+ Napi::ThreadSafeFunction tsfn =
560
+ Napi::ThreadSafeFunction::New(env, token_callback, "DS4TokenCallback", 0, 1);
561
+
562
+ auto worker = new StreamGenerateWorker(done_callback, handle, ParseGenerateParams(options), tsfn);
563
+ {
564
+ std::lock_guard<std::mutex> lock(g_workers_mutex);
565
+ g_stream_workers[handle].push_back(worker);
566
+ }
567
+ worker->Queue();
568
+ return env.Undefined();
569
+ }
570
+
571
+ Napi::Value CancelGeneration(const Napi::CallbackInfo &info) {
572
+ Napi::Env env = info.Env();
573
+ if (info.Length() < 1 || !info[0].IsNumber()) {
574
+ Napi::TypeError::New(env, "Expected model handle").ThrowAsJavaScriptException();
575
+ return env.Null();
576
+ }
577
+
578
+ const int handle = info[0].As<Napi::Number>().Int32Value();
579
+ bool cancelled = false;
580
+ {
581
+ std::lock_guard<std::mutex> lock(g_workers_mutex);
582
+ auto generate_it = g_generate_workers.find(handle);
583
+ if (generate_it != g_generate_workers.end()) {
584
+ for (GenerateWorker *worker : generate_it->second) {
585
+ worker->Cancel();
586
+ cancelled = true;
587
+ }
588
+ }
589
+ auto stream_it = g_stream_workers.find(handle);
590
+ if (stream_it != g_stream_workers.end()) {
591
+ for (StreamGenerateWorker *worker : stream_it->second) {
592
+ worker->Cancel();
593
+ cancelled = true;
594
+ }
595
+ }
596
+ }
597
+ return Napi::Boolean::New(env, cancelled);
598
+ }
599
+
600
+ Napi::Value IsModelLoaded(const Napi::CallbackInfo &info) {
601
+ Napi::Env env = info.Env();
602
+ if (info.Length() < 1 || !info[0].IsNumber()) {
603
+ Napi::TypeError::New(env, "Expected model handle").ThrowAsJavaScriptException();
604
+ return env.Null();
605
+ }
606
+
607
+ const int handle = info[0].As<Napi::Number>().Int32Value();
608
+ return Napi::Boolean::New(env, GetModel(handle) != nullptr);
609
+ }
610
+
611
+ Napi::Object Init(Napi::Env env, Napi::Object exports) {
612
+ exports.Set("loadModel", Napi::Function::New(env, LoadModel));
613
+ exports.Set("unloadModel", Napi::Function::New(env, UnloadModel));
614
+ exports.Set("generate", Napi::Function::New(env, Generate));
615
+ exports.Set("generateStream", Napi::Function::New(env, GenerateStream));
616
+ exports.Set("cancelGeneration", Napi::Function::New(env, CancelGeneration));
617
+ exports.Set("isModelLoaded", Napi::Function::New(env, IsModelLoaded));
618
+ return exports;
619
+ }
620
+
621
+ NODE_API_MODULE(ds4_binding, Init)
package/package.json ADDED
@@ -0,0 +1,66 @@
1
+ {
2
+ "name": "@lgrammel/ds4-provider",
3
+ "version": "0.0.1",
4
+ "description": "DS4 provider for the Vercel AI SDK",
5
+ "type": "module",
6
+ "main": "./dist/index.js",
7
+ "types": "./dist/index.d.ts",
8
+ "publishConfig": {
9
+ "access": "public"
10
+ },
11
+ "exports": {
12
+ ".": {
13
+ "import": "./dist/index.js",
14
+ "types": "./dist/index.d.ts"
15
+ }
16
+ },
17
+ "ds4": {
18
+ "repo": "https://github.com/antirez/ds4.git",
19
+ "commit": "613e9b2c9b78c09ac9f622bb8ece81c99989ead6"
20
+ },
21
+ "files": [
22
+ "dist",
23
+ "ds4/ds4.c",
24
+ "ds4/ds4.h",
25
+ "ds4/ds4_gpu.h",
26
+ "ds4/ds4_metal.m",
27
+ "ds4/LICENSE",
28
+ "ds4/metal",
29
+ "native",
30
+ "scripts",
31
+ "binding.gyp"
32
+ ],
33
+ "keywords": [
34
+ "ai",
35
+ "ai-sdk",
36
+ "ds4",
37
+ "deepseek",
38
+ "language-model",
39
+ "vercel"
40
+ ],
41
+ "author": "Lars Grammel",
42
+ "license": "MIT",
43
+ "engines": {
44
+ "node": ">=18.0.0"
45
+ },
46
+ "dependencies": {
47
+ "@ai-sdk/provider": "4.0.0-canary.16",
48
+ "node-addon-api": "^8.0.0",
49
+ "node-gyp": "^11.5.0"
50
+ },
51
+ "devDependencies": {
52
+ "tsx": "^4.21.0",
53
+ "@types/node": "^20.10.0",
54
+ "typescript": "^5.3.0"
55
+ },
56
+ "scripts": {
57
+ "postinstall": "node scripts/postinstall.cjs",
58
+ "vendor:ds4": "node scripts/vendor-ds4.cjs",
59
+ "build:native": "node-gyp rebuild",
60
+ "build:ts": "tsc",
61
+ "build": "pnpm run vendor:ds4 && pnpm run build:native && pnpm run build:ts",
62
+ "clean": "rm -rf dist build",
63
+ "test": "tsx test/ds4-language-model.test.ts",
64
+ "typecheck": "tsc --noEmit -p tsconfig.check.json"
65
+ }
66
+ }
@@ -0,0 +1,13 @@
1
+ const { execSync } = require("node:child_process");
2
+ const path = require("node:path");
3
+
4
+ const packageRoot = path.join(__dirname, "..");
5
+
6
+ function run(command) {
7
+ execSync(command, {
8
+ cwd: packageRoot,
9
+ stdio: "inherit",
10
+ });
11
+ }
12
+
13
+ run("node-gyp rebuild");