catboost 1.25.1 → 1.26.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (289) hide show
  1. package/DEPLOYMENT.md +22 -15
  2. package/README.md +37 -27
  3. package/binding.gyp +5 -7
  4. package/build_scripts/bootstrap.js +2 -1
  5. package/build_scripts/out/build.js +46 -68
  6. package/build_scripts/out/build_model.js +1 -1
  7. package/build_scripts/out/{build_ya.js → build_native.js} +1 -1
  8. package/build_scripts/out/ci.js +5 -5
  9. package/build_scripts/out/config.js +32 -18
  10. package/build_scripts/out/install.js +5 -3
  11. package/build_scripts/out/package_prepublish.js +1 -1
  12. package/build_scripts/out/packaging.js +1 -19
  13. package/build_scripts/out/run_tests.js +1 -1
  14. package/build_scripts/out/test.js +8 -3
  15. package/config.json +18 -11
  16. package/inc/catboost/libs/model_interface/c_api.h +349 -3
  17. package/lib/catboost.d.ts +65 -21
  18. package/package.json +4 -4
  19. package/src/api_helpers.cpp +100 -24
  20. package/src/api_helpers.h +8 -7
  21. package/src/api_module.cpp +1 -2
  22. package/src/model.cpp +483 -83
  23. package/src/model.h +24 -9
  24. package/inc/contrib/libs/cxxsupp/system_stl/include/stlfwd +0 -14
  25. package/inc/util/charset/recode_result.h +0 -9
  26. package/inc/util/charset/unicode_table.h +0 -123
  27. package/inc/util/charset/unidata.h +0 -421
  28. package/inc/util/charset/utf8.h +0 -384
  29. package/inc/util/charset/wide.h +0 -843
  30. package/inc/util/charset/wide_specific.h +0 -22
  31. package/inc/util/datetime/base.h +0 -669
  32. package/inc/util/datetime/constants.h +0 -7
  33. package/inc/util/datetime/cputimer.h +0 -124
  34. package/inc/util/datetime/parser.h +0 -292
  35. package/inc/util/datetime/systime.h +0 -47
  36. package/inc/util/datetime/uptime.h +0 -8
  37. package/inc/util/digest/city.h +0 -88
  38. package/inc/util/digest/fnv.h +0 -73
  39. package/inc/util/digest/multi.h +0 -14
  40. package/inc/util/digest/murmur.h +0 -57
  41. package/inc/util/digest/numeric.h +0 -86
  42. package/inc/util/digest/sequence.h +0 -48
  43. package/inc/util/draft/date.h +0 -129
  44. package/inc/util/draft/datetime.h +0 -184
  45. package/inc/util/draft/enum.h +0 -136
  46. package/inc/util/draft/holder_vector.h +0 -102
  47. package/inc/util/draft/ip.h +0 -131
  48. package/inc/util/draft/matrix.h +0 -108
  49. package/inc/util/draft/memory.h +0 -40
  50. package/inc/util/folder/dirent_win.h +0 -46
  51. package/inc/util/folder/dirut.h +0 -121
  52. package/inc/util/folder/filelist.h +0 -81
  53. package/inc/util/folder/fts.h +0 -108
  54. package/inc/util/folder/iterator.h +0 -109
  55. package/inc/util/folder/lstat_win.h +0 -20
  56. package/inc/util/folder/path.h +0 -225
  57. package/inc/util/folder/pathsplit.h +0 -113
  58. package/inc/util/folder/tempdir.h +0 -42
  59. package/inc/util/generic/adaptor.h +0 -134
  60. package/inc/util/generic/algorithm.h +0 -765
  61. package/inc/util/generic/array_ref.h +0 -282
  62. package/inc/util/generic/array_size.h +0 -24
  63. package/inc/util/generic/benchmark/vector_count_ctor/f.h +0 -9
  64. package/inc/util/generic/bitmap.h +0 -1115
  65. package/inc/util/generic/bitops.h +0 -459
  66. package/inc/util/generic/bt_exception.h +0 -24
  67. package/inc/util/generic/buffer.h +0 -232
  68. package/inc/util/generic/cast.h +0 -176
  69. package/inc/util/generic/deque.h +0 -24
  70. package/inc/util/generic/explicit_type.h +0 -42
  71. package/inc/util/generic/fastqueue.h +0 -55
  72. package/inc/util/generic/flags.h +0 -244
  73. package/inc/util/generic/function.h +0 -103
  74. package/inc/util/generic/fwd.h +0 -171
  75. package/inc/util/generic/guid.h +0 -61
  76. package/inc/util/generic/hash.h +0 -2032
  77. package/inc/util/generic/hash_primes.h +0 -140
  78. package/inc/util/generic/hash_set.h +0 -490
  79. package/inc/util/generic/hide_ptr.h +0 -3
  80. package/inc/util/generic/intrlist.h +0 -876
  81. package/inc/util/generic/is_in.h +0 -53
  82. package/inc/util/generic/iterator.h +0 -137
  83. package/inc/util/generic/iterator_range.h +0 -105
  84. package/inc/util/generic/lazy_value.h +0 -66
  85. package/inc/util/generic/list.h +0 -22
  86. package/inc/util/generic/map.h +0 -44
  87. package/inc/util/generic/mapfindptr.h +0 -60
  88. package/inc/util/generic/maybe.h +0 -713
  89. package/inc/util/generic/maybe_traits.h +0 -164
  90. package/inc/util/generic/mem_copy.h +0 -55
  91. package/inc/util/generic/noncopyable.h +0 -38
  92. package/inc/util/generic/object_counter.h +0 -53
  93. package/inc/util/generic/ptr.h +0 -1113
  94. package/inc/util/generic/queue.h +0 -57
  95. package/inc/util/generic/refcount.h +0 -162
  96. package/inc/util/generic/reserve.h +0 -11
  97. package/inc/util/generic/scope.h +0 -65
  98. package/inc/util/generic/serialized_enum.h +0 -406
  99. package/inc/util/generic/set.h +0 -42
  100. package/inc/util/generic/singleton.h +0 -136
  101. package/inc/util/generic/size_literals.h +0 -65
  102. package/inc/util/generic/stack.h +0 -18
  103. package/inc/util/generic/store_policy.h +0 -120
  104. package/inc/util/generic/strbase.h +0 -612
  105. package/inc/util/generic/strbuf.h +0 -552
  106. package/inc/util/generic/strfcpy.h +0 -17
  107. package/inc/util/generic/string.h +0 -1572
  108. package/inc/util/generic/string_hash.h +0 -21
  109. package/inc/util/generic/string_ut.h +0 -1175
  110. package/inc/util/generic/type_name.h +0 -34
  111. package/inc/util/generic/typelist.h +0 -114
  112. package/inc/util/generic/typetraits.h +0 -325
  113. package/inc/util/generic/utility.h +0 -132
  114. package/inc/util/generic/va_args.h +0 -400
  115. package/inc/util/generic/variant.h +0 -631
  116. package/inc/util/generic/variant_traits.h +0 -171
  117. package/inc/util/generic/vector.h +0 -119
  118. package/inc/util/generic/xrange.h +0 -258
  119. package/inc/util/generic/yexception.h +0 -212
  120. package/inc/util/generic/yexception_ut.h +0 -14
  121. package/inc/util/generic/ylimits.h +0 -92
  122. package/inc/util/generic/ymath.h +0 -206
  123. package/inc/util/memory/addstorage.h +0 -93
  124. package/inc/util/memory/alloc.h +0 -27
  125. package/inc/util/memory/blob.h +0 -296
  126. package/inc/util/memory/mmapalloc.h +0 -8
  127. package/inc/util/memory/pool.h +0 -432
  128. package/inc/util/memory/segmented_string_pool.h +0 -194
  129. package/inc/util/memory/segpool_alloc.h +0 -118
  130. package/inc/util/memory/smallobj.h +0 -141
  131. package/inc/util/memory/tempbuf.h +0 -111
  132. package/inc/util/network/address.h +0 -136
  133. package/inc/util/network/endpoint.h +0 -61
  134. package/inc/util/network/hostip.h +0 -16
  135. package/inc/util/network/init.h +0 -60
  136. package/inc/util/network/interface.h +0 -17
  137. package/inc/util/network/iovec.h +0 -65
  138. package/inc/util/network/ip.h +0 -116
  139. package/inc/util/network/nonblock.h +0 -8
  140. package/inc/util/network/pair.h +0 -9
  141. package/inc/util/network/poller.h +0 -58
  142. package/inc/util/network/pollerimpl.h +0 -707
  143. package/inc/util/network/sock.h +0 -608
  144. package/inc/util/network/socket.h +0 -421
  145. package/inc/util/random/common_ops.h +0 -130
  146. package/inc/util/random/easy.h +0 -47
  147. package/inc/util/random/entropy.h +0 -21
  148. package/inc/util/random/fast.h +0 -101
  149. package/inc/util/random/init_atfork.h +0 -3
  150. package/inc/util/random/lcg_engine.h +0 -66
  151. package/inc/util/random/mersenne.h +0 -46
  152. package/inc/util/random/mersenne32.h +0 -50
  153. package/inc/util/random/mersenne64.h +0 -50
  154. package/inc/util/random/normal.h +0 -38
  155. package/inc/util/random/random.h +0 -30
  156. package/inc/util/random/shuffle.h +0 -39
  157. package/inc/util/str_stl.h +0 -266
  158. package/inc/util/stream/aligned.h +0 -99
  159. package/inc/util/stream/buffer.h +0 -119
  160. package/inc/util/stream/buffered.h +0 -225
  161. package/inc/util/stream/debug.h +0 -53
  162. package/inc/util/stream/direct_io.h +0 -43
  163. package/inc/util/stream/file.h +0 -108
  164. package/inc/util/stream/format.h +0 -444
  165. package/inc/util/stream/fwd.h +0 -100
  166. package/inc/util/stream/hex.h +0 -8
  167. package/inc/util/stream/holder.h +0 -44
  168. package/inc/util/stream/input.h +0 -273
  169. package/inc/util/stream/labeled.h +0 -19
  170. package/inc/util/stream/length.h +0 -100
  171. package/inc/util/stream/mem.h +0 -255
  172. package/inc/util/stream/multi.h +0 -32
  173. package/inc/util/stream/null.h +0 -61
  174. package/inc/util/stream/output.h +0 -304
  175. package/inc/util/stream/pipe.h +0 -112
  176. package/inc/util/stream/printf.h +0 -25
  177. package/inc/util/stream/str.h +0 -207
  178. package/inc/util/stream/tee.h +0 -28
  179. package/inc/util/stream/tempbuf.h +0 -21
  180. package/inc/util/stream/tokenizer.h +0 -214
  181. package/inc/util/stream/trace.h +0 -60
  182. package/inc/util/stream/walk.h +0 -35
  183. package/inc/util/stream/zerocopy.h +0 -91
  184. package/inc/util/stream/zerocopy_output.h +0 -57
  185. package/inc/util/stream/zlib.h +0 -173
  186. package/inc/util/string/ascii.h +0 -236
  187. package/inc/util/string/builder.h +0 -39
  188. package/inc/util/string/cast.h +0 -347
  189. package/inc/util/string/cstriter.h +0 -14
  190. package/inc/util/string/escape.h +0 -70
  191. package/inc/util/string/hex.h +0 -59
  192. package/inc/util/string/join.h +0 -194
  193. package/inc/util/string/printf.h +0 -13
  194. package/inc/util/string/reverse.h +0 -16
  195. package/inc/util/string/split.h +0 -1080
  196. package/inc/util/string/strip.h +0 -257
  197. package/inc/util/string/strspn.h +0 -65
  198. package/inc/util/string/subst.h +0 -56
  199. package/inc/util/string/type.h +0 -50
  200. package/inc/util/string/util.h +0 -195
  201. package/inc/util/string/vector.h +0 -132
  202. package/inc/util/system/align.h +0 -50
  203. package/inc/util/system/atexit.h +0 -22
  204. package/inc/util/system/atomic.h +0 -51
  205. package/inc/util/system/atomic_gcc.h +0 -90
  206. package/inc/util/system/atomic_ops.h +0 -189
  207. package/inc/util/system/atomic_win.h +0 -114
  208. package/inc/util/system/backtrace.h +0 -39
  209. package/inc/util/system/byteorder.h +0 -186
  210. package/inc/util/system/compat.h +0 -84
  211. package/inc/util/system/compiler.h +0 -620
  212. package/inc/util/system/condvar.h +0 -71
  213. package/inc/util/system/context.h +0 -181
  214. package/inc/util/system/context_aarch64.h +0 -8
  215. package/inc/util/system/context_i686.h +0 -9
  216. package/inc/util/system/context_x86.h +0 -12
  217. package/inc/util/system/context_x86_64.h +0 -7
  218. package/inc/util/system/cpu_id.h +0 -159
  219. package/inc/util/system/daemon.h +0 -28
  220. package/inc/util/system/datetime.h +0 -98
  221. package/inc/util/system/defaults.h +0 -149
  222. package/inc/util/system/demangle.h +0 -5
  223. package/inc/util/system/demangle_impl.h +0 -23
  224. package/inc/util/system/direct_io.h +0 -71
  225. package/inc/util/system/dynlib.h +0 -119
  226. package/inc/util/system/env.h +0 -32
  227. package/inc/util/system/error.h +0 -95
  228. package/inc/util/system/event.h +0 -122
  229. package/inc/util/system/execpath.h +0 -17
  230. package/inc/util/system/fasttime.h +0 -6
  231. package/inc/util/system/fhandle.h +0 -27
  232. package/inc/util/system/file.h +0 -210
  233. package/inc/util/system/file_lock.h +0 -34
  234. package/inc/util/system/filemap.h +0 -383
  235. package/inc/util/system/flock.h +0 -35
  236. package/inc/util/system/fs.h +0 -156
  237. package/inc/util/system/fs_win.h +0 -29
  238. package/inc/util/system/fstat.h +0 -46
  239. package/inc/util/system/getpid.h +0 -12
  240. package/inc/util/system/guard.h +0 -179
  241. package/inc/util/system/hi_lo.h +0 -139
  242. package/inc/util/system/hostname.h +0 -10
  243. package/inc/util/system/hp_timer.h +0 -36
  244. package/inc/util/system/info.h +0 -12
  245. package/inc/util/system/interrupt_signals.h +0 -22
  246. package/inc/util/system/madvise.h +0 -30
  247. package/inc/util/system/maxlen.h +0 -32
  248. package/inc/util/system/mem_info.h +0 -18
  249. package/inc/util/system/mincore.h +0 -38
  250. package/inc/util/system/mktemp.h +0 -11
  251. package/inc/util/system/mlock.h +0 -43
  252. package/inc/util/system/mutex.h +0 -67
  253. package/inc/util/system/nice.h +0 -3
  254. package/inc/util/system/pipe.h +0 -90
  255. package/inc/util/system/platform.h +0 -246
  256. package/inc/util/system/progname.h +0 -13
  257. package/inc/util/system/protect.h +0 -25
  258. package/inc/util/system/rusage.h +0 -26
  259. package/inc/util/system/rwlock.h +0 -78
  260. package/inc/util/system/sanitizers.h +0 -122
  261. package/inc/util/system/sem.h +0 -41
  262. package/inc/util/system/shellcommand.h +0 -472
  263. package/inc/util/system/shmat.h +0 -32
  264. package/inc/util/system/sigset.h +0 -78
  265. package/inc/util/system/spin_wait.h +0 -10
  266. package/inc/util/system/spinlock.h +0 -121
  267. package/inc/util/system/src_location.h +0 -25
  268. package/inc/util/system/src_root.h +0 -68
  269. package/inc/util/system/sys_alloc.h +0 -43
  270. package/inc/util/system/sysstat.h +0 -52
  271. package/inc/util/system/tempfile.h +0 -34
  272. package/inc/util/system/thread.h +0 -167
  273. package/inc/util/system/tls.h +0 -307
  274. package/inc/util/system/types.h +0 -119
  275. package/inc/util/system/unaligned_mem.h +0 -67
  276. package/inc/util/system/user.h +0 -5
  277. package/inc/util/system/utime.h +0 -6
  278. package/inc/util/system/valgrind.h +0 -48
  279. package/inc/util/system/winint.h +0 -43
  280. package/inc/util/system/yassert.h +0 -121
  281. package/inc/util/system/yield.h +0 -4
  282. package/inc/util/thread/factory.h +0 -65
  283. package/inc/util/thread/fwd.h +0 -30
  284. package/inc/util/thread/lfqueue.h +0 -406
  285. package/inc/util/thread/lfstack.h +0 -188
  286. package/inc/util/thread/pool.h +0 -388
  287. package/inc/util/thread/singleton.h +0 -42
  288. package/inc/util/ysafeptr.h +0 -427
  289. package/inc/util/ysaveload.h +0 -700
