catboost 1.25.1 → 1.26.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (289) hide show
  1. package/DEPLOYMENT.md +22 -15
  2. package/README.md +37 -27
  3. package/binding.gyp +5 -7
  4. package/build_scripts/bootstrap.js +2 -1
  5. package/build_scripts/out/build.js +46 -68
  6. package/build_scripts/out/build_model.js +1 -1
  7. package/build_scripts/out/{build_ya.js → build_native.js} +1 -1
  8. package/build_scripts/out/ci.js +5 -5
  9. package/build_scripts/out/config.js +32 -18
  10. package/build_scripts/out/install.js +5 -3
  11. package/build_scripts/out/package_prepublish.js +1 -1
  12. package/build_scripts/out/packaging.js +1 -19
  13. package/build_scripts/out/run_tests.js +1 -1
  14. package/build_scripts/out/test.js +8 -3
  15. package/config.json +18 -11
  16. package/inc/catboost/libs/model_interface/c_api.h +349 -3
  17. package/lib/catboost.d.ts +65 -21
  18. package/package.json +4 -4
  19. package/src/api_helpers.cpp +100 -24
  20. package/src/api_helpers.h +8 -7
  21. package/src/api_module.cpp +1 -2
  22. package/src/model.cpp +483 -83
  23. package/src/model.h +24 -9
  24. package/inc/contrib/libs/cxxsupp/system_stl/include/stlfwd +0 -14
  25. package/inc/util/charset/recode_result.h +0 -9
  26. package/inc/util/charset/unicode_table.h +0 -123
  27. package/inc/util/charset/unidata.h +0 -421
  28. package/inc/util/charset/utf8.h +0 -384
  29. package/inc/util/charset/wide.h +0 -843
  30. package/inc/util/charset/wide_specific.h +0 -22
  31. package/inc/util/datetime/base.h +0 -669
  32. package/inc/util/datetime/constants.h +0 -7
  33. package/inc/util/datetime/cputimer.h +0 -124
  34. package/inc/util/datetime/parser.h +0 -292
  35. package/inc/util/datetime/systime.h +0 -47
  36. package/inc/util/datetime/uptime.h +0 -8
  37. package/inc/util/digest/city.h +0 -88
  38. package/inc/util/digest/fnv.h +0 -73
  39. package/inc/util/digest/multi.h +0 -14
  40. package/inc/util/digest/murmur.h +0 -57
  41. package/inc/util/digest/numeric.h +0 -86
  42. package/inc/util/digest/sequence.h +0 -48
  43. package/inc/util/draft/date.h +0 -129
  44. package/inc/util/draft/datetime.h +0 -184
  45. package/inc/util/draft/enum.h +0 -136
  46. package/inc/util/draft/holder_vector.h +0 -102
  47. package/inc/util/draft/ip.h +0 -131
  48. package/inc/util/draft/matrix.h +0 -108
  49. package/inc/util/draft/memory.h +0 -40
  50. package/inc/util/folder/dirent_win.h +0 -46
  51. package/inc/util/folder/dirut.h +0 -121
  52. package/inc/util/folder/filelist.h +0 -81
  53. package/inc/util/folder/fts.h +0 -108
  54. package/inc/util/folder/iterator.h +0 -109
  55. package/inc/util/folder/lstat_win.h +0 -20
  56. package/inc/util/folder/path.h +0 -225
  57. package/inc/util/folder/pathsplit.h +0 -113
  58. package/inc/util/folder/tempdir.h +0 -42
  59. package/inc/util/generic/adaptor.h +0 -134
  60. package/inc/util/generic/algorithm.h +0 -765
  61. package/inc/util/generic/array_ref.h +0 -282
  62. package/inc/util/generic/array_size.h +0 -24
  63. package/inc/util/generic/benchmark/vector_count_ctor/f.h +0 -9
  64. package/inc/util/generic/bitmap.h +0 -1115
  65. package/inc/util/generic/bitops.h +0 -459
  66. package/inc/util/generic/bt_exception.h +0 -24
  67. package/inc/util/generic/buffer.h +0 -232
  68. package/inc/util/generic/cast.h +0 -176
  69. package/inc/util/generic/deque.h +0 -24
  70. package/inc/util/generic/explicit_type.h +0 -42
  71. package/inc/util/generic/fastqueue.h +0 -55
  72. package/inc/util/generic/flags.h +0 -244
  73. package/inc/util/generic/function.h +0 -103
  74. package/inc/util/generic/fwd.h +0 -171
  75. package/inc/util/generic/guid.h +0 -61
  76. package/inc/util/generic/hash.h +0 -2032
  77. package/inc/util/generic/hash_primes.h +0 -140
  78. package/inc/util/generic/hash_set.h +0 -490
  79. package/inc/util/generic/hide_ptr.h +0 -3
  80. package/inc/util/generic/intrlist.h +0 -876
  81. package/inc/util/generic/is_in.h +0 -53
  82. package/inc/util/generic/iterator.h +0 -137
  83. package/inc/util/generic/iterator_range.h +0 -105
  84. package/inc/util/generic/lazy_value.h +0 -66
  85. package/inc/util/generic/list.h +0 -22
  86. package/inc/util/generic/map.h +0 -44
  87. package/inc/util/generic/mapfindptr.h +0 -60
  88. package/inc/util/generic/maybe.h +0 -713
  89. package/inc/util/generic/maybe_traits.h +0 -164
  90. package/inc/util/generic/mem_copy.h +0 -55
  91. package/inc/util/generic/noncopyable.h +0 -38
  92. package/inc/util/generic/object_counter.h +0 -53
  93. package/inc/util/generic/ptr.h +0 -1113
  94. package/inc/util/generic/queue.h +0 -57
  95. package/inc/util/generic/refcount.h +0 -162
  96. package/inc/util/generic/reserve.h +0 -11
  97. package/inc/util/generic/scope.h +0 -65
  98. package/inc/util/generic/serialized_enum.h +0 -406
  99. package/inc/util/generic/set.h +0 -42
  100. package/inc/util/generic/singleton.h +0 -136
  101. package/inc/util/generic/size_literals.h +0 -65
  102. package/inc/util/generic/stack.h +0 -18
  103. package/inc/util/generic/store_policy.h +0 -120
  104. package/inc/util/generic/strbase.h +0 -612
  105. package/inc/util/generic/strbuf.h +0 -552
  106. package/inc/util/generic/strfcpy.h +0 -17
  107. package/inc/util/generic/string.h +0 -1572
  108. package/inc/util/generic/string_hash.h +0 -21
  109. package/inc/util/generic/string_ut.h +0 -1175
  110. package/inc/util/generic/type_name.h +0 -34
  111. package/inc/util/generic/typelist.h +0 -114
  112. package/inc/util/generic/typetraits.h +0 -325
  113. package/inc/util/generic/utility.h +0 -132
  114. package/inc/util/generic/va_args.h +0 -400
  115. package/inc/util/generic/variant.h +0 -631
  116. package/inc/util/generic/variant_traits.h +0 -171
  117. package/inc/util/generic/vector.h +0 -119
  118. package/inc/util/generic/xrange.h +0 -258
  119. package/inc/util/generic/yexception.h +0 -212
  120. package/inc/util/generic/yexception_ut.h +0 -14
  121. package/inc/util/generic/ylimits.h +0 -92
  122. package/inc/util/generic/ymath.h +0 -206
  123. package/inc/util/memory/addstorage.h +0 -93
  124. package/inc/util/memory/alloc.h +0 -27
  125. package/inc/util/memory/blob.h +0 -296
  126. package/inc/util/memory/mmapalloc.h +0 -8
  127. package/inc/util/memory/pool.h +0 -432
  128. package/inc/util/memory/segmented_string_pool.h +0 -194
  129. package/inc/util/memory/segpool_alloc.h +0 -118
  130. package/inc/util/memory/smallobj.h +0 -141
  131. package/inc/util/memory/tempbuf.h +0 -111
  132. package/inc/util/network/address.h +0 -136
  133. package/inc/util/network/endpoint.h +0 -61
  134. package/inc/util/network/hostip.h +0 -16
  135. package/inc/util/network/init.h +0 -60
  136. package/inc/util/network/interface.h +0 -17
  137. package/inc/util/network/iovec.h +0 -65
  138. package/inc/util/network/ip.h +0 -116
  139. package/inc/util/network/nonblock.h +0 -8
  140. package/inc/util/network/pair.h +0 -9
  141. package/inc/util/network/poller.h +0 -58
  142. package/inc/util/network/pollerimpl.h +0 -707
  143. package/inc/util/network/sock.h +0 -608
  144. package/inc/util/network/socket.h +0 -421
  145. package/inc/util/random/common_ops.h +0 -130
  146. package/inc/util/random/easy.h +0 -47
  147. package/inc/util/random/entropy.h +0 -21
  148. package/inc/util/random/fast.h +0 -101
  149. package/inc/util/random/init_atfork.h +0 -3
  150. package/inc/util/random/lcg_engine.h +0 -66
  151. package/inc/util/random/mersenne.h +0 -46
  152. package/inc/util/random/mersenne32.h +0 -50
  153. package/inc/util/random/mersenne64.h +0 -50
  154. package/inc/util/random/normal.h +0 -38
  155. package/inc/util/random/random.h +0 -30
  156. package/inc/util/random/shuffle.h +0 -39
  157. package/inc/util/str_stl.h +0 -266
  158. package/inc/util/stream/aligned.h +0 -99
  159. package/inc/util/stream/buffer.h +0 -119
  160. package/inc/util/stream/buffered.h +0 -225
  161. package/inc/util/stream/debug.h +0 -53
  162. package/inc/util/stream/direct_io.h +0 -43
  163. package/inc/util/stream/file.h +0 -108
  164. package/inc/util/stream/format.h +0 -444
  165. package/inc/util/stream/fwd.h +0 -100
  166. package/inc/util/stream/hex.h +0 -8
  167. package/inc/util/stream/holder.h +0 -44
  168. package/inc/util/stream/input.h +0 -273
  169. package/inc/util/stream/labeled.h +0 -19
  170. package/inc/util/stream/length.h +0 -100
  171. package/inc/util/stream/mem.h +0 -255
  172. package/inc/util/stream/multi.h +0 -32
  173. package/inc/util/stream/null.h +0 -61
  174. package/inc/util/stream/output.h +0 -304
  175. package/inc/util/stream/pipe.h +0 -112
  176. package/inc/util/stream/printf.h +0 -25
  177. package/inc/util/stream/str.h +0 -207
  178. package/inc/util/stream/tee.h +0 -28
  179. package/inc/util/stream/tempbuf.h +0 -21
  180. package/inc/util/stream/tokenizer.h +0 -214
  181. package/inc/util/stream/trace.h +0 -60
  182. package/inc/util/stream/walk.h +0 -35
  183. package/inc/util/stream/zerocopy.h +0 -91
  184. package/inc/util/stream/zerocopy_output.h +0 -57
  185. package/inc/util/stream/zlib.h +0 -173
  186. package/inc/util/string/ascii.h +0 -236
  187. package/inc/util/string/builder.h +0 -39
  188. package/inc/util/string/cast.h +0 -347
  189. package/inc/util/string/cstriter.h +0 -14
  190. package/inc/util/string/escape.h +0 -70
  191. package/inc/util/string/hex.h +0 -59
  192. package/inc/util/string/join.h +0 -194
  193. package/inc/util/string/printf.h +0 -13
  194. package/inc/util/string/reverse.h +0 -16
  195. package/inc/util/string/split.h +0 -1080
  196. package/inc/util/string/strip.h +0 -257
  197. package/inc/util/string/strspn.h +0 -65
  198. package/inc/util/string/subst.h +0 -56
  199. package/inc/util/string/type.h +0 -50
  200. package/inc/util/string/util.h +0 -195
  201. package/inc/util/string/vector.h +0 -132
  202. package/inc/util/system/align.h +0 -50
  203. package/inc/util/system/atexit.h +0 -22
  204. package/inc/util/system/atomic.h +0 -51
  205. package/inc/util/system/atomic_gcc.h +0 -90
  206. package/inc/util/system/atomic_ops.h +0 -189
  207. package/inc/util/system/atomic_win.h +0 -114
  208. package/inc/util/system/backtrace.h +0 -39
  209. package/inc/util/system/byteorder.h +0 -186
  210. package/inc/util/system/compat.h +0 -84
  211. package/inc/util/system/compiler.h +0 -620
  212. package/inc/util/system/condvar.h +0 -71
  213. package/inc/util/system/context.h +0 -181
  214. package/inc/util/system/context_aarch64.h +0 -8
  215. package/inc/util/system/context_i686.h +0 -9
  216. package/inc/util/system/context_x86.h +0 -12
  217. package/inc/util/system/context_x86_64.h +0 -7
  218. package/inc/util/system/cpu_id.h +0 -159
  219. package/inc/util/system/daemon.h +0 -28
  220. package/inc/util/system/datetime.h +0 -98
  221. package/inc/util/system/defaults.h +0 -149
  222. package/inc/util/system/demangle.h +0 -5
  223. package/inc/util/system/demangle_impl.h +0 -23
  224. package/inc/util/system/direct_io.h +0 -71
  225. package/inc/util/system/dynlib.h +0 -119
  226. package/inc/util/system/env.h +0 -32
  227. package/inc/util/system/error.h +0 -95
  228. package/inc/util/system/event.h +0 -122
  229. package/inc/util/system/execpath.h +0 -17
  230. package/inc/util/system/fasttime.h +0 -6
  231. package/inc/util/system/fhandle.h +0 -27
  232. package/inc/util/system/file.h +0 -210
  233. package/inc/util/system/file_lock.h +0 -34
  234. package/inc/util/system/filemap.h +0 -383
  235. package/inc/util/system/flock.h +0 -35
  236. package/inc/util/system/fs.h +0 -156
  237. package/inc/util/system/fs_win.h +0 -29
  238. package/inc/util/system/fstat.h +0 -46
  239. package/inc/util/system/getpid.h +0 -12
  240. package/inc/util/system/guard.h +0 -179
  241. package/inc/util/system/hi_lo.h +0 -139
  242. package/inc/util/system/hostname.h +0 -10
  243. package/inc/util/system/hp_timer.h +0 -36
  244. package/inc/util/system/info.h +0 -12
  245. package/inc/util/system/interrupt_signals.h +0 -22
  246. package/inc/util/system/madvise.h +0 -30
  247. package/inc/util/system/maxlen.h +0 -32
  248. package/inc/util/system/mem_info.h +0 -18
  249. package/inc/util/system/mincore.h +0 -38
  250. package/inc/util/system/mktemp.h +0 -11
  251. package/inc/util/system/mlock.h +0 -43
  252. package/inc/util/system/mutex.h +0 -67
  253. package/inc/util/system/nice.h +0 -3
  254. package/inc/util/system/pipe.h +0 -90
  255. package/inc/util/system/platform.h +0 -246
  256. package/inc/util/system/progname.h +0 -13
  257. package/inc/util/system/protect.h +0 -25
  258. package/inc/util/system/rusage.h +0 -26
  259. package/inc/util/system/rwlock.h +0 -78
  260. package/inc/util/system/sanitizers.h +0 -122
  261. package/inc/util/system/sem.h +0 -41
  262. package/inc/util/system/shellcommand.h +0 -472
  263. package/inc/util/system/shmat.h +0 -32
  264. package/inc/util/system/sigset.h +0 -78
  265. package/inc/util/system/spin_wait.h +0 -10
  266. package/inc/util/system/spinlock.h +0 -121
  267. package/inc/util/system/src_location.h +0 -25
  268. package/inc/util/system/src_root.h +0 -68
  269. package/inc/util/system/sys_alloc.h +0 -43
  270. package/inc/util/system/sysstat.h +0 -52
  271. package/inc/util/system/tempfile.h +0 -34
  272. package/inc/util/system/thread.h +0 -167
  273. package/inc/util/system/tls.h +0 -307
  274. package/inc/util/system/types.h +0 -119
  275. package/inc/util/system/unaligned_mem.h +0 -67
  276. package/inc/util/system/user.h +0 -5
  277. package/inc/util/system/utime.h +0 -6
  278. package/inc/util/system/valgrind.h +0 -48
  279. package/inc/util/system/winint.h +0 -43
  280. package/inc/util/system/yassert.h +0 -121
  281. package/inc/util/system/yield.h +0 -4
  282. package/inc/util/thread/factory.h +0 -65
  283. package/inc/util/thread/fwd.h +0 -30
  284. package/inc/util/thread/lfqueue.h +0 -406
  285. package/inc/util/thread/lfstack.h +0 -188
  286. package/inc/util/thread/pool.h +0 -388
  287. package/inc/util/thread/singleton.h +0 -42
  288. package/inc/util/ysafeptr.h +0 -427
  289. package/inc/util/ysaveload.h +0 -700
