catboost 1.25.1 → 1.26.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/DEPLOYMENT.md +22 -15
- package/README.md +37 -27
- package/binding.gyp +5 -7
- package/build_scripts/bootstrap.js +2 -1
- package/build_scripts/out/build.js +46 -68
- package/build_scripts/out/build_model.js +1 -1
- package/build_scripts/out/{build_ya.js → build_native.js} +1 -1
- package/build_scripts/out/ci.js +5 -5
- package/build_scripts/out/config.js +32 -18
- package/build_scripts/out/install.js +5 -3
- package/build_scripts/out/package_prepublish.js +1 -1
- package/build_scripts/out/packaging.js +1 -19
- package/build_scripts/out/run_tests.js +1 -1
- package/build_scripts/out/test.js +8 -3
- package/config.json +18 -11
- package/inc/catboost/libs/model_interface/c_api.h +349 -3
- package/lib/catboost.d.ts +65 -21
- package/package.json +4 -4
- package/src/api_helpers.cpp +100 -24
- package/src/api_helpers.h +8 -7
- package/src/api_module.cpp +1 -2
- package/src/model.cpp +483 -83
- package/src/model.h +24 -9
- package/inc/contrib/libs/cxxsupp/system_stl/include/stlfwd +0 -14
- package/inc/util/charset/recode_result.h +0 -9
- package/inc/util/charset/unicode_table.h +0 -123
- package/inc/util/charset/unidata.h +0 -421
- package/inc/util/charset/utf8.h +0 -384
- package/inc/util/charset/wide.h +0 -843
- package/inc/util/charset/wide_specific.h +0 -22
- package/inc/util/datetime/base.h +0 -669
- package/inc/util/datetime/constants.h +0 -7
- package/inc/util/datetime/cputimer.h +0 -124
- package/inc/util/datetime/parser.h +0 -292
- package/inc/util/datetime/systime.h +0 -47
- package/inc/util/datetime/uptime.h +0 -8
- package/inc/util/digest/city.h +0 -88
- package/inc/util/digest/fnv.h +0 -73
- package/inc/util/digest/multi.h +0 -14
- package/inc/util/digest/murmur.h +0 -57
- package/inc/util/digest/numeric.h +0 -86
- package/inc/util/digest/sequence.h +0 -48
- package/inc/util/draft/date.h +0 -129
- package/inc/util/draft/datetime.h +0 -184
- package/inc/util/draft/enum.h +0 -136
- package/inc/util/draft/holder_vector.h +0 -102
- package/inc/util/draft/ip.h +0 -131
- package/inc/util/draft/matrix.h +0 -108
- package/inc/util/draft/memory.h +0 -40
- package/inc/util/folder/dirent_win.h +0 -46
- package/inc/util/folder/dirut.h +0 -121
- package/inc/util/folder/filelist.h +0 -81
- package/inc/util/folder/fts.h +0 -108
- package/inc/util/folder/iterator.h +0 -109
- package/inc/util/folder/lstat_win.h +0 -20
- package/inc/util/folder/path.h +0 -225
- package/inc/util/folder/pathsplit.h +0 -113
- package/inc/util/folder/tempdir.h +0 -42
- package/inc/util/generic/adaptor.h +0 -134
- package/inc/util/generic/algorithm.h +0 -765
- package/inc/util/generic/array_ref.h +0 -282
- package/inc/util/generic/array_size.h +0 -24
- package/inc/util/generic/benchmark/vector_count_ctor/f.h +0 -9
- package/inc/util/generic/bitmap.h +0 -1115
- package/inc/util/generic/bitops.h +0 -459
- package/inc/util/generic/bt_exception.h +0 -24
- package/inc/util/generic/buffer.h +0 -232
- package/inc/util/generic/cast.h +0 -176
- package/inc/util/generic/deque.h +0 -24
- package/inc/util/generic/explicit_type.h +0 -42
- package/inc/util/generic/fastqueue.h +0 -55
- package/inc/util/generic/flags.h +0 -244
- package/inc/util/generic/function.h +0 -103
- package/inc/util/generic/fwd.h +0 -171
- package/inc/util/generic/guid.h +0 -61
- package/inc/util/generic/hash.h +0 -2032
- package/inc/util/generic/hash_primes.h +0 -140
- package/inc/util/generic/hash_set.h +0 -490
- package/inc/util/generic/hide_ptr.h +0 -3
- package/inc/util/generic/intrlist.h +0 -876
- package/inc/util/generic/is_in.h +0 -53
- package/inc/util/generic/iterator.h +0 -137
- package/inc/util/generic/iterator_range.h +0 -105
- package/inc/util/generic/lazy_value.h +0 -66
- package/inc/util/generic/list.h +0 -22
- package/inc/util/generic/map.h +0 -44
- package/inc/util/generic/mapfindptr.h +0 -60
- package/inc/util/generic/maybe.h +0 -713
- package/inc/util/generic/maybe_traits.h +0 -164
- package/inc/util/generic/mem_copy.h +0 -55
- package/inc/util/generic/noncopyable.h +0 -38
- package/inc/util/generic/object_counter.h +0 -53
- package/inc/util/generic/ptr.h +0 -1113
- package/inc/util/generic/queue.h +0 -57
- package/inc/util/generic/refcount.h +0 -162
- package/inc/util/generic/reserve.h +0 -11
- package/inc/util/generic/scope.h +0 -65
- package/inc/util/generic/serialized_enum.h +0 -406
- package/inc/util/generic/set.h +0 -42
- package/inc/util/generic/singleton.h +0 -136
- package/inc/util/generic/size_literals.h +0 -65
- package/inc/util/generic/stack.h +0 -18
- package/inc/util/generic/store_policy.h +0 -120
- package/inc/util/generic/strbase.h +0 -612
- package/inc/util/generic/strbuf.h +0 -552
- package/inc/util/generic/strfcpy.h +0 -17
- package/inc/util/generic/string.h +0 -1572
- package/inc/util/generic/string_hash.h +0 -21
- package/inc/util/generic/string_ut.h +0 -1175
- package/inc/util/generic/type_name.h +0 -34
- package/inc/util/generic/typelist.h +0 -114
- package/inc/util/generic/typetraits.h +0 -325
- package/inc/util/generic/utility.h +0 -132
- package/inc/util/generic/va_args.h +0 -400
- package/inc/util/generic/variant.h +0 -631
- package/inc/util/generic/variant_traits.h +0 -171
- package/inc/util/generic/vector.h +0 -119
- package/inc/util/generic/xrange.h +0 -258
- package/inc/util/generic/yexception.h +0 -212
- package/inc/util/generic/yexception_ut.h +0 -14
- package/inc/util/generic/ylimits.h +0 -92
- package/inc/util/generic/ymath.h +0 -206
- package/inc/util/memory/addstorage.h +0 -93
- package/inc/util/memory/alloc.h +0 -27
- package/inc/util/memory/blob.h +0 -296
- package/inc/util/memory/mmapalloc.h +0 -8
- package/inc/util/memory/pool.h +0 -432
- package/inc/util/memory/segmented_string_pool.h +0 -194
- package/inc/util/memory/segpool_alloc.h +0 -118
- package/inc/util/memory/smallobj.h +0 -141
- package/inc/util/memory/tempbuf.h +0 -111
- package/inc/util/network/address.h +0 -136
- package/inc/util/network/endpoint.h +0 -61
- package/inc/util/network/hostip.h +0 -16
- package/inc/util/network/init.h +0 -60
- package/inc/util/network/interface.h +0 -17
- package/inc/util/network/iovec.h +0 -65
- package/inc/util/network/ip.h +0 -116
- package/inc/util/network/nonblock.h +0 -8
- package/inc/util/network/pair.h +0 -9
- package/inc/util/network/poller.h +0 -58
- package/inc/util/network/pollerimpl.h +0 -707
- package/inc/util/network/sock.h +0 -608
- package/inc/util/network/socket.h +0 -421
- package/inc/util/random/common_ops.h +0 -130
- package/inc/util/random/easy.h +0 -47
- package/inc/util/random/entropy.h +0 -21
- package/inc/util/random/fast.h +0 -101
- package/inc/util/random/init_atfork.h +0 -3
- package/inc/util/random/lcg_engine.h +0 -66
- package/inc/util/random/mersenne.h +0 -46
- package/inc/util/random/mersenne32.h +0 -50
- package/inc/util/random/mersenne64.h +0 -50
- package/inc/util/random/normal.h +0 -38
- package/inc/util/random/random.h +0 -30
- package/inc/util/random/shuffle.h +0 -39
- package/inc/util/str_stl.h +0 -266
- package/inc/util/stream/aligned.h +0 -99
- package/inc/util/stream/buffer.h +0 -119
- package/inc/util/stream/buffered.h +0 -225
- package/inc/util/stream/debug.h +0 -53
- package/inc/util/stream/direct_io.h +0 -43
- package/inc/util/stream/file.h +0 -108
- package/inc/util/stream/format.h +0 -444
- package/inc/util/stream/fwd.h +0 -100
- package/inc/util/stream/hex.h +0 -8
- package/inc/util/stream/holder.h +0 -44
- package/inc/util/stream/input.h +0 -273
- package/inc/util/stream/labeled.h +0 -19
- package/inc/util/stream/length.h +0 -100
- package/inc/util/stream/mem.h +0 -255
- package/inc/util/stream/multi.h +0 -32
- package/inc/util/stream/null.h +0 -61
- package/inc/util/stream/output.h +0 -304
- package/inc/util/stream/pipe.h +0 -112
- package/inc/util/stream/printf.h +0 -25
- package/inc/util/stream/str.h +0 -207
- package/inc/util/stream/tee.h +0 -28
- package/inc/util/stream/tempbuf.h +0 -21
- package/inc/util/stream/tokenizer.h +0 -214
- package/inc/util/stream/trace.h +0 -60
- package/inc/util/stream/walk.h +0 -35
- package/inc/util/stream/zerocopy.h +0 -91
- package/inc/util/stream/zerocopy_output.h +0 -57
- package/inc/util/stream/zlib.h +0 -173
- package/inc/util/string/ascii.h +0 -236
- package/inc/util/string/builder.h +0 -39
- package/inc/util/string/cast.h +0 -347
- package/inc/util/string/cstriter.h +0 -14
- package/inc/util/string/escape.h +0 -70
- package/inc/util/string/hex.h +0 -59
- package/inc/util/string/join.h +0 -194
- package/inc/util/string/printf.h +0 -13
- package/inc/util/string/reverse.h +0 -16
- package/inc/util/string/split.h +0 -1080
- package/inc/util/string/strip.h +0 -257
- package/inc/util/string/strspn.h +0 -65
- package/inc/util/string/subst.h +0 -56
- package/inc/util/string/type.h +0 -50
- package/inc/util/string/util.h +0 -195
- package/inc/util/string/vector.h +0 -132
- package/inc/util/system/align.h +0 -50
- package/inc/util/system/atexit.h +0 -22
- package/inc/util/system/atomic.h +0 -51
- package/inc/util/system/atomic_gcc.h +0 -90
- package/inc/util/system/atomic_ops.h +0 -189
- package/inc/util/system/atomic_win.h +0 -114
- package/inc/util/system/backtrace.h +0 -39
- package/inc/util/system/byteorder.h +0 -186
- package/inc/util/system/compat.h +0 -84
- package/inc/util/system/compiler.h +0 -620
- package/inc/util/system/condvar.h +0 -71
- package/inc/util/system/context.h +0 -181
- package/inc/util/system/context_aarch64.h +0 -8
- package/inc/util/system/context_i686.h +0 -9
- package/inc/util/system/context_x86.h +0 -12
- package/inc/util/system/context_x86_64.h +0 -7
- package/inc/util/system/cpu_id.h +0 -159
- package/inc/util/system/daemon.h +0 -28
- package/inc/util/system/datetime.h +0 -98
- package/inc/util/system/defaults.h +0 -149
- package/inc/util/system/demangle.h +0 -5
- package/inc/util/system/demangle_impl.h +0 -23
- package/inc/util/system/direct_io.h +0 -71
- package/inc/util/system/dynlib.h +0 -119
- package/inc/util/system/env.h +0 -32
- package/inc/util/system/error.h +0 -95
- package/inc/util/system/event.h +0 -122
- package/inc/util/system/execpath.h +0 -17
- package/inc/util/system/fasttime.h +0 -6
- package/inc/util/system/fhandle.h +0 -27
- package/inc/util/system/file.h +0 -210
- package/inc/util/system/file_lock.h +0 -34
- package/inc/util/system/filemap.h +0 -383
- package/inc/util/system/flock.h +0 -35
- package/inc/util/system/fs.h +0 -156
- package/inc/util/system/fs_win.h +0 -29
- package/inc/util/system/fstat.h +0 -46
- package/inc/util/system/getpid.h +0 -12
- package/inc/util/system/guard.h +0 -179
- package/inc/util/system/hi_lo.h +0 -139
- package/inc/util/system/hostname.h +0 -10
- package/inc/util/system/hp_timer.h +0 -36
- package/inc/util/system/info.h +0 -12
- package/inc/util/system/interrupt_signals.h +0 -22
- package/inc/util/system/madvise.h +0 -30
- package/inc/util/system/maxlen.h +0 -32
- package/inc/util/system/mem_info.h +0 -18
- package/inc/util/system/mincore.h +0 -38
- package/inc/util/system/mktemp.h +0 -11
- package/inc/util/system/mlock.h +0 -43
- package/inc/util/system/mutex.h +0 -67
- package/inc/util/system/nice.h +0 -3
- package/inc/util/system/pipe.h +0 -90
- package/inc/util/system/platform.h +0 -246
- package/inc/util/system/progname.h +0 -13
- package/inc/util/system/protect.h +0 -25
- package/inc/util/system/rusage.h +0 -26
- package/inc/util/system/rwlock.h +0 -78
- package/inc/util/system/sanitizers.h +0 -122
- package/inc/util/system/sem.h +0 -41
- package/inc/util/system/shellcommand.h +0 -472
- package/inc/util/system/shmat.h +0 -32
- package/inc/util/system/sigset.h +0 -78
- package/inc/util/system/spin_wait.h +0 -10
- package/inc/util/system/spinlock.h +0 -121
- package/inc/util/system/src_location.h +0 -25
- package/inc/util/system/src_root.h +0 -68
- package/inc/util/system/sys_alloc.h +0 -43
- package/inc/util/system/sysstat.h +0 -52
- package/inc/util/system/tempfile.h +0 -34
- package/inc/util/system/thread.h +0 -167
- package/inc/util/system/tls.h +0 -307
- package/inc/util/system/types.h +0 -119
- package/inc/util/system/unaligned_mem.h +0 -67
- package/inc/util/system/user.h +0 -5
- package/inc/util/system/utime.h +0 -6
- package/inc/util/system/valgrind.h +0 -48
- package/inc/util/system/winint.h +0 -43
- package/inc/util/system/yassert.h +0 -121
- package/inc/util/system/yield.h +0 -4
- package/inc/util/thread/factory.h +0 -65
- package/inc/util/thread/fwd.h +0 -30
- package/inc/util/thread/lfqueue.h +0 -406
- package/inc/util/thread/lfstack.h +0 -188
- package/inc/util/thread/pool.h +0 -388
- package/inc/util/thread/singleton.h +0 -42
- package/inc/util/ysafeptr.h +0 -427
- package/inc/util/ysaveload.h +0 -700
package/src/model.cpp
CHANGED
|
@@ -2,12 +2,15 @@
|
|
|
2
2
|
|
|
3
3
|
#include "api_helpers.h"
|
|
4
4
|
|
|
5
|
+
#include <vector>
|
|
6
|
+
|
|
7
|
+
|
|
5
8
|
namespace {
|
|
6
9
|
|
|
7
10
|
// Collect pointers to matrix rows into a vector.
|
|
8
|
-
template <typename T, typename V = const T*, typename C = const
|
|
9
|
-
|
|
10
|
-
|
|
11
|
+
template <typename T, typename V = const T*, typename C = const std::vector<T>>
|
|
12
|
+
std::vector<V> CollectMatrixRowPointers(C& matrix, uint32_t rowLength) {
|
|
13
|
+
std::vector<V> pointers;
|
|
11
14
|
for (uint32_t i = 0; i < matrix.size(); i += rowLength) {
|
|
12
15
|
pointers.push_back(matrix.data() + i);
|
|
13
16
|
}
|
|
@@ -30,10 +33,12 @@ TModel::TModel(const Napi::CallbackInfo& info): Napi::ObjectWrap<TModel>(info) {
|
|
|
30
33
|
}
|
|
31
34
|
|
|
32
35
|
NHelper::Check(env, info[0].IsString(), "File name argument should be a string");
|
|
33
|
-
const bool status = LoadFullModelFromFile(
|
|
34
|
-
|
|
36
|
+
const bool status = LoadFullModelFromFile(
|
|
37
|
+
this->Handle,
|
|
38
|
+
info[0].As<Napi::String>().Utf8Value().c_str()
|
|
39
|
+
);
|
|
35
40
|
// Even if it fails, this check schedules NodeJS exception, not C++ one.
|
|
36
|
-
// The C++ object is considered to be successfully created and will be
|
|
41
|
+
// The C++ object is considered to be successfully created and will be destroyed by Node runtime
|
|
37
42
|
// later as usual.
|
|
38
43
|
NHelper::CheckStatus(env, status);
|
|
39
44
|
if (status) {
|
|
@@ -52,10 +57,14 @@ Napi::Function TModel::GetClass(Napi::Env env) {
|
|
|
52
57
|
TModel::InstanceMethod("loadModel", &TModel::LoadFullFromFile),
|
|
53
58
|
TModel::InstanceMethod("predict", &TModel::CalcPrediction),
|
|
54
59
|
TModel::InstanceMethod("enableGPUEvaluation", &TModel::EvaluateOnGPU),
|
|
60
|
+
TModel::InstanceMethod("setPredictionType", &TModel::SetPredictionType),
|
|
55
61
|
TModel::InstanceMethod("getFloatFeaturesCount", &TModel::GetModelFloatFeaturesCount),
|
|
56
62
|
TModel::InstanceMethod("getCatFeaturesCount", &TModel::GetModelCatFeaturesCount),
|
|
63
|
+
TModel::InstanceMethod("getTextFeaturesCount", &TModel::GetModelTextFeaturesCount),
|
|
64
|
+
TModel::InstanceMethod("getEmbeddingFeaturesCount", &TModel::GetModelEmbeddingFeaturesCount),
|
|
57
65
|
TModel::InstanceMethod("getTreeCount", &TModel::GetModelTreeCount),
|
|
58
66
|
TModel::InstanceMethod("getDimensionsCount", &TModel::GetModelDimensionsCount),
|
|
67
|
+
TModel::InstanceMethod("getPredictionDimensionsCount", &TModel::GetPredictionDimensionsCount),
|
|
59
68
|
});
|
|
60
69
|
}
|
|
61
70
|
|
|
@@ -64,70 +73,178 @@ void TModel::LoadFullFromFile(const Napi::CallbackInfo& info) {
|
|
|
64
73
|
Napi::Env env = info.Env();
|
|
65
74
|
|
|
66
75
|
if (!NHelper::Check(env, info.Length() >= 1, "Wrong number of arguments") ||
|
|
67
|
-
!NHelper::Check(env, info[0].IsString(), "File name string is required"))
|
|
76
|
+
!NHelper::Check(env, info[0].IsString(), "File name string is required"))
|
|
77
|
+
{
|
|
68
78
|
return;
|
|
69
79
|
}
|
|
70
80
|
|
|
71
81
|
NHelper::CheckNotNullHandle(env, this->Handle);
|
|
72
|
-
const bool status = LoadFullModelFromFile(
|
|
73
|
-
|
|
82
|
+
const bool status = LoadFullModelFromFile(
|
|
83
|
+
this->Handle,
|
|
84
|
+
info[0].As<Napi::String>().Utf8Value().c_str()
|
|
85
|
+
);
|
|
74
86
|
NHelper::CheckStatus(env, status);
|
|
75
87
|
if (status) {
|
|
76
88
|
this->ModelLoaded = true;
|
|
77
89
|
}
|
|
78
90
|
}
|
|
79
91
|
|
|
92
|
+
// Set model predictions postprocessing type.
|
|
93
|
+
void TModel::SetPredictionType(const Napi::CallbackInfo& info) {
|
|
94
|
+
Napi::Env env = info.Env();
|
|
95
|
+
|
|
96
|
+
if (!NHelper::Check(env, info.Length() >= 1, "Wrong number of arguments") ||
|
|
97
|
+
!NHelper::Check(env, info[0].IsString(), "predictionType argument must have string type"))
|
|
98
|
+
{
|
|
99
|
+
return;
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
NHelper::CheckNotNullHandle(env, this->Handle);
|
|
103
|
+
const bool status = SetPredictionTypeString(
|
|
104
|
+
this->Handle,
|
|
105
|
+
info[0].As<Napi::String>().Utf8Value().c_str()
|
|
106
|
+
);
|
|
107
|
+
NHelper::CheckStatus(env, status);
|
|
108
|
+
}
|
|
109
|
+
|
|
80
110
|
Napi::Value TModel::CalcPrediction(const Napi::CallbackInfo& info) {
|
|
81
111
|
Napi::Env env = info.Env();
|
|
112
|
+
if (!NHelper::Check(env, this->ModelLoaded, "Trying to predict from the empty model")) {
|
|
113
|
+
return env.Undefined();
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
if (!NHelper::Check(env, info.Length() >= 1, "Wrong number of arguments - expected at least 1")) {
|
|
117
|
+
return env.Undefined();
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
// Numerical features
|
|
82
122
|
|
|
83
|
-
if (!NHelper::
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
123
|
+
if (!NHelper::CheckIsMatrix(
|
|
124
|
+
env,
|
|
125
|
+
info[0],
|
|
126
|
+
NHelper::NAT_NUMBER,
|
|
127
|
+
"Expected the first argument to be a matrix of floats - "
|
|
128
|
+
))
|
|
129
|
+
{
|
|
87
130
|
return env.Undefined();
|
|
88
131
|
}
|
|
89
132
|
|
|
90
133
|
const Napi::Array floatFeatures = info[0].As<Napi::Array>();
|
|
91
|
-
const uint32_t
|
|
92
|
-
if (
|
|
134
|
+
const uint32_t sampleCount = floatFeatures.Length();
|
|
135
|
+
if (sampleCount == 0) {
|
|
93
136
|
return Napi::Array::New(env);
|
|
94
137
|
}
|
|
95
138
|
|
|
96
|
-
const uint32_t floatFeaturesSize = floatFeatures[0u].As<Napi::Array>().Length();
|
|
97
139
|
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
140
|
+
// Categorical features
|
|
141
|
+
Napi::Value catFeatures;
|
|
142
|
+
bool catFeaturesAreHashes = false;
|
|
143
|
+
|
|
144
|
+
if (info.Length() >= 2) {
|
|
145
|
+
catFeatures = info[1];
|
|
146
|
+
if (!NHelper::CheckIsMatrix(
|
|
147
|
+
env,
|
|
148
|
+
catFeatures,
|
|
149
|
+
NHelper::NAT_NUMBER_OR_STRING,
|
|
150
|
+
"Expected the second argument to be a matrix of strings or numbers - "
|
|
151
|
+
))
|
|
152
|
+
{
|
|
153
|
+
return env.Undefined();
|
|
154
|
+
}
|
|
155
|
+
const Napi::Array catFeaturesArray = catFeatures.As<Napi::Array>();
|
|
156
|
+
|
|
157
|
+
if (!NHelper::Check(
|
|
158
|
+
env,
|
|
159
|
+
catFeaturesArray.Length() == sampleCount,
|
|
160
|
+
"Expected the number of samples to be the same for both float and categorical features"
|
|
161
|
+
))
|
|
162
|
+
{
|
|
163
|
+
return env.Undefined();
|
|
164
|
+
}
|
|
165
|
+
if (sampleCount) {
|
|
166
|
+
const Napi::Array catRow = catFeaturesArray[0u].As<Napi::Array>();
|
|
167
|
+
if (catRow.Length()) {
|
|
168
|
+
catFeaturesAreHashes = catRow[0u].IsNumber();
|
|
169
|
+
}
|
|
105
170
|
}
|
|
106
171
|
}
|
|
107
172
|
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
173
|
+
|
|
174
|
+
// Text features
|
|
175
|
+
Napi::Value textFeatures;
|
|
176
|
+
if (info.Length() >= 3) {
|
|
177
|
+
textFeatures = info[2];
|
|
178
|
+
if (!NHelper::CheckIsMatrix(
|
|
179
|
+
env,
|
|
180
|
+
textFeatures,
|
|
181
|
+
NHelper::NAT_STRING,
|
|
182
|
+
"Expected the third argument to be a matrix of strings - "
|
|
183
|
+
))
|
|
184
|
+
{
|
|
185
|
+
return env.Undefined();
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
if (!NHelper::Check(
|
|
189
|
+
env,
|
|
190
|
+
textFeatures.As<Napi::Array>().Length() == sampleCount,
|
|
191
|
+
"Expected the number of samples to be the same for both float and text features"
|
|
192
|
+
))
|
|
193
|
+
{
|
|
194
|
+
return env.Undefined();
|
|
195
|
+
}
|
|
112
196
|
}
|
|
113
|
-
const Napi::Array catFeatures = info[1].As<Napi::Array>();
|
|
114
197
|
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
198
|
+
// Embedding features
|
|
199
|
+
Napi::Value embeddingFeatures;
|
|
200
|
+
if (info.Length() == 4) {
|
|
201
|
+
embeddingFeatures = info[3];
|
|
202
|
+
if (!NHelper::CheckIsMatrix(
|
|
203
|
+
env,
|
|
204
|
+
embeddingFeatures,
|
|
205
|
+
NHelper::NAT_ARRAY_OR_NUMBERS,
|
|
206
|
+
"Expected the fourth argument to be a matrix of arrays of numbers - "
|
|
207
|
+
))
|
|
208
|
+
{
|
|
209
|
+
return env.Undefined();
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
if (!NHelper::Check(
|
|
213
|
+
env,
|
|
214
|
+
embeddingFeatures.As<Napi::Array>().Length() == sampleCount,
|
|
215
|
+
"Expected the number of samples to be the same for both float and embedding features"
|
|
216
|
+
))
|
|
217
|
+
{
|
|
218
|
+
return env.Undefined();
|
|
219
|
+
}
|
|
118
220
|
}
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
if (catFeaturesAreHashes) {
|
|
224
|
+
return CalcPredictionWithCatFeaturesAsHashes(
|
|
225
|
+
env,
|
|
226
|
+
sampleCount,
|
|
227
|
+
floatFeatures,
|
|
228
|
+
catFeatures,
|
|
229
|
+
textFeatures,
|
|
230
|
+
embeddingFeatures
|
|
231
|
+
);
|
|
122
232
|
}
|
|
123
|
-
return
|
|
233
|
+
return CalcPredictionWithCatFeaturesAsStrings(
|
|
234
|
+
env,
|
|
235
|
+
sampleCount,
|
|
236
|
+
floatFeatures,
|
|
237
|
+
catFeatures,
|
|
238
|
+
textFeatures,
|
|
239
|
+
embeddingFeatures
|
|
240
|
+
);
|
|
124
241
|
}
|
|
125
242
|
|
|
126
243
|
void TModel::EvaluateOnGPU(const Napi::CallbackInfo& info) {
|
|
127
244
|
Napi::Env env = info.Env();
|
|
128
245
|
if (!NHelper::Check(env, info.Length() >= 1, "Wrong number of arguments - expected 1") ||
|
|
129
|
-
!NHelper::Check(env, info[0].IsNumber(),
|
|
130
|
-
|
|
246
|
+
!NHelper::Check(env, info[0].IsNumber(), "Expected the first argument to be a numeric deviceId"))
|
|
247
|
+
{
|
|
131
248
|
return;
|
|
132
249
|
}
|
|
133
250
|
|
|
@@ -150,6 +267,20 @@ Napi::Value TModel::GetModelCatFeaturesCount(const Napi::CallbackInfo& info) {
|
|
|
150
267
|
return Napi::Number::New(env, count);
|
|
151
268
|
}
|
|
152
269
|
|
|
270
|
+
Napi::Value TModel::GetModelTextFeaturesCount(const Napi::CallbackInfo& info) {
|
|
271
|
+
Napi::Env env = info.Env();
|
|
272
|
+
const size_t count = GetTextFeaturesCount(this->Handle);
|
|
273
|
+
|
|
274
|
+
return Napi::Number::New(env, count);
|
|
275
|
+
}
|
|
276
|
+
|
|
277
|
+
Napi::Value TModel::GetModelEmbeddingFeaturesCount(const Napi::CallbackInfo& info) {
|
|
278
|
+
Napi::Env env = info.Env();
|
|
279
|
+
const size_t count = GetEmbeddingFeaturesCount(this->Handle);
|
|
280
|
+
|
|
281
|
+
return Napi::Number::New(env, count);
|
|
282
|
+
}
|
|
283
|
+
|
|
153
284
|
Napi::Value TModel::GetModelTreeCount(const Napi::CallbackInfo& info) {
|
|
154
285
|
Napi::Env env = info.Env();
|
|
155
286
|
const size_t count = GetTreeCount(this->Handle);
|
|
@@ -164,68 +295,337 @@ Napi::Value TModel::GetModelDimensionsCount(const Napi::CallbackInfo& info) {
|
|
|
164
295
|
return Napi::Number::New(env, count);
|
|
165
296
|
}
|
|
166
297
|
|
|
167
|
-
Napi::
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
298
|
+
Napi::Value TModel::GetPredictionDimensionsCount(const Napi::CallbackInfo& info) {
|
|
299
|
+
Napi::Env env = info.Env();
|
|
300
|
+
const size_t count = ::GetPredictionDimensionsCount(this->Handle);
|
|
301
|
+
|
|
302
|
+
return Napi::Number::New(env, count);
|
|
303
|
+
}
|
|
304
|
+
|
|
305
|
+
static void GetNumericFeaturesData(
|
|
306
|
+
const uint32_t sampleCount,
|
|
307
|
+
const Napi::Array& floatFeatures,
|
|
308
|
+
uint32_t* floatFeatureCount,
|
|
309
|
+
std::vector<float>* storage,
|
|
310
|
+
std::vector<const float*>* sampleDataPtrs
|
|
311
|
+
) {
|
|
312
|
+
const uint32_t floatFeatureCountLocal = floatFeatures[0u].As<Napi::Array>().Length();
|
|
313
|
+
*floatFeatureCount = floatFeatureCountLocal;
|
|
314
|
+
|
|
315
|
+
storage->clear();
|
|
316
|
+
storage->reserve(floatFeatureCountLocal * sampleCount);
|
|
317
|
+
|
|
318
|
+
for (uint32_t i = 0; i < sampleCount; ++i) {
|
|
319
|
+
const Napi::Array row = floatFeatures[i].As<Napi::Array>();
|
|
320
|
+
for (uint32_t j = 0; j < floatFeatureCountLocal; ++j) {
|
|
321
|
+
storage->push_back(row[j].As<Napi::Number>().FloatValue());
|
|
322
|
+
}
|
|
323
|
+
}
|
|
173
324
|
|
|
174
|
-
|
|
175
|
-
|
|
325
|
+
*sampleDataPtrs = CollectMatrixRowPointers<float>(*storage, floatFeatureCountLocal);
|
|
326
|
+
}
|
|
176
327
|
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
328
|
+
static void GetTextFeaturesData(
|
|
329
|
+
const uint32_t sampleCount,
|
|
330
|
+
const Napi::Value& textFeatures, // array or empty
|
|
331
|
+
uint32_t* textFeatureCount,
|
|
332
|
+
std::vector<std::string>* storage,
|
|
333
|
+
std::vector<const char*>* dataPtrsStorage,
|
|
334
|
+
std::vector<const char**>* sampleDataPtrs
|
|
335
|
+
) {
|
|
336
|
+
storage->clear();
|
|
337
|
+
dataPtrsStorage->clear();
|
|
338
|
+
sampleDataPtrs->clear();
|
|
339
|
+
if (textFeatures.IsEmpty()) {
|
|
340
|
+
*textFeatureCount = 0;
|
|
341
|
+
} else {
|
|
342
|
+
const Napi::Array textFeaturesArray = textFeatures.As<Napi::Array>();
|
|
343
|
+
const uint32_t textFeatureCountLocal = textFeaturesArray[0u].As<Napi::Array>().Length();
|
|
344
|
+
*textFeatureCount = textFeatureCountLocal;
|
|
345
|
+
|
|
346
|
+
storage->reserve(textFeatureCountLocal * sampleCount);
|
|
347
|
+
dataPtrsStorage->reserve(textFeatureCountLocal * sampleCount);
|
|
348
|
+
|
|
349
|
+
for (uint32_t i = 0; i < sampleCount; ++i) {
|
|
350
|
+
const Napi::Array row = textFeaturesArray[i].As<Napi::Array>();
|
|
351
|
+
for (uint32_t j = 0; j < textFeatureCountLocal; ++j) {
|
|
352
|
+
storage->push_back(row[j].As<Napi::String>().Utf8Value());
|
|
353
|
+
dataPtrsStorage->push_back(storage->back().c_str());
|
|
354
|
+
}
|
|
181
355
|
}
|
|
356
|
+
|
|
357
|
+
*sampleDataPtrs = CollectMatrixRowPointers<const char*, const char**>(
|
|
358
|
+
*dataPtrsStorage,
|
|
359
|
+
textFeatureCountLocal
|
|
360
|
+
);
|
|
182
361
|
}
|
|
362
|
+
}
|
|
183
363
|
|
|
184
|
-
|
|
185
|
-
|
|
364
|
+
static bool GetEmbeddingFeaturesData(
|
|
365
|
+
Napi::Env env,
|
|
366
|
+
const uint32_t sampleCount,
|
|
367
|
+
const Napi::Value& embeddingFeatures, // array or empty
|
|
368
|
+
uint32_t* embeddingFeatureCount,
|
|
369
|
+
std::vector<size_t>* embeddingDimensions,
|
|
370
|
+
std::vector<float>* storage,
|
|
371
|
+
std::vector<const float*>* dataPtrsStorage,
|
|
372
|
+
std::vector<const float**>* sampleDataPtrs
|
|
373
|
+
) {
|
|
374
|
+
embeddingDimensions->clear();
|
|
375
|
+
storage->clear();
|
|
376
|
+
dataPtrsStorage->clear();
|
|
377
|
+
sampleDataPtrs->clear();
|
|
378
|
+
if (embeddingFeatures.IsEmpty()) {
|
|
379
|
+
*embeddingFeatureCount = 0;
|
|
380
|
+
} else {
|
|
381
|
+
const Napi::Array embeddingsFeaturesArray = embeddingFeatures.As<Napi::Array>();
|
|
382
|
+
const uint32_t embeddingFeatureCountLocal = embeddingsFeaturesArray[0u].As<Napi::Array>().Length();
|
|
383
|
+
*embeddingFeatureCount = embeddingFeatureCountLocal;
|
|
384
|
+
|
|
385
|
+
embeddingDimensions->reserve(embeddingFeatureCountLocal);
|
|
386
|
+
// this is a lower bound, final allocation is delayed until the first sample is processed and
|
|
387
|
+
// embedding dimensions become known
|
|
388
|
+
storage->reserve(embeddingFeatureCountLocal * sampleCount);
|
|
389
|
+
dataPtrsStorage->reserve(embeddingFeatureCountLocal * sampleCount);
|
|
390
|
+
|
|
391
|
+
size_t perSampleValuesSize = 0;
|
|
392
|
+
|
|
393
|
+
for (uint32_t i = 0; i < sampleCount; ++i) {
|
|
394
|
+
const Napi::Array row = embeddingsFeaturesArray[i].As<Napi::Array>();
|
|
395
|
+
for (uint32_t j = 0; j < embeddingFeatureCountLocal; ++j) {
|
|
396
|
+
const Napi::Array embeddingValues = row[j].As<Napi::Array>();
|
|
397
|
+
auto embeddingSize = embeddingValues.Length();
|
|
398
|
+
if (i == 0) {
|
|
399
|
+
embeddingDimensions->push_back(embeddingSize);
|
|
400
|
+
} else {
|
|
401
|
+
if (!NHelper::Check(
|
|
402
|
+
env,
|
|
403
|
+
(*embeddingDimensions)[j] == embeddingSize,
|
|
404
|
+
"Embedding values arrays have different lengths"
|
|
405
|
+
))
|
|
406
|
+
{
|
|
407
|
+
return false;
|
|
408
|
+
}
|
|
409
|
+
}
|
|
410
|
+
|
|
411
|
+
for (uint32_t k = 0; k < embeddingSize; ++k) {
|
|
412
|
+
storage->push_back(embeddingValues[k].As<Napi::Number>().FloatValue());
|
|
413
|
+
}
|
|
414
|
+
// can't update dataPtrsStorage just yet as it is not reserved to final size
|
|
415
|
+
}
|
|
416
|
+
if (i == 0) {
|
|
417
|
+
perSampleValuesSize = storage->size();
|
|
418
|
+
storage->reserve(perSampleValuesSize * sampleCount);
|
|
419
|
+
}
|
|
420
|
+
const float* dataPtr = storage->data() + perSampleValuesSize * i;
|
|
421
|
+
for (uint32_t j = 0; j < embeddingFeatureCountLocal; ++j) {
|
|
422
|
+
dataPtrsStorage->push_back(dataPtr);
|
|
423
|
+
dataPtr += (*embeddingDimensions)[j];
|
|
424
|
+
}
|
|
425
|
+
}
|
|
186
426
|
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
catPtrs.data(), catFeaturesSize,
|
|
193
|
-
resultValues.data(), docsCount));
|
|
427
|
+
*sampleDataPtrs = CollectMatrixRowPointers<const float*, const float**>(
|
|
428
|
+
*dataPtrsStorage,
|
|
429
|
+
embeddingFeatureCountLocal
|
|
430
|
+
);
|
|
431
|
+
}
|
|
194
432
|
|
|
195
|
-
return
|
|
433
|
+
return true;
|
|
196
434
|
}
|
|
197
435
|
|
|
198
|
-
Napi::Array TModel::CalcPredictionString(Napi::Env env,
|
|
199
|
-
const TVector<float>& floatFeatures,
|
|
200
|
-
const Napi::Array& catFeatures) {
|
|
201
|
-
const uint32_t docsCount = catFeatures.Length();
|
|
202
|
-
const uint32_t catFeaturesSize = catFeatures[0u].As<Napi::Array>().Length();
|
|
203
|
-
const uint32_t floatFeaturesSize = floatFeatures.size() / docsCount;
|
|
204
436
|
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
437
|
+
Napi::Array TModel::CalcPredictionWithCatFeaturesAsHashes(
|
|
438
|
+
Napi::Env env,
|
|
439
|
+
const uint32_t sampleCount,
|
|
440
|
+
const Napi::Array& floatFeatures,
|
|
441
|
+
const Napi::Value& catFeatures,
|
|
442
|
+
const Napi::Value& textFeatures,
|
|
443
|
+
const Napi::Value& embeddingFeatures
|
|
444
|
+
) {
|
|
445
|
+
uint32_t floatFeaturesSize = 0;
|
|
446
|
+
std::vector<float> floatFeaturesStorage;
|
|
447
|
+
std::vector<const float*> floatPtrs;
|
|
448
|
+
|
|
449
|
+
GetNumericFeaturesData(sampleCount, floatFeatures, &floatFeaturesSize, &floatFeaturesStorage, &floatPtrs);
|
|
209
450
|
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
451
|
+
|
|
452
|
+
uint32_t catFeaturesSize = 0;
|
|
453
|
+
std::vector<int> catHashValues;
|
|
454
|
+
std::vector<const int*> catPtrs;
|
|
455
|
+
|
|
456
|
+
if (!catFeatures.IsEmpty()) {
|
|
457
|
+
const Napi::Array catFeaturesArray = catFeatures.As<Napi::Array>();
|
|
458
|
+
catFeaturesSize = catFeaturesArray[0u].As<Napi::Array>().Length();
|
|
459
|
+
|
|
460
|
+
catHashValues.reserve(catFeaturesSize * sampleCount);
|
|
461
|
+
|
|
462
|
+
for (uint32_t i = 0; i < sampleCount; ++i) {
|
|
463
|
+
const Napi::Array row = catFeaturesArray[i].As<Napi::Array>();
|
|
464
|
+
for (uint32_t j = 0; j < catFeaturesSize; ++j) {
|
|
465
|
+
catHashValues.push_back(row[j].As<Napi::Number>().Int32Value());
|
|
466
|
+
}
|
|
467
|
+
}
|
|
468
|
+
|
|
469
|
+
catPtrs = CollectMatrixRowPointers<int>(catHashValues, catFeaturesSize);
|
|
470
|
+
}
|
|
471
|
+
|
|
472
|
+
|
|
473
|
+
uint32_t textFeaturesSize = 0;
|
|
474
|
+
std::vector<std::string> textFeaturesStorage;
|
|
475
|
+
std::vector<const char*> textFeaturesDataPtrsStorage;
|
|
476
|
+
std::vector<const char**> textFeaturesSampleDataPtrs;
|
|
477
|
+
|
|
478
|
+
GetTextFeaturesData(
|
|
479
|
+
sampleCount,
|
|
480
|
+
textFeatures,
|
|
481
|
+
&textFeaturesSize,
|
|
482
|
+
&textFeaturesStorage,
|
|
483
|
+
&textFeaturesDataPtrsStorage,
|
|
484
|
+
&textFeaturesSampleDataPtrs
|
|
485
|
+
);
|
|
486
|
+
|
|
487
|
+
|
|
488
|
+
uint32_t embeddingFeaturesSize = 0;
|
|
489
|
+
std::vector<size_t> embeddingDimensions;
|
|
490
|
+
std::vector<float> embeddingFeaturesStorage;
|
|
491
|
+
std::vector<const float*> embeddingFeaturesDataPtrsStorage;
|
|
492
|
+
std::vector<const float**> embeddingFeaturesSampleDataPtrs;
|
|
493
|
+
|
|
494
|
+
if (!NHelper::Check(
|
|
495
|
+
env,
|
|
496
|
+
GetEmbeddingFeaturesData(
|
|
497
|
+
env,
|
|
498
|
+
sampleCount,
|
|
499
|
+
embeddingFeatures,
|
|
500
|
+
&embeddingFeaturesSize,
|
|
501
|
+
&embeddingDimensions,
|
|
502
|
+
&embeddingFeaturesStorage,
|
|
503
|
+
&embeddingFeaturesDataPtrsStorage,
|
|
504
|
+
&embeddingFeaturesSampleDataPtrs
|
|
505
|
+
),
|
|
506
|
+
"Failed to get embedding features data"
|
|
507
|
+
))
|
|
508
|
+
{
|
|
509
|
+
return Napi::Array::New(env);
|
|
510
|
+
}
|
|
511
|
+
|
|
512
|
+
|
|
513
|
+
const auto predictionDimensions = ::GetPredictionDimensionsCount(this->Handle);
|
|
514
|
+
std::vector<double> resultValues;
|
|
515
|
+
resultValues.resize(sampleCount * predictionDimensions);
|
|
516
|
+
|
|
517
|
+
NHelper::CheckStatus(
|
|
518
|
+
env,
|
|
519
|
+
CalcModelPredictionWithHashedCatFeaturesAndTextAndEmbeddingFeatures(
|
|
520
|
+
this->Handle,
|
|
521
|
+
sampleCount,
|
|
522
|
+
floatPtrs.data(), floatFeaturesSize,
|
|
523
|
+
catPtrs.data(), catFeaturesSize,
|
|
524
|
+
textFeaturesSampleDataPtrs.data(), textFeaturesSize,
|
|
525
|
+
embeddingFeaturesSampleDataPtrs.data(), embeddingDimensions.data(), embeddingFeaturesSize,
|
|
526
|
+
resultValues.data(), resultValues.size()
|
|
527
|
+
)
|
|
528
|
+
);
|
|
529
|
+
|
|
530
|
+
return NHelper::ConvertToArray(env, resultValues);
|
|
531
|
+
}
|
|
532
|
+
|
|
533
|
+
Napi::Array TModel::CalcPredictionWithCatFeaturesAsStrings(
|
|
534
|
+
Napi::Env env,
|
|
535
|
+
const uint32_t sampleCount,
|
|
536
|
+
const Napi::Array& floatFeatures,
|
|
537
|
+
const Napi::Value& catFeatures,
|
|
538
|
+
const Napi::Value& textFeatures,
|
|
539
|
+
const Napi::Value& embeddingFeatures
|
|
540
|
+
) {
|
|
541
|
+
uint32_t floatFeaturesSize = 0;
|
|
542
|
+
std::vector<float> floatFeaturesStorage;
|
|
543
|
+
std::vector<const float*> floatPtrs;
|
|
544
|
+
|
|
545
|
+
GetNumericFeaturesData(sampleCount, floatFeatures, &floatFeaturesSize, &floatFeaturesStorage, &floatPtrs);
|
|
546
|
+
|
|
547
|
+
uint32_t catFeaturesSize = 0;
|
|
548
|
+
std::vector<std::string> catStrings;
|
|
549
|
+
std::vector<const char*> catStringValues;
|
|
550
|
+
std::vector<const char**> catPtrs;
|
|
551
|
+
|
|
552
|
+
if (!catFeatures.IsEmpty()) {
|
|
553
|
+
const Napi::Array catFeaturesArray = catFeatures.As<Napi::Array>();
|
|
554
|
+
catFeaturesSize = catFeaturesArray[0u].As<Napi::Array>().Length();
|
|
555
|
+
|
|
556
|
+
catStrings.reserve(catFeaturesSize * sampleCount);
|
|
557
|
+
catStringValues.reserve(catFeaturesSize * sampleCount);
|
|
558
|
+
|
|
559
|
+
for (uint32_t i = 0; i < sampleCount; ++i) {
|
|
560
|
+
const Napi::Array row = catFeaturesArray[i].As<Napi::Array>();
|
|
561
|
+
for (uint32_t j = 0; j < catFeaturesSize; ++j) {
|
|
562
|
+
catStrings.push_back(row[j].As<Napi::String>().Utf8Value());
|
|
563
|
+
catStringValues.push_back(catStrings.back().c_str());
|
|
564
|
+
}
|
|
215
565
|
}
|
|
566
|
+
catPtrs = CollectMatrixRowPointers<const char*, const char**>(
|
|
567
|
+
catStringValues,
|
|
568
|
+
catFeaturesSize
|
|
569
|
+
);
|
|
216
570
|
}
|
|
217
571
|
|
|
218
|
-
|
|
219
|
-
|
|
572
|
+
uint32_t textFeaturesSize = 0;
|
|
573
|
+
std::vector<std::string> textFeaturesStorage;
|
|
574
|
+
std::vector<const char*> textFeaturesDataPtrsStorage;
|
|
575
|
+
std::vector<const char**> textFeaturesSampleDataPtrs;
|
|
576
|
+
|
|
577
|
+
GetTextFeaturesData(
|
|
578
|
+
sampleCount,
|
|
579
|
+
textFeatures,
|
|
580
|
+
&textFeaturesSize,
|
|
581
|
+
&textFeaturesStorage,
|
|
582
|
+
&textFeaturesDataPtrsStorage,
|
|
583
|
+
&textFeaturesSampleDataPtrs
|
|
584
|
+
);
|
|
585
|
+
|
|
586
|
+
|
|
587
|
+
uint32_t embeddingFeaturesSize = 0;
|
|
588
|
+
std::vector<size_t> embeddingDimensions;
|
|
589
|
+
std::vector<float> embeddingFeaturesStorage;
|
|
590
|
+
std::vector<const float*> embeddingFeaturesDataPtrsStorage;
|
|
591
|
+
std::vector<const float**> embeddingFeaturesSampleDataPtrs;
|
|
592
|
+
|
|
593
|
+
if (!NHelper::Check(
|
|
594
|
+
env,
|
|
595
|
+
GetEmbeddingFeaturesData(
|
|
596
|
+
env,
|
|
597
|
+
sampleCount,
|
|
598
|
+
embeddingFeatures,
|
|
599
|
+
&embeddingFeaturesSize,
|
|
600
|
+
&embeddingDimensions,
|
|
601
|
+
&embeddingFeaturesStorage,
|
|
602
|
+
&embeddingFeaturesDataPtrsStorage,
|
|
603
|
+
&embeddingFeaturesSampleDataPtrs
|
|
604
|
+
),
|
|
605
|
+
"Failed to get embedding features data"
|
|
606
|
+
))
|
|
607
|
+
{
|
|
608
|
+
return Napi::Array::New(env);
|
|
609
|
+
}
|
|
220
610
|
|
|
221
|
-
TVector<const float*> floatPtrs = CollectMatrixRowPointers<float>(floatFeatures, floatFeaturesSize);
|
|
222
|
-
TVector<const char**> catPtrs = CollectMatrixRowPointers<const char*, const char**>(catStringValues, catFeaturesSize);
|
|
223
611
|
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
612
|
+
const auto predictionDimensions = ::GetPredictionDimensionsCount(this->Handle);
|
|
613
|
+
std::vector<double> resultValues;
|
|
614
|
+
resultValues.resize(sampleCount * predictionDimensions);
|
|
615
|
+
|
|
616
|
+
if (!NHelper::CheckStatus(
|
|
617
|
+
env,
|
|
618
|
+
CalcModelPredictionTextAndEmbeddings(
|
|
619
|
+
this->Handle,
|
|
620
|
+
sampleCount,
|
|
621
|
+
floatPtrs.data(), floatFeaturesSize,
|
|
622
|
+
catPtrs.data(), catFeaturesSize,
|
|
623
|
+
textFeaturesSampleDataPtrs.data(), textFeaturesSize,
|
|
624
|
+
embeddingFeaturesSampleDataPtrs.data(), embeddingDimensions.data(), embeddingFeaturesSize,
|
|
625
|
+
resultValues.data(), resultValues.size()
|
|
626
|
+
)
|
|
627
|
+
))
|
|
628
|
+
{
|
|
229
629
|
return Napi::Array::New(env);
|
|
230
630
|
}
|
|
231
631
|
|