package/lib/catboost.d.ts CHANGED
@@ -1,31 +1,75 @@
1
- /** CatBoost numeric features for multiple documents. */
1
+ /** CatBoost numeric features for multiple samples. */
2
2
  export type CatBoostFloatFeatures = Array<number[]>;
3
3
  /**
4
- * CatBoost categorial features for multiple documents - either integer hashes
4
+ * CatBoost categorical features for multiple samples - either integer hashes
5
5
  * or string values.
6
6
  */
7
7
  export type CatBoostCategoryFeatures = Array<number[]>|Array<string[]>;
8
+ /** CatBoost text features for multiple samples. */
9
+ export type CatBoostTextFeatures = Array<string[]>;
10
+ /** CatBoost embedding features for multiple samples. */
11
+ export type CatBoostEmbeddingFeatures = Array<Array<number[]>>;
8
12
 
9
13
  /** CatBoost model instance. */
10
14
  export class Model {
11
- constructor(path?: string);
15
+ constructor(path?: string);
12
16
 
13
- /** Loads model from file. */
14
- loadModel(path: string): void;
15
- /**
16
- * Calculate prediction for multiple documents.
17
- * The same number of numeric and categorial features is expected.
18
- */
19
- predict(floatFeatures: CatBoostFloatFeatures,
20
- catFeatures: CatBoostCategoryFeatures): number[];
21
- /** Enable evaluation on GPU device. */
22
- enableGPUEvaluation(deviceId: number): void;
23
- /** The number of numeric features. */
24
- getFloatFeaturesCount(): number;
25
- /** The number of categorial features. */
26
- getCatFeaturesCount(): number;
27
- /** The number of trees in the model. */
28
- getTreeCount(): number;
29
- /** The number of dimensions in the model. */
30
- getDimensionsCount(): number;
17
+ /** Load a model from the file. */
18
+ loadModel(path: string): void;
19
+ /** Set model prediction postprocessing type. Possible value are:
20
+ * RawFormulaVal - raw sum of leaf values for each dimension, this is the default
21
+ * Exponent - exp(sum(leaf values)),
22
+ * RMSEWithUncertainty - pair (prediction, uncertainty),
23
+ * Probability - (probablity for class_0, ..., probablity for class_i,...)
24
+ * MultiProbability - probability for each label (used for multilabel classification)
25
+ * Class - index of a class with the maximum predicted probability
26
+ * */
27
+ setPredictionType(predictionType: string): void;
28
+ /**
29
+ * Calculate the prediction for multiple samples.
30
+ * All defined feature arguments must have the same length.
31
+ *
32
+ * The returned value contains [sampleCount x predictionDimensions] elements
33
+ * (should be accessed using [sampleIndex * predictionDimensions + predictionDimensionIdx],
34
+ * for simple cases when predictionDimensions = 1 it is just [sampleIndex])
35
+ * and its interpretation depends on prediction type (can be set with 'setPredictionType'):
36
+ * - RawFormulaVal (this is the default):
37
+ * array of raw sum of leaf values for each dimension
38
+ * - Exponent:
39
+ * array of exp(sum(leaf values)) for each dimension
40
+ * - RMSEWithUncertainty:
41
+ * array of pairs (prediction, uncertainty)
42
+ * - Probability:
43
+ * - for binary classification models:
44
+ * array of probabilities for positive class (calculated as sigmoid(rawFormulaVal))
45
+ * - for multiclassification models:
46
+ * array of array of probabilities for each class (calculated as softmax(rawFormulaVal))
47
+ * - MultiProbability:
48
+ * array of probabilities for each label (calculated as sigmoid(rawFormulaVal))
49
+ * (used for multilabel classification)
50
+ * - Class:
51
+ * array of predicted class indices.
52
+ *
53
+ * predictionDimensions can be obtained using 'getPredictionDimensionsCount' method.
54
+ */
55
+ predict(floatFeatures: CatBoostFloatFeatures,
56
+ catFeatures?: CatBoostCategoryFeatures,
57
+ textFeatures?: CatBoostTextFeatures,
58
+ embeddingFeatures?: CatBoostEmbeddingFeatures): number[];
59
+ /** Enable evaluation on GPU device. */
60
+ enableGPUEvaluation(deviceId: number): void;
61
+ /** The number of numeric features. */
62
+ getFloatFeaturesCount(): number;
63
+ /** The number of categorical features. */
64
+ getCatFeaturesCount(): number;
65
+ /** The number of text features. */
66
+ getTextFeaturesCount(): number;
67
+ /** The number of embedding features. */
68
+ getEmbeddingFeaturesCount(): number;
69
+ /** The number of trees in the model. */
70
+ getTreeCount(): number;
71
+ /** The number of dimensions in the model. */
72
+ getDimensionsCount(): number;
73
+ /** The number of dimensions in the prediction. */
74
+ getPredictionDimensionsCount(): number;
31
75
  }
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "catboost",
3
- "version": "1.25.1",
3
+ "version": "1.26.0",
4
4
  "description": "Node bindings for CatBoost library to apply models. CatBoost is a machine learning method based on gradient boosting over decision trees.",
