bun-scikit 0.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +21 -0
- package/README.md +187 -0
- package/binding.gyp +21 -0
- package/docs/README.md +7 -0
- package/docs/native-abi.md +53 -0
- package/index.ts +1 -0
- package/package.json +76 -0
- package/scripts/build-node-addon.ts +26 -0
- package/scripts/build-zig-kernels.ts +50 -0
- package/scripts/check-api-docs-coverage.ts +52 -0
- package/scripts/check-benchmark-health.ts +140 -0
- package/scripts/install-native.ts +160 -0
- package/scripts/package-native-artifacts.ts +62 -0
- package/scripts/sync-benchmark-readme.ts +181 -0
- package/scripts/update-benchmark-history.ts +91 -0
- package/src/ensemble/RandomForestClassifier.ts +136 -0
- package/src/ensemble/RandomForestRegressor.ts +136 -0
- package/src/index.ts +32 -0
- package/src/linear_model/LinearRegression.ts +136 -0
- package/src/linear_model/LogisticRegression.ts +260 -0
- package/src/linear_model/SGDClassifier.ts +161 -0
- package/src/linear_model/SGDRegressor.ts +104 -0
- package/src/metrics/classification.ts +294 -0
- package/src/metrics/regression.ts +51 -0
- package/src/model_selection/GridSearchCV.ts +244 -0
- package/src/model_selection/KFold.ts +82 -0
- package/src/model_selection/RepeatedKFold.ts +49 -0
- package/src/model_selection/RepeatedStratifiedKFold.ts +50 -0
- package/src/model_selection/StratifiedKFold.ts +112 -0
- package/src/model_selection/StratifiedShuffleSplit.ts +211 -0
- package/src/model_selection/crossValScore.ts +165 -0
- package/src/model_selection/trainTestSplit.ts +82 -0
- package/src/naive_bayes/GaussianNB.ts +148 -0
- package/src/native/node-addon/bun_scikit_addon.cpp +450 -0
- package/src/native/zigKernels.ts +576 -0
- package/src/neighbors/KNeighborsClassifier.ts +85 -0
- package/src/pipeline/ColumnTransformer.ts +203 -0
- package/src/pipeline/FeatureUnion.ts +123 -0
- package/src/pipeline/Pipeline.ts +168 -0
- package/src/preprocessing/MinMaxScaler.ts +113 -0
- package/src/preprocessing/OneHotEncoder.ts +91 -0
- package/src/preprocessing/PolynomialFeatures.ts +158 -0
- package/src/preprocessing/RobustScaler.ts +149 -0
- package/src/preprocessing/SimpleImputer.ts +150 -0
- package/src/preprocessing/StandardScaler.ts +92 -0
- package/src/svm/LinearSVC.ts +117 -0
- package/src/tree/DecisionTreeClassifier.ts +394 -0
- package/src/tree/DecisionTreeRegressor.ts +407 -0
- package/src/types.ts +18 -0
- package/src/utils/linalg.ts +209 -0
- package/src/utils/validation.ts +78 -0
- package/zig/kernels.zig +1327 -0
|
@@ -0,0 +1,450 @@
|
|
|
1
|
+
#include <napi.h>
|
|
2
|
+
|
|
3
|
+
#include <cstdint>
|
|
4
|
+
#include <string>
|
|
5
|
+
#include <utility>
|
|
6
|
+
|
|
7
|
+
#if defined(_WIN32)
|
|
8
|
+
#include <windows.h>
|
|
9
|
+
#else
|
|
10
|
+
#include <dlfcn.h>
|
|
11
|
+
#endif
|
|
12
|
+
|
|
13
|
+
namespace {
|
|
14
|
+
|
|
15
|
+
using NativeHandle = std::uintptr_t;
|
|
16
|
+
|
|
17
|
+
using AbiVersionFn = std::uint32_t (*)();
|
|
18
|
+
using LinearModelCreateFn = NativeHandle (*)(std::size_t, std::uint8_t);
|
|
19
|
+
using LinearModelDestroyFn = void (*)(NativeHandle);
|
|
20
|
+
using LinearModelFitFn = std::uint8_t (*)(NativeHandle, const double*, const double*, std::size_t, double);
|
|
21
|
+
using LinearModelCopyCoefficientsFn = std::uint8_t (*)(NativeHandle, double*);
|
|
22
|
+
using LinearModelGetInterceptFn = double (*)(NativeHandle);
|
|
23
|
+
|
|
24
|
+
using LogisticModelCreateFn = NativeHandle (*)(std::size_t, std::uint8_t);
|
|
25
|
+
using LogisticModelDestroyFn = void (*)(NativeHandle);
|
|
26
|
+
using LogisticModelFitFn = std::size_t (*)(NativeHandle, const double*, const double*, std::size_t, double, double, std::size_t, double);
|
|
27
|
+
using LogisticModelFitLbfgsFn = std::size_t (*)(NativeHandle, const double*, const double*, std::size_t, std::size_t, double, double, std::size_t);
|
|
28
|
+
using LogisticModelCopyCoefficientsFn = std::uint8_t (*)(NativeHandle, double*);
|
|
29
|
+
using LogisticModelGetInterceptFn = double (*)(NativeHandle);
|
|
30
|
+
|
|
31
|
+
struct KernelLibrary {
|
|
32
|
+
#if defined(_WIN32)
|
|
33
|
+
HMODULE handle{nullptr};
|
|
34
|
+
#else
|
|
35
|
+
void* handle{nullptr};
|
|
36
|
+
#endif
|
|
37
|
+
std::string path{};
|
|
38
|
+
AbiVersionFn abi_version{nullptr};
|
|
39
|
+
LinearModelCreateFn linear_model_create{nullptr};
|
|
40
|
+
LinearModelDestroyFn linear_model_destroy{nullptr};
|
|
41
|
+
LinearModelFitFn linear_model_fit{nullptr};
|
|
42
|
+
LinearModelCopyCoefficientsFn linear_model_copy_coefficients{nullptr};
|
|
43
|
+
LinearModelGetInterceptFn linear_model_get_intercept{nullptr};
|
|
44
|
+
LogisticModelCreateFn logistic_model_create{nullptr};
|
|
45
|
+
LogisticModelDestroyFn logistic_model_destroy{nullptr};
|
|
46
|
+
LogisticModelFitFn logistic_model_fit{nullptr};
|
|
47
|
+
LogisticModelFitLbfgsFn logistic_model_fit_lbfgs{nullptr};
|
|
48
|
+
LogisticModelCopyCoefficientsFn logistic_model_copy_coefficients{nullptr};
|
|
49
|
+
LogisticModelGetInterceptFn logistic_model_get_intercept{nullptr};
|
|
50
|
+
};
|
|
51
|
+
|
|
52
|
+
KernelLibrary g_library{};
|
|
53
|
+
|
|
54
|
+
void unloadLibrary() {
|
|
55
|
+
if (!g_library.handle) {
|
|
56
|
+
return;
|
|
57
|
+
}
|
|
58
|
+
#if defined(_WIN32)
|
|
59
|
+
FreeLibrary(g_library.handle);
|
|
60
|
+
#else
|
|
61
|
+
dlclose(g_library.handle);
|
|
62
|
+
#endif
|
|
63
|
+
g_library = KernelLibrary{};
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
void* lookupSymbol(const char* name) {
|
|
67
|
+
if (!g_library.handle) {
|
|
68
|
+
return nullptr;
|
|
69
|
+
}
|
|
70
|
+
#if defined(_WIN32)
|
|
71
|
+
return reinterpret_cast<void*>(GetProcAddress(g_library.handle, name));
|
|
72
|
+
#else
|
|
73
|
+
return dlsym(g_library.handle, name);
|
|
74
|
+
#endif
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
template <typename T>
|
|
78
|
+
T loadSymbol(const char* name) {
|
|
79
|
+
return reinterpret_cast<T>(lookupSymbol(name));
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
void throwTypeError(const Napi::Env& env, const char* message) {
|
|
83
|
+
Napi::TypeError::New(env, message).ThrowAsJavaScriptException();
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
void throwError(const Napi::Env& env, const char* message) {
|
|
87
|
+
Napi::Error::New(env, message).ThrowAsJavaScriptException();
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
NativeHandle handleFromBigInt(const Napi::Value& value, const Napi::Env& env) {
|
|
91
|
+
if (!value.IsBigInt()) {
|
|
92
|
+
throwTypeError(env, "Expected a BigInt handle.");
|
|
93
|
+
return 0;
|
|
94
|
+
}
|
|
95
|
+
bool lossless = false;
|
|
96
|
+
const std::uint64_t raw = value.As<Napi::BigInt>().Uint64Value(&lossless);
|
|
97
|
+
if (!lossless) {
|
|
98
|
+
throwTypeError(env, "BigInt handle is not lossless as uint64.");
|
|
99
|
+
return 0;
|
|
100
|
+
}
|
|
101
|
+
return static_cast<NativeHandle>(raw);
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
Napi::Value LoadNativeLibrary(const Napi::CallbackInfo& info) {
|
|
105
|
+
const Napi::Env env = info.Env();
|
|
106
|
+
if (info.Length() != 1 || !info[0].IsString()) {
|
|
107
|
+
throwTypeError(env, "loadLibrary(path) expects a string path.");
|
|
108
|
+
return env.Null();
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
unloadLibrary();
|
|
112
|
+
const std::string path = info[0].As<Napi::String>().Utf8Value();
|
|
113
|
+
|
|
114
|
+
#if defined(_WIN32)
|
|
115
|
+
g_library.handle = ::LoadLibraryA(path.c_str());
|
|
116
|
+
#else
|
|
117
|
+
g_library.handle = dlopen(path.c_str(), RTLD_NOW | RTLD_LOCAL);
|
|
118
|
+
#endif
|
|
119
|
+
if (!g_library.handle) {
|
|
120
|
+
return Napi::Boolean::New(env, false);
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
g_library.path = path;
|
|
124
|
+
g_library.abi_version = loadSymbol<AbiVersionFn>("bun_scikit_abi_version");
|
|
125
|
+
g_library.linear_model_create = loadSymbol<LinearModelCreateFn>("linear_model_create");
|
|
126
|
+
g_library.linear_model_destroy = loadSymbol<LinearModelDestroyFn>("linear_model_destroy");
|
|
127
|
+
g_library.linear_model_fit = loadSymbol<LinearModelFitFn>("linear_model_fit");
|
|
128
|
+
g_library.linear_model_copy_coefficients =
|
|
129
|
+
loadSymbol<LinearModelCopyCoefficientsFn>("linear_model_copy_coefficients");
|
|
130
|
+
g_library.linear_model_get_intercept =
|
|
131
|
+
loadSymbol<LinearModelGetInterceptFn>("linear_model_get_intercept");
|
|
132
|
+
g_library.logistic_model_create = loadSymbol<LogisticModelCreateFn>("logistic_model_create");
|
|
133
|
+
g_library.logistic_model_destroy = loadSymbol<LogisticModelDestroyFn>("logistic_model_destroy");
|
|
134
|
+
g_library.logistic_model_fit = loadSymbol<LogisticModelFitFn>("logistic_model_fit");
|
|
135
|
+
g_library.logistic_model_fit_lbfgs =
|
|
136
|
+
loadSymbol<LogisticModelFitLbfgsFn>("logistic_model_fit_lbfgs");
|
|
137
|
+
g_library.logistic_model_copy_coefficients =
|
|
138
|
+
loadSymbol<LogisticModelCopyCoefficientsFn>("logistic_model_copy_coefficients");
|
|
139
|
+
g_library.logistic_model_get_intercept =
|
|
140
|
+
loadSymbol<LogisticModelGetInterceptFn>("logistic_model_get_intercept");
|
|
141
|
+
|
|
142
|
+
return Napi::Boolean::New(env, true);
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
Napi::Value UnloadLibrary(const Napi::CallbackInfo& info) {
|
|
146
|
+
unloadLibrary();
|
|
147
|
+
return info.Env().Undefined();
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
bool isLibraryLoaded(const Napi::Env& env) {
|
|
151
|
+
if (!g_library.handle) {
|
|
152
|
+
throwError(env, "Native library is not loaded.");
|
|
153
|
+
return false;
|
|
154
|
+
}
|
|
155
|
+
return true;
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
Napi::Value AbiVersion(const Napi::CallbackInfo& info) {
|
|
159
|
+
const Napi::Env env = info.Env();
|
|
160
|
+
if (!isLibraryLoaded(env)) {
|
|
161
|
+
return env.Null();
|
|
162
|
+
}
|
|
163
|
+
if (!g_library.abi_version) {
|
|
164
|
+
throwError(env, "Symbol bun_scikit_abi_version is unavailable.");
|
|
165
|
+
return env.Null();
|
|
166
|
+
}
|
|
167
|
+
return Napi::Number::New(env, g_library.abi_version());
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
Napi::Value LoadedPath(const Napi::CallbackInfo& info) {
|
|
171
|
+
const Napi::Env env = info.Env();
|
|
172
|
+
if (!g_library.handle) {
|
|
173
|
+
return env.Null();
|
|
174
|
+
}
|
|
175
|
+
return Napi::String::New(env, g_library.path);
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
Napi::Value LinearModelCreate(const Napi::CallbackInfo& info) {
|
|
179
|
+
const Napi::Env env = info.Env();
|
|
180
|
+
if (!isLibraryLoaded(env)) {
|
|
181
|
+
return env.Null();
|
|
182
|
+
}
|
|
183
|
+
if (!g_library.linear_model_create) {
|
|
184
|
+
throwError(env, "Symbol linear_model_create is unavailable.");
|
|
185
|
+
return env.Null();
|
|
186
|
+
}
|
|
187
|
+
if (info.Length() != 2 || !info[0].IsNumber() || !info[1].IsNumber()) {
|
|
188
|
+
throwTypeError(env, "linearModelCreate(nFeatures, fitIntercept) expects two numbers.");
|
|
189
|
+
return env.Null();
|
|
190
|
+
}
|
|
191
|
+
const std::size_t n_features = static_cast<std::size_t>(info[0].As<Napi::Number>().Uint32Value());
|
|
192
|
+
const std::uint8_t fit_intercept = static_cast<std::uint8_t>(info[1].As<Napi::Number>().Uint32Value());
|
|
193
|
+
const NativeHandle handle = g_library.linear_model_create(n_features, fit_intercept);
|
|
194
|
+
return Napi::BigInt::New(env, static_cast<std::uint64_t>(handle));
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
Napi::Value LinearModelDestroy(const Napi::CallbackInfo& info) {
|
|
198
|
+
const Napi::Env env = info.Env();
|
|
199
|
+
if (!isLibraryLoaded(env)) {
|
|
200
|
+
return env.Null();
|
|
201
|
+
}
|
|
202
|
+
if (!g_library.linear_model_destroy) {
|
|
203
|
+
throwError(env, "Symbol linear_model_destroy is unavailable.");
|
|
204
|
+
return env.Null();
|
|
205
|
+
}
|
|
206
|
+
if (info.Length() != 1) {
|
|
207
|
+
throwTypeError(env, "linearModelDestroy(handle) expects one BigInt.");
|
|
208
|
+
return env.Null();
|
|
209
|
+
}
|
|
210
|
+
const NativeHandle handle = handleFromBigInt(info[0], env);
|
|
211
|
+
if (env.IsExceptionPending()) {
|
|
212
|
+
return env.Null();
|
|
213
|
+
}
|
|
214
|
+
g_library.linear_model_destroy(handle);
|
|
215
|
+
return env.Undefined();
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
Napi::Value LinearModelFit(const Napi::CallbackInfo& info) {
|
|
219
|
+
const Napi::Env env = info.Env();
|
|
220
|
+
if (!isLibraryLoaded(env)) {
|
|
221
|
+
return env.Null();
|
|
222
|
+
}
|
|
223
|
+
if (!g_library.linear_model_fit) {
|
|
224
|
+
throwError(env, "Symbol linear_model_fit is unavailable.");
|
|
225
|
+
return env.Null();
|
|
226
|
+
}
|
|
227
|
+
if (info.Length() != 5 || !info[1].IsTypedArray() || !info[2].IsTypedArray() ||
|
|
228
|
+
!info[3].IsNumber() || !info[4].IsNumber()) {
|
|
229
|
+
throwTypeError(env, "linearModelFit(handle, x, y, nSamples, l2) expects (BigInt, Float64Array, Float64Array, number, number).");
|
|
230
|
+
return env.Null();
|
|
231
|
+
}
|
|
232
|
+
const NativeHandle handle = handleFromBigInt(info[0], env);
|
|
233
|
+
if (env.IsExceptionPending()) {
|
|
234
|
+
return env.Null();
|
|
235
|
+
}
|
|
236
|
+
auto x = info[1].As<Napi::Float64Array>();
|
|
237
|
+
auto y = info[2].As<Napi::Float64Array>();
|
|
238
|
+
const std::size_t n_samples = static_cast<std::size_t>(info[3].As<Napi::Number>().Uint32Value());
|
|
239
|
+
const double l2 = info[4].As<Napi::Number>().DoubleValue();
|
|
240
|
+
const std::uint8_t status = g_library.linear_model_fit(handle, x.Data(), y.Data(), n_samples, l2);
|
|
241
|
+
return Napi::Number::New(env, status);
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
Napi::Value LinearModelCopyCoefficients(const Napi::CallbackInfo& info) {
|
|
245
|
+
const Napi::Env env = info.Env();
|
|
246
|
+
if (!isLibraryLoaded(env)) {
|
|
247
|
+
return env.Null();
|
|
248
|
+
}
|
|
249
|
+
if (!g_library.linear_model_copy_coefficients) {
|
|
250
|
+
throwError(env, "Symbol linear_model_copy_coefficients is unavailable.");
|
|
251
|
+
return env.Null();
|
|
252
|
+
}
|
|
253
|
+
if (info.Length() != 2 || !info[1].IsTypedArray()) {
|
|
254
|
+
throwTypeError(env, "linearModelCopyCoefficients(handle, out) expects (BigInt, Float64Array).");
|
|
255
|
+
return env.Null();
|
|
256
|
+
}
|
|
257
|
+
const NativeHandle handle = handleFromBigInt(info[0], env);
|
|
258
|
+
if (env.IsExceptionPending()) {
|
|
259
|
+
return env.Null();
|
|
260
|
+
}
|
|
261
|
+
auto out = info[1].As<Napi::Float64Array>();
|
|
262
|
+
const std::uint8_t status = g_library.linear_model_copy_coefficients(handle, out.Data());
|
|
263
|
+
return Napi::Number::New(env, status);
|
|
264
|
+
}
|
|
265
|
+
|
|
266
|
+
Napi::Value LinearModelGetIntercept(const Napi::CallbackInfo& info) {
|
|
267
|
+
const Napi::Env env = info.Env();
|
|
268
|
+
if (!isLibraryLoaded(env)) {
|
|
269
|
+
return env.Null();
|
|
270
|
+
}
|
|
271
|
+
if (!g_library.linear_model_get_intercept) {
|
|
272
|
+
throwError(env, "Symbol linear_model_get_intercept is unavailable.");
|
|
273
|
+
return env.Null();
|
|
274
|
+
}
|
|
275
|
+
if (info.Length() != 1) {
|
|
276
|
+
throwTypeError(env, "linearModelGetIntercept(handle) expects one BigInt.");
|
|
277
|
+
return env.Null();
|
|
278
|
+
}
|
|
279
|
+
const NativeHandle handle = handleFromBigInt(info[0], env);
|
|
280
|
+
if (env.IsExceptionPending()) {
|
|
281
|
+
return env.Null();
|
|
282
|
+
}
|
|
283
|
+
return Napi::Number::New(env, g_library.linear_model_get_intercept(handle));
|
|
284
|
+
}
|
|
285
|
+
|
|
286
|
+
Napi::Value LogisticModelCreate(const Napi::CallbackInfo& info) {
|
|
287
|
+
const Napi::Env env = info.Env();
|
|
288
|
+
if (!isLibraryLoaded(env)) {
|
|
289
|
+
return env.Null();
|
|
290
|
+
}
|
|
291
|
+
if (!g_library.logistic_model_create) {
|
|
292
|
+
throwError(env, "Symbol logistic_model_create is unavailable.");
|
|
293
|
+
return env.Null();
|
|
294
|
+
}
|
|
295
|
+
if (info.Length() != 2 || !info[0].IsNumber() || !info[1].IsNumber()) {
|
|
296
|
+
throwTypeError(env, "logisticModelCreate(nFeatures, fitIntercept) expects two numbers.");
|
|
297
|
+
return env.Null();
|
|
298
|
+
}
|
|
299
|
+
const std::size_t n_features = static_cast<std::size_t>(info[0].As<Napi::Number>().Uint32Value());
|
|
300
|
+
const std::uint8_t fit_intercept = static_cast<std::uint8_t>(info[1].As<Napi::Number>().Uint32Value());
|
|
301
|
+
const NativeHandle handle = g_library.logistic_model_create(n_features, fit_intercept);
|
|
302
|
+
return Napi::BigInt::New(env, static_cast<std::uint64_t>(handle));
|
|
303
|
+
}
|
|
304
|
+
|
|
305
|
+
Napi::Value LogisticModelDestroy(const Napi::CallbackInfo& info) {
|
|
306
|
+
const Napi::Env env = info.Env();
|
|
307
|
+
if (!isLibraryLoaded(env)) {
|
|
308
|
+
return env.Null();
|
|
309
|
+
}
|
|
310
|
+
if (!g_library.logistic_model_destroy) {
|
|
311
|
+
throwError(env, "Symbol logistic_model_destroy is unavailable.");
|
|
312
|
+
return env.Null();
|
|
313
|
+
}
|
|
314
|
+
if (info.Length() != 1) {
|
|
315
|
+
throwTypeError(env, "logisticModelDestroy(handle) expects one BigInt.");
|
|
316
|
+
return env.Null();
|
|
317
|
+
}
|
|
318
|
+
const NativeHandle handle = handleFromBigInt(info[0], env);
|
|
319
|
+
if (env.IsExceptionPending()) {
|
|
320
|
+
return env.Null();
|
|
321
|
+
}
|
|
322
|
+
g_library.logistic_model_destroy(handle);
|
|
323
|
+
return env.Undefined();
|
|
324
|
+
}
|
|
325
|
+
|
|
326
|
+
Napi::Value LogisticModelFit(const Napi::CallbackInfo& info) {
|
|
327
|
+
const Napi::Env env = info.Env();
|
|
328
|
+
if (!isLibraryLoaded(env)) {
|
|
329
|
+
return env.Null();
|
|
330
|
+
}
|
|
331
|
+
if (!g_library.logistic_model_fit) {
|
|
332
|
+
throwError(env, "Symbol logistic_model_fit is unavailable.");
|
|
333
|
+
return env.Null();
|
|
334
|
+
}
|
|
335
|
+
if (info.Length() != 8 || !info[1].IsTypedArray() || !info[2].IsTypedArray()) {
|
|
336
|
+
throwTypeError(env, "logisticModelFit(handle, x, y, nSamples, learningRate, l2, maxIter, tolerance) has invalid arguments.");
|
|
337
|
+
return env.Null();
|
|
338
|
+
}
|
|
339
|
+
const NativeHandle handle = handleFromBigInt(info[0], env);
|
|
340
|
+
if (env.IsExceptionPending()) {
|
|
341
|
+
return env.Null();
|
|
342
|
+
}
|
|
343
|
+
auto x = info[1].As<Napi::Float64Array>();
|
|
344
|
+
auto y = info[2].As<Napi::Float64Array>();
|
|
345
|
+
const std::size_t n_samples = static_cast<std::size_t>(info[3].As<Napi::Number>().Uint32Value());
|
|
346
|
+
const double learning_rate = info[4].As<Napi::Number>().DoubleValue();
|
|
347
|
+
const double l2 = info[5].As<Napi::Number>().DoubleValue();
|
|
348
|
+
const std::size_t max_iter = static_cast<std::size_t>(info[6].As<Napi::Number>().Uint32Value());
|
|
349
|
+
const double tolerance = info[7].As<Napi::Number>().DoubleValue();
|
|
350
|
+
const std::size_t epochs = g_library.logistic_model_fit(
|
|
351
|
+
handle, x.Data(), y.Data(), n_samples, learning_rate, l2, max_iter, tolerance);
|
|
352
|
+
return Napi::BigInt::New(env, static_cast<std::uint64_t>(epochs));
|
|
353
|
+
}
|
|
354
|
+
|
|
355
|
+
Napi::Value LogisticModelFitLbfgs(const Napi::CallbackInfo& info) {
|
|
356
|
+
const Napi::Env env = info.Env();
|
|
357
|
+
if (!isLibraryLoaded(env)) {
|
|
358
|
+
return env.Null();
|
|
359
|
+
}
|
|
360
|
+
if (!g_library.logistic_model_fit_lbfgs) {
|
|
361
|
+
throwError(env, "Symbol logistic_model_fit_lbfgs is unavailable.");
|
|
362
|
+
return env.Null();
|
|
363
|
+
}
|
|
364
|
+
if (info.Length() != 8 || !info[1].IsTypedArray() || !info[2].IsTypedArray()) {
|
|
365
|
+
throwTypeError(env, "logisticModelFitLbfgs(handle, x, y, nSamples, maxIter, tolerance, l2, memory) has invalid arguments.");
|
|
366
|
+
return env.Null();
|
|
367
|
+
}
|
|
368
|
+
const NativeHandle handle = handleFromBigInt(info[0], env);
|
|
369
|
+
if (env.IsExceptionPending()) {
|
|
370
|
+
return env.Null();
|
|
371
|
+
}
|
|
372
|
+
auto x = info[1].As<Napi::Float64Array>();
|
|
373
|
+
auto y = info[2].As<Napi::Float64Array>();
|
|
374
|
+
const std::size_t n_samples = static_cast<std::size_t>(info[3].As<Napi::Number>().Uint32Value());
|
|
375
|
+
const std::size_t max_iter = static_cast<std::size_t>(info[4].As<Napi::Number>().Uint32Value());
|
|
376
|
+
const double tolerance = info[5].As<Napi::Number>().DoubleValue();
|
|
377
|
+
const double l2 = info[6].As<Napi::Number>().DoubleValue();
|
|
378
|
+
const std::size_t memory = static_cast<std::size_t>(info[7].As<Napi::Number>().Uint32Value());
|
|
379
|
+
const std::size_t epochs = g_library.logistic_model_fit_lbfgs(
|
|
380
|
+
handle, x.Data(), y.Data(), n_samples, max_iter, tolerance, l2, memory);
|
|
381
|
+
return Napi::BigInt::New(env, static_cast<std::uint64_t>(epochs));
|
|
382
|
+
}
|
|
383
|
+
|
|
384
|
+
Napi::Value LogisticModelCopyCoefficients(const Napi::CallbackInfo& info) {
|
|
385
|
+
const Napi::Env env = info.Env();
|
|
386
|
+
if (!isLibraryLoaded(env)) {
|
|
387
|
+
return env.Null();
|
|
388
|
+
}
|
|
389
|
+
if (!g_library.logistic_model_copy_coefficients) {
|
|
390
|
+
throwError(env, "Symbol logistic_model_copy_coefficients is unavailable.");
|
|
391
|
+
return env.Null();
|
|
392
|
+
}
|
|
393
|
+
if (info.Length() != 2 || !info[1].IsTypedArray()) {
|
|
394
|
+
throwTypeError(env, "logisticModelCopyCoefficients(handle, out) expects (BigInt, Float64Array).");
|
|
395
|
+
return env.Null();
|
|
396
|
+
}
|
|
397
|
+
const NativeHandle handle = handleFromBigInt(info[0], env);
|
|
398
|
+
if (env.IsExceptionPending()) {
|
|
399
|
+
return env.Null();
|
|
400
|
+
}
|
|
401
|
+
auto out = info[1].As<Napi::Float64Array>();
|
|
402
|
+
const std::uint8_t status = g_library.logistic_model_copy_coefficients(handle, out.Data());
|
|
403
|
+
return Napi::Number::New(env, status);
|
|
404
|
+
}
|
|
405
|
+
|
|
406
|
+
Napi::Value LogisticModelGetIntercept(const Napi::CallbackInfo& info) {
|
|
407
|
+
const Napi::Env env = info.Env();
|
|
408
|
+
if (!isLibraryLoaded(env)) {
|
|
409
|
+
return env.Null();
|
|
410
|
+
}
|
|
411
|
+
if (!g_library.logistic_model_get_intercept) {
|
|
412
|
+
throwError(env, "Symbol logistic_model_get_intercept is unavailable.");
|
|
413
|
+
return env.Null();
|
|
414
|
+
}
|
|
415
|
+
if (info.Length() != 1) {
|
|
416
|
+
throwTypeError(env, "logisticModelGetIntercept(handle) expects one BigInt.");
|
|
417
|
+
return env.Null();
|
|
418
|
+
}
|
|
419
|
+
const NativeHandle handle = handleFromBigInt(info[0], env);
|
|
420
|
+
if (env.IsExceptionPending()) {
|
|
421
|
+
return env.Null();
|
|
422
|
+
}
|
|
423
|
+
return Napi::Number::New(env, g_library.logistic_model_get_intercept(handle));
|
|
424
|
+
}
|
|
425
|
+
|
|
426
|
+
Napi::Object Init(Napi::Env env, Napi::Object exports) {
|
|
427
|
+
exports.Set("loadLibrary", Napi::Function::New(env, LoadNativeLibrary));
|
|
428
|
+
exports.Set("unloadLibrary", Napi::Function::New(env, UnloadLibrary));
|
|
429
|
+
exports.Set("loadedPath", Napi::Function::New(env, LoadedPath));
|
|
430
|
+
exports.Set("abiVersion", Napi::Function::New(env, AbiVersion));
|
|
431
|
+
|
|
432
|
+
exports.Set("linearModelCreate", Napi::Function::New(env, LinearModelCreate));
|
|
433
|
+
exports.Set("linearModelDestroy", Napi::Function::New(env, LinearModelDestroy));
|
|
434
|
+
exports.Set("linearModelFit", Napi::Function::New(env, LinearModelFit));
|
|
435
|
+
exports.Set("linearModelCopyCoefficients", Napi::Function::New(env, LinearModelCopyCoefficients));
|
|
436
|
+
exports.Set("linearModelGetIntercept", Napi::Function::New(env, LinearModelGetIntercept));
|
|
437
|
+
|
|
438
|
+
exports.Set("logisticModelCreate", Napi::Function::New(env, LogisticModelCreate));
|
|
439
|
+
exports.Set("logisticModelDestroy", Napi::Function::New(env, LogisticModelDestroy));
|
|
440
|
+
exports.Set("logisticModelFit", Napi::Function::New(env, LogisticModelFit));
|
|
441
|
+
exports.Set("logisticModelFitLbfgs", Napi::Function::New(env, LogisticModelFitLbfgs));
|
|
442
|
+
exports.Set("logisticModelCopyCoefficients", Napi::Function::New(env, LogisticModelCopyCoefficients));
|
|
443
|
+
exports.Set("logisticModelGetIntercept", Napi::Function::New(env, LogisticModelGetIntercept));
|
|
444
|
+
|
|
445
|
+
return exports;
|
|
446
|
+
}
|
|
447
|
+
|
|
448
|
+
} // namespace
|
|
449
|
+
|
|
450
|
+
NODE_API_MODULE(bun_scikit_node_addon, Init)
|