package/src/model.cpp CHANGED
@@ -2,12 +2,15 @@
2
2
 
3
3
  #include "api_helpers.h"
4
4
 
5
+ #include <vector>
6
+
7
+
5
8
  namespace {
6
9
 
7
10
  // Collect pointers to matrix rows into a vector.
8
- template <typename T, typename V = const T*, typename C = const TVector<T>>
9
- TVector<V> CollectMatrixRowPointers(C& matrix, uint32_t rowLength) {
10
- TVector<V> pointers;
11
+ template <typename T, typename V = const T*, typename C = const std::vector<T>>
12
+ std::vector<V> CollectMatrixRowPointers(C& matrix, uint32_t rowLength) {
13
+ std::vector<V> pointers;
11
14
  for (uint32_t i = 0; i < matrix.size(); i += rowLength) {
12
15
  pointers.push_back(matrix.data() + i);
13
16
  }
@@ -30,10 +33,12 @@ TModel::TModel(const Napi::CallbackInfo& info): Napi::ObjectWrap<TModel>(info) {
30
33
  }
31
34
 
32
35
  NHelper::Check(env, info[0].IsString(), "File name argument should be a string");
33
- const bool status = LoadFullModelFromFile(this->Handle,
34
- info[0].As<Napi::String>().Utf8Value().c_str());
36
+ const bool status = LoadFullModelFromFile(
37
+ this->Handle,
38
+ info[0].As<Napi::String>().Utf8Value().c_str()
39
+ );
35
40
  // Even if it fails, this check schedules NodeJS exception, not C++ one.
36
- // The C++ object is considered to be successfully created and will be destoryed by Node runtime
41
+ // The C++ object is considered to be successfully created and will be destroyed by Node runtime
37
42
  // later as usual.
38
43
  NHelper::CheckStatus(env, status);
39
44
  if (status) {
@@ -52,10 +57,14 @@ Napi::Function TModel::GetClass(Napi::Env env) {
52
57
  TModel::InstanceMethod("loadModel", &TModel::LoadFullFromFile),
53
58
  TModel::InstanceMethod("predict", &TModel::CalcPrediction),
54
59
  TModel::InstanceMethod("enableGPUEvaluation", &TModel::EvaluateOnGPU),
60
+ TModel::InstanceMethod("setPredictionType", &TModel::SetPredictionType),
55
61
  TModel::InstanceMethod("getFloatFeaturesCount", &TModel::GetModelFloatFeaturesCount),
56
62
  TModel::InstanceMethod("getCatFeaturesCount", &TModel::GetModelCatFeaturesCount),
63
+ TModel::InstanceMethod("getTextFeaturesCount", &TModel::GetModelTextFeaturesCount),
64
+ TModel::InstanceMethod("getEmbeddingFeaturesCount", &TModel::GetModelEmbeddingFeaturesCount),
57
65
  TModel::InstanceMethod("getTreeCount", &TModel::GetModelTreeCount),
58
66
  TModel::InstanceMethod("getDimensionsCount", &TModel::GetModelDimensionsCount),
67
+ TModel::InstanceMethod("getPredictionDimensionsCount", &TModel::GetPredictionDimensionsCount),
59
68
  });
60
69
  }
61
70
 
@@ -64,70 +73,178 @@ void TModel::LoadFullFromFile(const Napi::CallbackInfo& info) {
64
73
  Napi::Env env = info.Env();
65
74
 
66
75
  if (!NHelper::Check(env, info.Length() >= 1, "Wrong number of arguments") ||
67
- !NHelper::Check(env, info[0].IsString(), "File name string is required")) {
76
+ !NHelper::Check(env, info[0].IsString(), "File name string is required"))
77
+ {
68
78
  return;
69
79
  }
70
80
 
71
81
  NHelper::CheckNotNullHandle(env, this->Handle);
72
- const bool status = LoadFullModelFromFile(this->Handle,
73
- info[0].As<Napi::String>().Utf8Value().c_str());
82
+ const bool status = LoadFullModelFromFile(
83
+ this->Handle,
84
+ info[0].As<Napi::String>().Utf8Value().c_str()
85
+ );
74
86
  NHelper::CheckStatus(env, status);
75
87
  if (status) {
76
88
  this->ModelLoaded = true;
77
89
  }
78
90
  }
79
91
 
92
+ // Set model predictions postprocessing type.
93
+ void TModel::SetPredictionType(const Napi::CallbackInfo& info) {
94
+ Napi::Env env = info.Env();
95
+
96
+ if (!NHelper::Check(env, info.Length() >= 1, "Wrong number of arguments") ||
97
+ !NHelper::Check(env, info[0].IsString(), "predictionType argument must have string type"))
98
+ {
99
+ return;
100
+ }
101
+
102
+ NHelper::CheckNotNullHandle(env, this->Handle);
103
+ const bool status = SetPredictionTypeString(
104
+ this->Handle,
105
+ info[0].As<Napi::String>().Utf8Value().c_str()
106
+ );
107
+ NHelper::CheckStatus(env, status);
108
+ }
109
+
80
110
  Napi::Value TModel::CalcPrediction(const Napi::CallbackInfo& info) {
81
111
  Napi::Env env = info.Env();
112
+ if (!NHelper::Check(env, this->ModelLoaded, "Trying to predict from the empty model")) {
113
+ return env.Undefined();
114
+ }
115
+
116
+ if (!NHelper::Check(env, info.Length() >= 1, "Wrong number of arguments - expected at least 1")) {
117
+ return env.Undefined();
118
+ }
119
+
120
+
121
+ // Numerical features
82
122
 
83
- if (!NHelper::Check(env, info.Length() >= 2, "Wrong number of arguments - expected 2") ||
84
- !NHelper::Check(env, NHelper::IsMatrix(info[0], NHelper::NAT_NUMBER),
85
- "Expected the first argument to be a matrix of floats") ||
86
- !NHelper::Check(env, this->ModelLoaded, "Trying to predict from the empty model")) {
123
+ if (!NHelper::CheckIsMatrix(
124
+ env,
125
+ info[0],
126
+ NHelper::NAT_NUMBER,
127
+ "Expected the first argument to be a matrix of floats - "
128
+ ))
129
+ {
87
130
  return env.Undefined();
88
131
  }
89
132
 
90
133
  const Napi::Array floatFeatures = info[0].As<Napi::Array>();
91
- const uint32_t docsCount = floatFeatures.Length();
92
- if (docsCount == 0) {
134
+ const uint32_t sampleCount = floatFeatures.Length();
135
+ if (sampleCount == 0) {
93
136
  return Napi::Array::New(env);
94
137
  }
95
138
 
96
- const uint32_t floatFeaturesSize = floatFeatures[0u].As<Napi::Array>().Length();
97
139
 
98
- TVector<float> floatFeatureValues;
99
- floatFeatureValues.reserve(floatFeaturesSize * docsCount);
100
-
101
- for (uint32_t i = 0; i < docsCount; ++i) {
102
- const Napi::Array row = floatFeatures[i].As<Napi::Array>();
103
- for (uint32_t j = 0; j < floatFeaturesSize; ++j) {
104
- floatFeatureValues.push_back(row[j].As<Napi::Number>().FloatValue());
140
+ // Categorical features
141
+ Napi::Value catFeatures;
142
+ bool catFeaturesAreHashes = false;
143
+
144
+ if (info.Length() >= 2) {
145
+ catFeatures = info[1];
146
+ if (!NHelper::CheckIsMatrix(
147
+ env,
148
+ catFeatures,
149
+ NHelper::NAT_NUMBER_OR_STRING,
150
+ "Expected the second argument to be a matrix of strings or numbers - "
151
+ ))
152
+ {
153
+ return env.Undefined();
154
+ }
155
+ const Napi::Array catFeaturesArray = catFeatures.As<Napi::Array>();
156
+
157
+ if (!NHelper::Check(
158
+ env,
159
+ catFeaturesArray.Length() == sampleCount,
160
+ "Expected the number of samples to be the same for both float and categorical features"
161
+ ))
162
+ {
163
+ return env.Undefined();
164
+ }
165
+ if (sampleCount) {
166
+ const Napi::Array catRow = catFeaturesArray[0u].As<Napi::Array>();
167
+ if (catRow.Length()) {
168
+ catFeaturesAreHashes = catRow[0u].IsNumber();
169
+ }
105
170
  }
106
171
  }
107
172
 
108
- if (!NHelper::Check(env, NHelper::IsMatrix(info[1], NHelper::NAT_NUMBER) ||
109
- NHelper::IsMatrix(info[1], NHelper::NAT_STRING),
110
- "Expected second argument to be a matrix of strings or numbers")) {
111
- return env.Undefined();
173
+
174
+ // Text features
175
+ Napi::Value textFeatures;
176
+ if (info.Length() >= 3) {
177
+ textFeatures = info[2];
178
+ if (!NHelper::CheckIsMatrix(
179
+ env,
180
+ textFeatures,
181
+ NHelper::NAT_STRING,
182
+ "Expected the third argument to be a matrix of strings - "
183
+ ))
184
+ {
185
+ return env.Undefined();
186
+ }
187
+
188
+ if (!NHelper::Check(
189
+ env,
190
+ textFeatures.As<Napi::Array>().Length() == sampleCount,
191
+ "Expected the number of samples to be the same for both float and text features"
192
+ ))
193
+ {
194
+ return env.Undefined();
195
+ }
112
196
  }
113
- const Napi::Array catFeatures = info[1].As<Napi::Array>();
114
197
 
115
- if (!NHelper::Check(env, catFeatures.Length() == docsCount,
116
- "Expected the number of docs to be the same for both float and categorial features")) {
117
- return env.Undefined();
198
+ // Embedding features
199
+ Napi::Value embeddingFeatures;
200
+ if (info.Length() == 4) {
201
+ embeddingFeatures = info[3];
202
+ if (!NHelper::CheckIsMatrix(
203
+ env,
204
+ embeddingFeatures,
205
+ NHelper::NAT_ARRAY_OR_NUMBERS,
206
+ "Expected the fourth argument to be a matrix of arrays of numbers - "
207
+ ))
208
+ {
209
+ return env.Undefined();
210
+ }
211
+
212
+ if (!NHelper::Check(
213
+ env,
214
+ embeddingFeatures.As<Napi::Array>().Length() == sampleCount,
215
+ "Expected the number of samples to be the same for both float and embedding features"
216
+ ))
217
+ {
218
+ return env.Undefined();
219
+ }
118
220
  }
119
- const Napi::Array catRow = catFeatures[0u].As<Napi::Array>();
120
- if (catRow.Length() == 0 || catRow[0u].IsNumber()) {
121
- return CalcPredictionHash(env, floatFeatureValues, catFeatures);
221
+
222
+
223
+ if (catFeaturesAreHashes) {
224
+ return CalcPredictionWithCatFeaturesAsHashes(
225
+ env,
226
+ sampleCount,
227
+ floatFeatures,
228
+ catFeatures,
229
+ textFeatures,
230
+ embeddingFeatures
231
+ );
122
232
  }
123
- return CalcPredictionString(env, floatFeatureValues, catFeatures);
233
+ return CalcPredictionWithCatFeaturesAsStrings(
234
+ env,
235
+ sampleCount,
236
+ floatFeatures,
237
+ catFeatures,
238
+ textFeatures,
239
+ embeddingFeatures
240
+ );
124
241
  }
125
242
 
126
243
  void TModel::EvaluateOnGPU(const Napi::CallbackInfo& info) {
127
244
  Napi::Env env = info.Env();
128
245
  if (!NHelper::Check(env, info.Length() >= 1, "Wrong number of arguments - expected 1") ||
129
- !NHelper::Check(env, info[0].IsNumber(),
130
- "Expected the first argument to be a numeric deviceId")) {
246
+ !NHelper::Check(env, info[0].IsNumber(), "Expected the first argument to be a numeric deviceId"))
247
+ {
131
248
  return;
132
249
  }
133
250
 
@@ -150,6 +267,20 @@ Napi::Value TModel::GetModelCatFeaturesCount(const Napi::CallbackInfo& info) {
150
267
  return Napi::Number::New(env, count);
151
268
  }
152
269
 
270
+ Napi::Value TModel::GetModelTextFeaturesCount(const Napi::CallbackInfo& info) {
271
+ Napi::Env env = info.Env();
272
+ const size_t count = GetTextFeaturesCount(this->Handle);
273
+
274
+ return Napi::Number::New(env, count);
275
+ }
276
+
277
+ Napi::Value TModel::GetModelEmbeddingFeaturesCount(const Napi::CallbackInfo& info) {
278
+ Napi::Env env = info.Env();
279
+ const size_t count = GetEmbeddingFeaturesCount(this->Handle);
280
+
281
+ return Napi::Number::New(env, count);
282
+ }
283
+
153
284
  Napi::Value TModel::GetModelTreeCount(const Napi::CallbackInfo& info) {
154
285
  Napi::Env env = info.Env();
155
286
  const size_t count = GetTreeCount(this->Handle);
@@ -164,68 +295,337 @@ Napi::Value TModel::GetModelDimensionsCount(const Napi::CallbackInfo& info) {
164
295
  return Napi::Number::New(env, count);
165
296
  }
166
297
 
167
- Napi::Array TModel::CalcPredictionHash(Napi::Env env,
168
- const TVector<float>& floatFeatures,
169
- const Napi::Array& catFeatures) {
170
- const uint32_t docsCount = catFeatures.Length();
171
- const uint32_t catFeaturesSize = catFeatures[0u].As<Napi::Array>().Length();
172
- const uint32_t floatFeaturesSize = floatFeatures.size() / docsCount;
298
+ Napi::Value TModel::GetPredictionDimensionsCount(const Napi::CallbackInfo& info) {
299
+ Napi::Env env = info.Env();
300
+ const size_t count = ::GetPredictionDimensionsCount(this->Handle);
301
+
302
+ return Napi::Number::New(env, count);
303
+ }
304
+
305
+ static void GetNumericFeaturesData(
306
+ const uint32_t sampleCount,
307
+ const Napi::Array& floatFeatures,
308
+ uint32_t* floatFeatureCount,
309
+ std::vector<float>* storage,
310
+ std::vector<const float*>* sampleDataPtrs
311
+ ) {
312
+ const uint32_t floatFeatureCountLocal = floatFeatures[0u].As<Napi::Array>().Length();
313
+ *floatFeatureCount = floatFeatureCountLocal;
314
+
315
+ storage->clear();
316
+ storage->reserve(floatFeatureCountLocal * sampleCount);
317
+
318
+ for (uint32_t i = 0; i < sampleCount; ++i) {
319
+ const Napi::Array row = floatFeatures[i].As<Napi::Array>();
320
+ for (uint32_t j = 0; j < floatFeatureCountLocal; ++j) {
321
+ storage->push_back(row[j].As<Napi::Number>().FloatValue());
322
+ }
323
+ }
173
324
 
174
- TVector<int> catHashValues;
175
- catHashValues.reserve(catFeaturesSize * docsCount);
325
+ *sampleDataPtrs = CollectMatrixRowPointers<float>(*storage, floatFeatureCountLocal);
326
+ }
176
327
 
177
- for (uint32_t i = 0; i < docsCount; ++i) {
178
- const Napi::Array row = catFeatures[i].As<Napi::Array>();
179
- for (uint32_t j = 0; j < catFeaturesSize; ++j) {
180
- catHashValues.push_back(row[j].As<Napi::Number>().Int32Value());
328
+ static void GetTextFeaturesData(
329
+ const uint32_t sampleCount,
330
+ const Napi::Value& textFeatures, // array or empty
331
+ uint32_t* textFeatureCount,
332
+ std::vector<std::string>* storage,
333
+ std::vector<const char*>* dataPtrsStorage,
334
+ std::vector<const char**>* sampleDataPtrs
335
+ ) {
336
+ storage->clear();
337
+ dataPtrsStorage->clear();
338
+ sampleDataPtrs->clear();
339
+ if (textFeatures.IsEmpty()) {
340
+ *textFeatureCount = 0;
341
+ } else {
342
+ const Napi::Array textFeaturesArray = textFeatures.As<Napi::Array>();
343
+ const uint32_t textFeatureCountLocal = textFeaturesArray[0u].As<Napi::Array>().Length();
344
+ *textFeatureCount = textFeatureCountLocal;
345
+
346
+ storage->reserve(textFeatureCountLocal * sampleCount);
347
+ dataPtrsStorage->reserve(textFeatureCountLocal * sampleCount);
348
+
349
+ for (uint32_t i = 0; i < sampleCount; ++i) {
350
+ const Napi::Array row = textFeaturesArray[i].As<Napi::Array>();
351
+ for (uint32_t j = 0; j < textFeatureCountLocal; ++j) {
352
+ storage->push_back(row[j].As<Napi::String>().Utf8Value());
353
+ dataPtrsStorage->push_back(storage->back().c_str());
354
+ }
181
355
  }
356
+
357
+ *sampleDataPtrs = CollectMatrixRowPointers<const char*, const char**>(
358
+ *dataPtrsStorage,
359
+ textFeatureCountLocal
360
+ );
182
361
  }
362
+ }
183
363
 
184
- TVector<double> resultValues;
185
- resultValues.resize(docsCount);
364
+ static bool GetEmbeddingFeaturesData(
365
+ Napi::Env env,
366
+ const uint32_t sampleCount,
367
+ const Napi::Value& embeddingFeatures, // array or empty
368
+ uint32_t* embeddingFeatureCount,
369
+ std::vector<size_t>* embeddingDimensions,
370
+ std::vector<float>* storage,
371
+ std::vector<const float*>* dataPtrsStorage,
372
+ std::vector<const float**>* sampleDataPtrs
373
+ ) {
374
+ embeddingDimensions->clear();
375
+ storage->clear();
376
+ dataPtrsStorage->clear();
377
+ sampleDataPtrs->clear();
378
+ if (embeddingFeatures.IsEmpty()) {
379
+ *embeddingFeatureCount = 0;
380
+ } else {
381
+ const Napi::Array embeddingsFeaturesArray = embeddingFeatures.As<Napi::Array>();
382
+ const uint32_t embeddingFeatureCountLocal = embeddingsFeaturesArray[0u].As<Napi::Array>().Length();
383
+ *embeddingFeatureCount = embeddingFeatureCountLocal;
384
+
385
+ embeddingDimensions->reserve(embeddingFeatureCountLocal);
386
+ // this is a lower bound, final allocation is delayed until the first sample is processed and
387
+ // embedding dimensions become known
388
+ storage->reserve(embeddingFeatureCountLocal * sampleCount);
389
+ dataPtrsStorage->reserve(embeddingFeatureCountLocal * sampleCount);
390
+
391
+ size_t perSampleValuesSize = 0;
392
+
393
+ for (uint32_t i = 0; i < sampleCount; ++i) {
394
+ const Napi::Array row = embeddingsFeaturesArray[i].As<Napi::Array>();
395
+ for (uint32_t j = 0; j < embeddingFeatureCountLocal; ++j) {
396
+ const Napi::Array embeddingValues = row[j].As<Napi::Array>();
397
+ auto embeddingSize = embeddingValues.Length();
398
+ if (i == 0) {
399
+ embeddingDimensions->push_back(embeddingSize);
400
+ } else {
401
+ if (!NHelper::Check(
402
+ env,
403
+ (*embeddingDimensions)[j] == embeddingSize,
404
+ "Embedding values arrays have different lengths"
405
+ ))
406
+ {
407
+ return false;
408
+ }
409
+ }
410
+
411
+ for (uint32_t k = 0; k < embeddingSize; ++k) {
412
+ storage->push_back(embeddingValues[k].As<Napi::Number>().FloatValue());
413
+ }
414
+ // can't update dataPtrsStorage just yet as it is not reserved to final size
415
+ }
416
+ if (i == 0) {
417
+ perSampleValuesSize = storage->size();
418
+ storage->reserve(perSampleValuesSize * sampleCount);
419
+ }
420
+ const float* dataPtr = storage->data() + perSampleValuesSize * i;
421
+ for (uint32_t j = 0; j < embeddingFeatureCountLocal; ++j) {
422
+ dataPtrsStorage->push_back(dataPtr);
423
+ dataPtr += (*embeddingDimensions)[j];
424
+ }
425
+ }
186
426
 
187
- TVector<const float*> floatPtrs = CollectMatrixRowPointers<float>(floatFeatures, floatFeaturesSize);
188
- TVector<const int*> catPtrs = CollectMatrixRowPointers<int>(catHashValues, catFeaturesSize);
189
- NHelper::CheckStatus(env,
190
- CalcModelPredictionWithHashedCatFeatures(this->Handle, docsCount,
191
- floatPtrs.data(), floatFeaturesSize,
192
- catPtrs.data(), catFeaturesSize,
193
- resultValues.data(), docsCount));
427
+ *sampleDataPtrs = CollectMatrixRowPointers<const float*, const float**>(
428
+ *dataPtrsStorage,
429
+ embeddingFeatureCountLocal
430
+ );
431
+ }
194
432
 
195
- return NHelper::ConvertToArray(env, resultValues);
433
+ return true;
196
434
  }
197
435
 
198
- Napi::Array TModel::CalcPredictionString(Napi::Env env,
199
- const TVector<float>& floatFeatures,
200
- const Napi::Array& catFeatures) {
201
- const uint32_t docsCount = catFeatures.Length();
202
- const uint32_t catFeaturesSize = catFeatures[0u].As<Napi::Array>().Length();
203
- const uint32_t floatFeaturesSize = floatFeatures.size() / docsCount;
204
436
 
205
- TVector<std::string> catStrings;
206
- TVector<const char*> catStringValues;
207
- catStrings.reserve(catFeaturesSize * docsCount);
208
- catStringValues.reserve(catFeaturesSize * docsCount);
437
+ Napi::Array TModel::CalcPredictionWithCatFeaturesAsHashes(
438
+ Napi::Env env,
439
+ const uint32_t sampleCount,
440
+ const Napi::Array& floatFeatures,
441
+ const Napi::Value& catFeatures,
442
+ const Napi::Value& textFeatures,
443
+ const Napi::Value& embeddingFeatures
444
+ ) {
445
+ uint32_t floatFeaturesSize = 0;
446
+ std::vector<float> floatFeaturesStorage;
447
+ std::vector<const float*> floatPtrs;
448
+
449
+ GetNumericFeaturesData(sampleCount, floatFeatures, &floatFeaturesSize, &floatFeaturesStorage, &floatPtrs);
209
450
 
210
- for (uint32_t i = 0; i < docsCount; ++i) {
211
- const Napi::Array row = catFeatures[i].As<Napi::Array>();
212
- for (uint32_t j = 0; j < catFeaturesSize; ++j) {
213
- catStrings.push_back(row[j].As<Napi::String>().Utf8Value());
214
- catStringValues.push_back(catStrings.back().c_str());
451
+
452
+ uint32_t catFeaturesSize = 0;
453
+ std::vector<int> catHashValues;
454
+ std::vector<const int*> catPtrs;
455
+
456
+ if (!catFeatures.IsEmpty()) {
457
+ const Napi::Array catFeaturesArray = catFeatures.As<Napi::Array>();
458
+ catFeaturesSize = catFeaturesArray[0u].As<Napi::Array>().Length();
459
+
460
+ catHashValues.reserve(catFeaturesSize * sampleCount);
461
+
462
+ for (uint32_t i = 0; i < sampleCount; ++i) {
463
+ const Napi::Array row = catFeaturesArray[i].As<Napi::Array>();
464
+ for (uint32_t j = 0; j < catFeaturesSize; ++j) {
465
+ catHashValues.push_back(row[j].As<Napi::Number>().Int32Value());
466
+ }
467
+ }
468
+
469
+ catPtrs = CollectMatrixRowPointers<int>(catHashValues, catFeaturesSize);
470
+ }
471
+
472
+
473
+ uint32_t textFeaturesSize = 0;
474
+ std::vector<std::string> textFeaturesStorage;
475
+ std::vector<const char*> textFeaturesDataPtrsStorage;
476
+ std::vector<const char**> textFeaturesSampleDataPtrs;
477
+
478
+ GetTextFeaturesData(
479
+ sampleCount,
480
+ textFeatures,
481
+ &textFeaturesSize,
482
+ &textFeaturesStorage,
483
+ &textFeaturesDataPtrsStorage,
484
+ &textFeaturesSampleDataPtrs
485
+ );
486
+
487
+
488
+ uint32_t embeddingFeaturesSize = 0;
489
+ std::vector<size_t> embeddingDimensions;
490
+ std::vector<float> embeddingFeaturesStorage;
491
+ std::vector<const float*> embeddingFeaturesDataPtrsStorage;
492
+ std::vector<const float**> embeddingFeaturesSampleDataPtrs;
493
+
494
+ if (!NHelper::Check(
495
+ env,
496
+ GetEmbeddingFeaturesData(
497
+ env,
498
+ sampleCount,
499
+ embeddingFeatures,
500
+ &embeddingFeaturesSize,
501
+ &embeddingDimensions,
502
+ &embeddingFeaturesStorage,
503
+ &embeddingFeaturesDataPtrsStorage,
504
+ &embeddingFeaturesSampleDataPtrs
505
+ ),
506
+ "Failed to get embedding features data"
507
+ ))
508
+ {
509
+ return Napi::Array::New(env);
510
+ }
511
+
512
+
513
+ const auto predictionDimensions = ::GetPredictionDimensionsCount(this->Handle);
514
+ std::vector<double> resultValues;
515
+ resultValues.resize(sampleCount * predictionDimensions);
516
+
517
+ NHelper::CheckStatus(
518
+ env,
519
+ CalcModelPredictionWithHashedCatFeaturesAndTextAndEmbeddingFeatures(
520
+ this->Handle,
521
+ sampleCount,
522
+ floatPtrs.data(), floatFeaturesSize,
523
+ catPtrs.data(), catFeaturesSize,
524
+ textFeaturesSampleDataPtrs.data(), textFeaturesSize,
525
+ embeddingFeaturesSampleDataPtrs.data(), embeddingDimensions.data(), embeddingFeaturesSize,
526
+ resultValues.data(), resultValues.size()
527
+ )
528
+ );
529
+
530
+ return NHelper::ConvertToArray(env, resultValues);
531
+ }
532
+
533
+ Napi::Array TModel::CalcPredictionWithCatFeaturesAsStrings(
534
+ Napi::Env env,
535
+ const uint32_t sampleCount,
536
+ const Napi::Array& floatFeatures,
537
+ const Napi::Value& catFeatures,
538
+ const Napi::Value& textFeatures,
539
+ const Napi::Value& embeddingFeatures
540
+ ) {
541
+ uint32_t floatFeaturesSize = 0;
542
+ std::vector<float> floatFeaturesStorage;
543
+ std::vector<const float*> floatPtrs;
544
+
545
+ GetNumericFeaturesData(sampleCount, floatFeatures, &floatFeaturesSize, &floatFeaturesStorage, &floatPtrs);
546
+
547
+ uint32_t catFeaturesSize = 0;
548
+ std::vector<std::string> catStrings;
549
+ std::vector<const char*> catStringValues;
550
+ std::vector<const char**> catPtrs;
551
+
552
+ if (!catFeatures.IsEmpty()) {
553
+ const Napi::Array catFeaturesArray = catFeatures.As<Napi::Array>();
554
+ catFeaturesSize = catFeaturesArray[0u].As<Napi::Array>().Length();
555
+
556
+ catStrings.reserve(catFeaturesSize * sampleCount);
557
+ catStringValues.reserve(catFeaturesSize * sampleCount);
558
+
559
+ for (uint32_t i = 0; i < sampleCount; ++i) {
560
+ const Napi::Array row = catFeaturesArray[i].As<Napi::Array>();
561
+ for (uint32_t j = 0; j < catFeaturesSize; ++j) {
562
+ catStrings.push_back(row[j].As<Napi::String>().Utf8Value());
563
+ catStringValues.push_back(catStrings.back().c_str());
564
+ }
215
565
  }
566
+ catPtrs = CollectMatrixRowPointers<const char*, const char**>(
567
+ catStringValues,
568
+ catFeaturesSize
569
+ );
216
570
  }
217
571
 
218
- TVector<double> resultValues;
219
- resultValues.resize(docsCount);
572
+ uint32_t textFeaturesSize = 0;
573
+ std::vector<std::string> textFeaturesStorage;
574
+ std::vector<const char*> textFeaturesDataPtrsStorage;
575
+ std::vector<const char**> textFeaturesSampleDataPtrs;
576
+
577
+ GetTextFeaturesData(
578
+ sampleCount,
579
+ textFeatures,
580
+ &textFeaturesSize,
581
+ &textFeaturesStorage,
582
+ &textFeaturesDataPtrsStorage,
583
+ &textFeaturesSampleDataPtrs
584
+ );
585
+
586
+
587
+ uint32_t embeddingFeaturesSize = 0;
588
+ std::vector<size_t> embeddingDimensions;
589
+ std::vector<float> embeddingFeaturesStorage;
590
+ std::vector<const float*> embeddingFeaturesDataPtrsStorage;
591
+ std::vector<const float**> embeddingFeaturesSampleDataPtrs;
592
+
593
+ if (!NHelper::Check(
594
+ env,
595
+ GetEmbeddingFeaturesData(
596
+ env,
597
+ sampleCount,
598
+ embeddingFeatures,
599
+ &embeddingFeaturesSize,
600
+ &embeddingDimensions,
601
+ &embeddingFeaturesStorage,
602
+ &embeddingFeaturesDataPtrsStorage,
603
+ &embeddingFeaturesSampleDataPtrs
604
+ ),
605
+ "Failed to get embedding features data"
606
+ ))
607
+ {
608
+ return Napi::Array::New(env);
609
+ }
220
610
 
221
- TVector<const float*> floatPtrs = CollectMatrixRowPointers<float>(floatFeatures, floatFeaturesSize);
222
- TVector<const char**> catPtrs = CollectMatrixRowPointers<const char*, const char**>(catStringValues, catFeaturesSize);
223
611
 
224
- if (!NHelper::CheckStatus(env,
225
- CalcModelPrediction(this->Handle, docsCount,
226
- floatPtrs.data(), floatFeaturesSize,
227
- catPtrs.data(), catFeaturesSize,
228
- resultValues.data(), docsCount))) {
612
+ const auto predictionDimensions = ::GetPredictionDimensionsCount(this->Handle);
613
+ std::vector<double> resultValues;
614
+ resultValues.resize(sampleCount * predictionDimensions);
615
+
616
+ if (!NHelper::CheckStatus(
617
+ env,
618
+ CalcModelPredictionTextAndEmbeddings(
619
+ this->Handle,
620
+ sampleCount,
621
+ floatPtrs.data(), floatFeaturesSize,
622
+ catPtrs.data(), catFeaturesSize,
623
+ textFeaturesSampleDataPtrs.data(), textFeaturesSize,
624
+ embeddingFeaturesSampleDataPtrs.data(), embeddingDimensions.data(), embeddingFeaturesSize,
625
+ resultValues.data(), resultValues.size()
626
+ )
627
+ ))
628
+ {
229
629
  return Napi::Array::New(env);
230
630
  }
231
631