5
5
  "keywords": [
6
6
  "catboost",
@@ -16,11 +16,11 @@
16
16
  "main": "lib/index.js",
17
17
  "types": "lib/index.d.ts",
18
18
  "dependencies": {
19
- "node-addon-api": "^1.1.0"
19
+ "node-addon-api": "^8.2.2"
20
20
  },
21
21
  "scripts": {
22
22
  "install": "node ./build_scripts/bootstrap.js install",
23
- "build": "node ./build_scripts/bootstrap.js build_ya",
23
+ "build": "node ./build_scripts/bootstrap.js build_native",
24
24
  "ci": "node ./build_scripts/bootstrap.js ci",
25
25
  "compile": "node ./build_scripts/bootstrap.js compile",
26
26
  "package_prepublish": "node ./build_scripts/bootstrap.js package_prepublish",
@@ -36,4 +36,4 @@
36
36
  "@types/node": "^7.10.14",
37
37
  "typescript": "^3.9.9"
38
38
  }
39
- }
39
+ }
@@ -1,48 +1,124 @@
1
1
  #include "api_helpers.h"
2
2
 
3
- #include <util/system/yassert.h>
3
+ #include <assert.h>
4
4
 
5
5
  namespace NHelper {
6
6
 
7
- bool IsMatrix(const Napi::Value& value, ENApiType type) {
8
- if (!value.IsArray()) {
7
+ bool CheckIsMatrix(Napi::Env env, const Napi::Value& value, ENApiType type, const std::string& errorPrefix) {
8
+ if (!Check(env, value.IsArray(), errorPrefix + "is not an array")) {
9
9
  return false;
10
10
  }
11
- const Napi::Array floatFeatures = value.As<Napi::Array>();
12
- const uint32_t rowsCount = floatFeatures.Length();
11
+ const Napi::Array matrix = value.As<Napi::Array>();
12
+ const uint32_t rowsCount = matrix.Length();
13
13
  if (rowsCount == 0) {
14
14
  return true;
15
15
  }
16
16
 
17
- if (!floatFeatures[0u].IsArray()) {
17
+ if (!Check(env, matrix[0u].IsArray(), errorPrefix + "the first element of the matrix is not an array")) {
18
18
  return false;
19
19
  }
20
- const uint32_t columnsCount = floatFeatures[0u].As<Napi::Array>().Length();
20
+ const uint32_t columnsCount = matrix[0u].As<Napi::Array>().Length();
21
+ size_t numberCount = 0;
22
+ size_t strCount = 0;
23
+
24
+ std::function<bool(const Napi::Value&)> checkElement;
25
+
26
+ switch (type) {
27
+ case ENApiType::NAT_NUMBER:
28
+ checkElement = [&] (const Napi::Value& value) -> bool {
29
+ return Check(
30
+ env,
31
+ value.IsNumber(),
32
+ "non-numeric type in the matrix elements"
33
+ );
34
+ };
35
+ break;
36
+ case ENApiType::NAT_STRING:
37
+ checkElement = [&] (const Napi::Value& value) -> bool {
38
+ return Check(
39
+ env,
40
+ value.IsString(),
41
+ "non-string type in the matrix elements"
42
+ );
43
+ };
44
+ break;
45
+ case ENApiType::NAT_NUMBER_OR_STRING:
46
+ checkElement = [&] (const Napi::Value& value) -> bool {
47
+ if (value.IsNumber()) {
48
+ ++numberCount;
49
+ } else if (value.IsString()) {
50
+ ++strCount;
51
+ } else {
52
+ Check(env, false, errorPrefix + "invalid type found: " + std::to_string(value.Type()));
53
+ return false;
54
+ }
55
+ return true;
56
+ };
57
+ break;
58
+ case ENApiType::NAT_ARRAY_OR_NUMBERS:
59
+ checkElement = [&] (const Napi::Value& value) -> bool {
60
+ if (!Check(
61
+ env,
62
+ value.IsArray(),
63
+ "the matrix contains non-array elements"
64
+ ))
65
+ {
66
+ return false;
67
+ }
68
+
69
+ const Napi::Array subArray = value.As<Napi::Array>();
70
+ const uint32_t subArraySize = subArray.Length();
71
+
72
+ for (uint32_t k = 0; k < subArraySize; ++k) {
73
+ if (!Check(
74
+ env,
75
+ subArray[k].IsNumber(),
76
+ "an array in the matrix element contains a non-number element"
77
+ ))
78
+ {
79
+ return false;
80
+ }
81
+ }
82
+ return true;
83
+ };
84
+ break;
85
+ }
86
+
21
87
 
22
88
  for (uint32_t i = 0; i < rowsCount; ++i) {
23
- if (!floatFeatures[i].IsArray()) {
89
+ if (!Check(
90
+ env,
91
+ matrix[i].IsArray(),
92
+ errorPrefix + std::to_string(i) + "-th element of the matrix is not an array"
93
+ ))
94
+ {
24
95
  return false;
25
96
  }
26
97
 
27
- const Napi::Array row = floatFeatures[i].As<Napi::Array>();
28
- if (row.Length() != columnsCount) {
98
+ const Napi::Array row = matrix[i].As<Napi::Array>();
99
+ if (!Check(
100
+ env,
101
+ row.Length() == columnsCount,
102
+ errorPrefix + "invalid length of " + std::to_string(i) + "-th row"
103
+ ))
104
+ {
29
105
  return false;
30
106
  }
31
107
 
32
- for (uint32_t j = 0; j < rowsCount; ++j) {
33
- switch (type) {
34
- case ENApiType::NAT_NUMBER:
35
- if (!row[j].IsNumber()) {
36
- return false;
37
- }
38
- break;
39
- case ENApiType::NAT_STRING:
40
- if (!row[j].IsString()) {
41
- return false;
42
- }
43
- break;
44
- default:
45
- Y_ASSERT(false);
108
+ for (uint32_t j = 0; j < columnsCount; ++j) {
109
+ if (!checkElement(row[j])) {
110
+ return false;
111
+ }
112
+ }
113
+
114
+ if (type == ENApiType::NAT_NUMBER_OR_STRING) {
115
+ if (!Check(
116
+ env,
117
+ !(numberCount > 0 && strCount > 0),
118
+ errorPrefix + "mixed strings and numbers in the matrix"
119
+ ))
120
+ {
121
+ return false;
46
122
  }
47
123
  }
48
124
  }
package/src/api_helpers.h CHANGED
@@ -4,11 +4,11 @@
4
4
  #include <napi.h>
5
5
 
6
6
  // Catboost C API
7
- #include <c_api.h>
7
+ #include <catboost/libs/model_interface/c_api.h>
8
8
 
9
- #include <util/generic/vector.h>
9
+ #include <vector>
10
10
 
11
- // Using STD version of string as it is used by N-API.
11
+ // used by N-API.
12
12
  #include <string>
13
13
 
14
14
  namespace NHelper {
@@ -17,8 +17,7 @@ namespace NHelper {
17
17
  // Returns false if check failed.
18
18
  inline bool Check(Napi::Env env, bool condition, const std::string& message) {
19
19
  if (!condition) {
20
- Napi::TypeError::New(env, message)
21
- .ThrowAsJavaScriptException();
20
+ Napi::TypeError::New(env, message).ThrowAsJavaScriptException();
22
21
  }
23
22
 
24
23
  return condition;
@@ -50,14 +49,16 @@ inline bool CheckStatus(Napi::Env& env, bool status) {
50
49
  enum ENApiType {
51
50
  NAT_NUMBER,
52
51
  NAT_STRING,
52
+ NAT_NUMBER_OR_STRING,
53
+ NAT_ARRAY_OR_NUMBERS
53
54
  };
54
55
 
55
56
  // Checks if the value a matrix with element of a given type.
56
- bool IsMatrix(const Napi::Value& value, ENApiType type);
57
+ bool CheckIsMatrix(Napi::Env env, const Napi::Value& value, ENApiType type, const std::string& errorStr);
57
58
 
58
59
  // Converts vector of numbers to N-API array.
59
60
  template <typename T>
60
- Napi::Array ConvertToArray(Napi::Env env, const TVector<T>& values) {
61
+ Napi::Array ConvertToArray(Napi::Env env, const std::vector<T>& values) {
61
62
  Napi::Array result = Napi::Array::New(env);
62
63
  uint32_t index = 0;
63
64
  for (const auto value: values) {
@@ -3,8 +3,7 @@
3
3
  #include <napi.h>
4
4
 
5
5
  Napi::Object Init(Napi::Env env, Napi::Object exports) {
6
- exports.Set(Napi::String::New(env, "Model"),
7
- NNodeCatBoost::TModel::GetClass(env));
6
+ exports.Set(Napi::String::New(env, "Model"), NNodeCatBoost::TModel::GetClass(env));
8
7
  return exports;
9
8
  }
10
9