@datagrok/eda 1.1.30 → 1.1.32
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +8 -0
- package/README.md +1 -0
- package/dist/23.js.map +1 -1
- package/dist/242.js +1 -1
- package/dist/242.js.map +1 -1
- package/dist/449.js +2 -0
- package/dist/449.js.map +1 -0
- package/dist/738.js +1 -1
- package/dist/738.js.map +1 -1
- package/dist/77573759e3857711e15b.wasm +0 -0
- package/dist/990.js +2 -0
- package/dist/990.js.map +1 -0
- package/dist/package-test.js +1 -1
- package/dist/package-test.js.map +1 -1
- package/dist/package.js +1 -1
- package/dist/package.js.map +1 -1
- package/package.json +92 -91
- package/src/missing-values-imputation/ui.ts +4 -4
- package/src/package-test.ts +2 -0
- package/src/package.ts +65 -3
- package/src/pls/pls-constants.ts +21 -5
- package/src/pls/pls-tools.ts +8 -2
- package/src/tests/classifiers-tests.ts +114 -0
- package/src/tests/linear-methods-tests.ts +150 -0
- package/src/tests/utils.ts +121 -0
- package/src/xgbooster.ts +260 -0
- package/wasm/XGBoostAPI.js +32 -0
- package/wasm/XGBoostAPI.wasm +0 -0
- package/wasm/XGBoostAPIinWebWorker.js +32 -0
- package/wasm/callWasmForWebWorker.js +11 -8
- package/wasm/workers/xgboostWorker.js +67 -0
- package/wasm/xgboost/CMakeLists.txt +23 -0
- package/wasm/xgboost/commands.txt +12 -0
- package/wasm/xgboost/xgboost/README.txt +1 -0
- package/wasm/xgboost/xgboost-api.cpp +134 -0
- package/wasm/xgbooster.js +161 -0
- package/dist/317.js +0 -2
- package/dist/317.js.map +0 -1
package/package.json
CHANGED
|
@@ -1,94 +1,95 @@
|
|
|
1
1
|
{
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
2
|
+
"name": "@datagrok/eda",
|
|
3
|
+
"friendlyName": "EDA",
|
|
4
|
+
"version": "1.1.32",
|
|
5
|
+
"description": "Exploratory Data Analysis Tools",
|
|
6
|
+
"dependencies": {
|
|
7
|
+
"@datagrok-libraries/math": "^1.1.11",
|
|
8
|
+
"@datagrok-libraries/ml": "^6.6.15",
|
|
9
|
+
"@datagrok-libraries/tutorials": "^1.3.13",
|
|
10
|
+
"@datagrok-libraries/utils": "^4.2.20",
|
|
11
|
+
"@keckelt/tsne": "^1.0.2",
|
|
12
|
+
"@webgpu/types": "^0.1.40",
|
|
13
|
+
"cash-dom": "^8.1.1",
|
|
14
|
+
"datagrok-api": "^1.20.1",
|
|
15
|
+
"dayjs": "^1.11.9",
|
|
16
|
+
"jstat": "^1.9.6",
|
|
17
|
+
"source-map-loader": "^4.0.1",
|
|
18
|
+
"umap-js": "^1.3.3",
|
|
19
|
+
"worker-loader": "latest"
|
|
20
|
+
},
|
|
21
|
+
"author": {
|
|
22
|
+
"name": "Viktor Makarichev",
|
|
23
|
+
"email": "vmakarichev@datagrok.ai"
|
|
24
|
+
},
|
|
25
|
+
"devDependencies": {
|
|
26
|
+
"@typescript-eslint/eslint-plugin": "^5.32.0",
|
|
27
|
+
"@typescript-eslint/parser": "^5.32.0",
|
|
28
|
+
"css-loader": "latest",
|
|
29
|
+
"eslint": "^8.21.0",
|
|
30
|
+
"eslint-config-google": "^0.14.0",
|
|
31
|
+
"style-loader": "latest",
|
|
32
|
+
"ts-loader": "latest",
|
|
33
|
+
"typescript": "latest",
|
|
34
|
+
"webpack": "latest",
|
|
35
|
+
"webpack-cli": "latest"
|
|
36
|
+
},
|
|
37
|
+
"scripts": {
|
|
38
|
+
"link-all": "npm link datagrok-api @datagrok-libraries/utils @datagrok-libraries/tutorials",
|
|
39
|
+
"debug-eda": "webpack && grok publish",
|
|
40
|
+
"release-eda": "webpack && grok publish --release",
|
|
41
|
+
"build-eda": "webpack",
|
|
42
|
+
"build": "webpack",
|
|
43
|
+
"debug-eda-dev": "webpack && grok publish dev",
|
|
44
|
+
"release-eda-dev": "webpack && grok publish dev --release",
|
|
45
|
+
"debug-eda-local": "webpack && grok publish local",
|
|
46
|
+
"release-eda-local": "webpack && grok publish local --release",
|
|
47
|
+
"build-all": "npm --prefix ./../../js-api run build && npm --prefix ./../../libraries/utils run build && npm --prefix ./../../libraries/tutorials run build && npm run build"
|
|
48
|
+
},
|
|
49
|
+
"canEdit": [
|
|
50
|
+
"Developers"
|
|
51
|
+
],
|
|
52
|
+
"canView": [
|
|
53
|
+
"All users"
|
|
54
|
+
],
|
|
55
|
+
"repository": {
|
|
56
|
+
"type": "git",
|
|
57
|
+
"url": "https://github.com/datagrok-ai/public.git",
|
|
58
|
+
"directory": "packages/EDA"
|
|
59
|
+
},
|
|
60
|
+
"category": "Machine Learning",
|
|
61
|
+
"sources": [
|
|
62
|
+
"wasm/EDA.js",
|
|
63
|
+
"wasm/XGBoostAPI.js"
|
|
64
|
+
],
|
|
65
|
+
"meta": {
|
|
66
|
+
"menu": {
|
|
67
|
+
"ML": {
|
|
68
|
+
"Tools": {
|
|
69
|
+
"Impute Missing Values...": null,
|
|
70
|
+
"Random Data...": null
|
|
71
|
+
},
|
|
72
|
+
"Cluster": {
|
|
73
|
+
"Cluster...": null,
|
|
74
|
+
"DBSCAN...": null
|
|
75
|
+
},
|
|
76
|
+
"Notebooks": {
|
|
77
|
+
"Browse Notebooks": null,
|
|
78
|
+
"Open in Notebook": null,
|
|
79
|
+
"New Notebook": null
|
|
80
|
+
},
|
|
81
|
+
"Models": {
|
|
82
|
+
"Browse Models": null,
|
|
83
|
+
"Train Model...": null,
|
|
84
|
+
"Apply Model...": null
|
|
85
|
+
},
|
|
86
|
+
"Analyse": {
|
|
87
|
+
"PCA...": null,
|
|
88
|
+
"ANOVA...": null,
|
|
89
|
+
"Multivariate Analysis...": null
|
|
90
|
+
},
|
|
91
|
+
"Reduce Dimensionality": null
|
|
92
|
+
}
|
|
93
93
|
}
|
|
94
|
+
}
|
|
94
95
|
}
|
|
@@ -117,15 +117,15 @@ export async function runKNNImputer(df?: DG.DataFrame): Promise<void> {
|
|
|
117
117
|
|
|
118
118
|
// Target columns components (cols with missing values to be imputed)
|
|
119
119
|
let targetColNames = colsWithMissingVals.map((col) => col.name);
|
|
120
|
-
const targetColInput = ui.input.columns(TITLE.COLUMNS, {table: df, onValueChanged: () => {
|
|
120
|
+
const targetColInput = ui.input.columns(TITLE.COLUMNS, {table: df, value: df.columns.byNames(availableTargetColsNames), onValueChanged: () => {
|
|
121
121
|
targetColNames = targetColInput.value.map((col) => col.name);
|
|
122
122
|
checkApplicability();
|
|
123
|
-
}, available: availableTargetColsNames
|
|
123
|
+
}, available: availableTargetColsNames});
|
|
124
124
|
targetColInput.setTooltip(HINT.TARGET);
|
|
125
125
|
|
|
126
126
|
// Feature columns components
|
|
127
127
|
let selectedFeatureColNames = availableFeatureColsNames as string[];
|
|
128
|
-
const featuresInput = ui.input.columns(TITLE.FEATURES, {table: df, onValueChanged: () => {
|
|
128
|
+
const featuresInput = ui.input.columns(TITLE.FEATURES, {value: df.columns.byNames(availableFeatureColsNames), table: df, onValueChanged: () => {
|
|
129
129
|
selectedFeatureColNames = featuresInput.value.map((col) => col.name);
|
|
130
130
|
|
|
131
131
|
if (selectedFeatureColNames.length > 0) {
|
|
@@ -133,7 +133,7 @@ export async function runKNNImputer(df?: DG.DataFrame): Promise<void> {
|
|
|
133
133
|
metricInfoInputs.forEach((div, name) => div.hidden = !selectedFeatureColNames.includes(name));
|
|
134
134
|
} else
|
|
135
135
|
hideWidgets();
|
|
136
|
-
}, available: availableFeatureColsNames
|
|
136
|
+
}, available: availableFeatureColsNames});
|
|
137
137
|
featuresInput.setTooltip(HINT.FEATURES);
|
|
138
138
|
|
|
139
139
|
/** Hide widgets (use if run is not applicable) */
|
package/src/package-test.ts
CHANGED
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
import * as DG from 'datagrok-api/dg';
|
|
2
2
|
import {runTests, tests, TestContext} from '@datagrok-libraries/utils/src/test';
|
|
3
3
|
import './tests/dim-reduction-tests';
|
|
4
|
+
import './tests/linear-methods-tests';
|
|
5
|
+
import './tests/classifiers-tests';
|
|
4
6
|
export const _package = new DG.Package();
|
|
5
7
|
export {tests};
|
|
6
8
|
|
package/src/package.ts
CHANGED
|
@@ -34,6 +34,9 @@ import {getLinearRegressionParams, getPredictionByLinearRegression} from './regr
|
|
|
34
34
|
import {PlsModel} from './pls/pls-ml';
|
|
35
35
|
import {SoftmaxClassifier} from './softmax-classifier';
|
|
36
36
|
|
|
37
|
+
import {initXgboost} from '../wasm/xgbooster';
|
|
38
|
+
import {XGBooster} from './xgbooster';
|
|
39
|
+
|
|
37
40
|
export const _package = new DG.Package();
|
|
38
41
|
|
|
39
42
|
//name: info
|
|
@@ -44,6 +47,7 @@ export function info() {
|
|
|
44
47
|
//tags: init
|
|
45
48
|
export async function init(): Promise<void> {
|
|
46
49
|
await _initEDAAPI();
|
|
50
|
+
await initXgboost();
|
|
47
51
|
}
|
|
48
52
|
|
|
49
53
|
//top-menu: ML | Cluster | DBSCAN...
|
|
@@ -193,7 +197,7 @@ export function GetMCLEditor(call: DG.FuncCall): void {
|
|
|
193
197
|
df: params.table, cols: params.columns, metrics: params.distanceMetrics,
|
|
194
198
|
weights: params.weights, aggregationMethod: params.aggreaggregationMethod, preprocessingFuncs: params.preprocessingFunctions,
|
|
195
199
|
preprocessingFuncArgs: params.preprocessingFuncArgs, threshold: params.threshold, maxIterations: params.maxIterations,
|
|
196
|
-
useWebGPU: params.useWebGPU, inflate: params.inflateFactor,
|
|
200
|
+
useWebGPU: params.useWebGPU, inflate: params.inflateFactor, minClusterSize: params.minClusterSize,
|
|
197
201
|
}).call(true);
|
|
198
202
|
}).show();
|
|
199
203
|
} catch (err: any) {
|
|
@@ -219,10 +223,12 @@ export function GetMCLEditor(call: DG.FuncCall): void {
|
|
|
219
223
|
//input: int maxIterations = 10
|
|
220
224
|
//input: bool useWebGPU = false
|
|
221
225
|
//input: double inflate = 2
|
|
226
|
+
//input: int minClusterSize = 5
|
|
222
227
|
//editor: EDA: GetMCLEditor
|
|
223
228
|
export async function MCL(df: DG.DataFrame, cols: DG.Column[], metrics: KnownMetrics[],
|
|
224
229
|
weights: number[], aggregationMethod: DistanceAggregationMethod, preprocessingFuncs: (DG.Func | null | undefined)[],
|
|
225
230
|
preprocessingFuncArgs: any[], threshold: number = 80, maxIterations: number = 10, useWebGPU: boolean = false, inflate: number = 0,
|
|
231
|
+
minClusterSize: number = 5,
|
|
226
232
|
): Promise< DG.ScatterPlotViewer | undefined> {
|
|
227
233
|
const tv = grok.shell.tableView(df.name) ?? grok.shell.addTableView(df);
|
|
228
234
|
const serializedOptions: string = JSON.stringify({
|
|
@@ -236,6 +242,7 @@ export async function MCL(df: DG.DataFrame, cols: DG.Column[], metrics: KnownMet
|
|
|
236
242
|
maxIterations: maxIterations,
|
|
237
243
|
useWebGPU: useWebGPU,
|
|
238
244
|
inflate: inflate,
|
|
245
|
+
minClusterSize: minClusterSize ?? 5,
|
|
239
246
|
} satisfies MCLSerializableOptions);
|
|
240
247
|
df.setTag(MCL_OPTIONS_TAG, serializedOptions);
|
|
241
248
|
|
|
@@ -255,9 +262,12 @@ export async function MCLInitializationFunction(sc: DG.ScatterPlotViewer) {
|
|
|
255
262
|
const options: MCLSerializableOptions = JSON.parse(mclTag);
|
|
256
263
|
const cols = options.cols.map((colName) => df.columns.byName(colName));
|
|
257
264
|
const preprocessingFuncs = options.preprocessingFuncs.map((funcName) => funcName ? DG.Func.byName(funcName) : null);
|
|
265
|
+
// let presetMatrix = null;
|
|
266
|
+
// if (df.temp['sparseMatrix'])
|
|
267
|
+
// presetMatrix = df.temp['sparseMatrix'];
|
|
258
268
|
const res = await markovCluster(df, cols, options.metrics, options.weights,
|
|
259
269
|
options.aggregationMethod, preprocessingFuncs, options.preprocessingFuncArgs, options.threshold,
|
|
260
|
-
options.maxIterations, options.useWebGPU, options.inflate, sc);
|
|
270
|
+
options.maxIterations, options.useWebGPU, options.inflate, options.minClusterSize, sc /**presetMatrix */);
|
|
261
271
|
return res?.sc;
|
|
262
272
|
}
|
|
263
273
|
|
|
@@ -297,7 +307,7 @@ export async function MVA(): Promise<void> {
|
|
|
297
307
|
//description: Multidimensional data analysis using partial least squares (PLS) regression. It identifies latent factors and constructs a linear model based on them.
|
|
298
308
|
//meta.demoPath: Compute | Multivariate analysis
|
|
299
309
|
export async function demoMultivariateAnalysis(): Promise<any> {
|
|
300
|
-
runDemoMVA();
|
|
310
|
+
await runDemoMVA();
|
|
301
311
|
}
|
|
302
312
|
|
|
303
313
|
//name: trainLinearKernelSVM
|
|
@@ -734,3 +744,55 @@ export async function visualizePLSRegression(df: DG.DataFrame, targetColumn: DG.
|
|
|
734
744
|
export function isInteractivePLSRegression(df: DG.DataFrame, predictColumn: DG.Column): boolean {
|
|
735
745
|
return PlsModel.isInteractive(df.columns, predictColumn);
|
|
736
746
|
}
|
|
747
|
+
|
|
748
|
+
//name: trainXGBooster
|
|
749
|
+
//meta.mlname: XGBoost
|
|
750
|
+
//meta.mlrole: train
|
|
751
|
+
//input: dataframe df
|
|
752
|
+
//input: column predictColumn
|
|
753
|
+
//input: int iterations = 20 {min: 1; max: 100} [Number of training iterations]
|
|
754
|
+
//input: double eta = 0.3 {caption: Rate; min: 0; max: 1} [Learning rate]
|
|
755
|
+
//input: int maxDepth = 6 {min: 0; max: 20} [Maximum depth of a tree]
|
|
756
|
+
//input: double lambda = 1 {min: 0; max: 100} [L2 regularization term]
|
|
757
|
+
//input: double alpha = 0 {min: 0; max: 100} [L1 regularization term]
|
|
758
|
+
//output: dynamic model
|
|
759
|
+
export async function trainXGBooster(df: DG.DataFrame, predictColumn: DG.Column,
|
|
760
|
+
iterations: number, eta: number, maxDepth: number, lambda: number, alpha: number): Promise<Uint8Array> {
|
|
761
|
+
const features = df.columns;
|
|
762
|
+
|
|
763
|
+
const booster = new XGBooster();
|
|
764
|
+
await booster.fit(features, predictColumn, iterations, eta, maxDepth, lambda, alpha);
|
|
765
|
+
|
|
766
|
+
return booster.toBytes();
|
|
767
|
+
}
|
|
768
|
+
|
|
769
|
+
//name: applyXGBooster
|
|
770
|
+
//meta.mlname: XGBoost
|
|
771
|
+
//meta.mlrole: apply
|
|
772
|
+
//input: dataframe df
|
|
773
|
+
//input: dynamic model
|
|
774
|
+
//output: dataframe table
|
|
775
|
+
export function applyXGBooster(df: DG.DataFrame, model: any): DG.DataFrame {
|
|
776
|
+
const unpackedModel = new XGBooster(model);
|
|
777
|
+
return DG.DataFrame.fromColumns([unpackedModel.predict(df.columns)]);
|
|
778
|
+
}
|
|
779
|
+
|
|
780
|
+
//name: isInteractiveXGBooster
|
|
781
|
+
//meta.mlname: XGBoost
|
|
782
|
+
//meta.mlrole: isInteractive
|
|
783
|
+
//input: dataframe df
|
|
784
|
+
//input: column predictColumn
|
|
785
|
+
//output: bool result
|
|
786
|
+
export function isInteractiveXGBooster(df: DG.DataFrame, predictColumn: DG.Column): boolean {
|
|
787
|
+
return XGBooster.isInteractive(df.columns, predictColumn);
|
|
788
|
+
}
|
|
789
|
+
|
|
790
|
+
//name: isApplicableXGBooster
|
|
791
|
+
//meta.mlname: XGBoost
|
|
792
|
+
//meta.mlrole: isApplicable
|
|
793
|
+
//input: dataframe df
|
|
794
|
+
//input: column predictColumn
|
|
795
|
+
//output: bool result
|
|
796
|
+
export function isApplicableXGBooster(df: DG.DataFrame, predictColumn: DG.Column): boolean {
|
|
797
|
+
return XGBooster.isApplicable(df.columns, predictColumn);
|
|
798
|
+
}
|
package/src/pls/pls-constants.ts
CHANGED
|
@@ -35,6 +35,7 @@ export enum TITLE {
|
|
|
35
35
|
EXPL_VAR = 'Explained Variance',
|
|
36
36
|
EXPLORE = 'Explore',
|
|
37
37
|
FEATURES = 'Feature names',
|
|
38
|
+
BROWSE = 'Browse',
|
|
38
39
|
}
|
|
39
40
|
|
|
40
41
|
/** Tooltips */
|
|
@@ -115,11 +116,26 @@ The method finds the latent factors that
|
|
|
115
116
|
|
|
116
117
|
/** Description of demo results: wizard components */
|
|
117
118
|
export const DEMO_RESULTS = [
|
|
118
|
-
{
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
{
|
|
119
|
+
{
|
|
120
|
+
caption: TITLE.MODEL,
|
|
121
|
+
text: 'Closer to the line means better price prediction.',
|
|
122
|
+
},
|
|
123
|
+
{
|
|
124
|
+
caption: TITLE.SCORES,
|
|
125
|
+
text: 'The latent factor values for each sample reflect the similarities and dissimilarities among observations.',
|
|
126
|
+
},
|
|
127
|
+
{
|
|
128
|
+
caption: TITLE.LOADINGS,
|
|
129
|
+
text: 'The impact of each feature on the latent factors: higher loading means stronger influence.',
|
|
130
|
+
},
|
|
131
|
+
{
|
|
132
|
+
caption: TITLE.REGR_COEFS,
|
|
133
|
+
text: 'Parameters of the obtained linear model: features make different contribution to the prediction.',
|
|
134
|
+
},
|
|
135
|
+
{
|
|
136
|
+
caption: TITLE.EXPL_VAR,
|
|
137
|
+
text: 'How well the latent components fit source data: closer to one means better fit.',
|
|
138
|
+
},
|
|
123
139
|
];
|
|
124
140
|
|
|
125
141
|
/** Form results markdown for demo app */
|
package/src/pls/pls-tools.ts
CHANGED
|
@@ -110,7 +110,11 @@ async function performMVA(input: PlsInput, analysisType: PLS_ANALYSIS): Promise<
|
|
|
110
110
|
if (analysisType === PLS_ANALYSIS.COMPUTE_COMPONENTS)
|
|
111
111
|
return;
|
|
112
112
|
|
|
113
|
-
const view = grok.shell.tableView(input.table.name);
|
|
113
|
+
//const view = grok.shell.tableView(input.table.name);
|
|
114
|
+
|
|
115
|
+
const view = (analysisType === PLS_ANALYSIS.DEMO) ?
|
|
116
|
+
(grok.shell.view(TITLE.BROWSE) as DG.BrowseView).preview as DG.TableView :
|
|
117
|
+
grok.shell.tableView(input.table.name);
|
|
114
118
|
|
|
115
119
|
// 0.1 Buffer table
|
|
116
120
|
const buffer = DG.DataFrame.fromColumns([
|
|
@@ -248,7 +252,9 @@ async function performMVA(input: PlsInput, analysisType: PLS_ANALYSIS): Promise<
|
|
|
248
252
|
|
|
249
253
|
/** Run multivariate analysis (PLS) */
|
|
250
254
|
export async function runMVA(analysisType: PLS_ANALYSIS): Promise<void> {
|
|
251
|
-
const table =
|
|
255
|
+
const table = (analysisType === PLS_ANALYSIS.DEMO) ?
|
|
256
|
+
((grok.shell.view(TITLE.BROWSE) as DG.BrowseView).preview as DG.TableView).table :
|
|
257
|
+
grok.shell.t;
|
|
252
258
|
|
|
253
259
|
if (table === null) {
|
|
254
260
|
grok.shell.warning(ERROR_MSG.NO_DF);
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
// Tests for classifiers
|
|
2
|
+
|
|
3
|
+
import * as grok from 'datagrok-api/grok';
|
|
4
|
+
import * as ui from 'datagrok-api/ui';
|
|
5
|
+
import * as DG from 'datagrok-api/dg';
|
|
6
|
+
import {_package} from '../package-test';
|
|
7
|
+
|
|
8
|
+
import {category, expect, test} from '@datagrok-libraries/utils/src/test';
|
|
9
|
+
|
|
10
|
+
import {classificationDataset, accuracy} from './utils';
|
|
11
|
+
import {SoftmaxClassifier} from '../softmax-classifier';
|
|
12
|
+
import {XGBooster} from '../xgbooster';
|
|
13
|
+
|
|
14
|
+
const ROWS_K = 50;
|
|
15
|
+
const MIN_COLS = 2;
|
|
16
|
+
const COLS = 100;
|
|
17
|
+
const TIMEOUT = 8000;
|
|
18
|
+
const MIN_ACCURACY = 0.9;
|
|
19
|
+
|
|
20
|
+
category('Softmax', () => {
|
|
21
|
+
test(`Performance: ${ROWS_K}K samples, ${COLS} features`, async () => {
|
|
22
|
+
// Data
|
|
23
|
+
const df = classificationDataset(ROWS_K * 1000, COLS, false);
|
|
24
|
+
const features = df.columns;
|
|
25
|
+
const target = features.byIndex(COLS);
|
|
26
|
+
features.remove(target.name);
|
|
27
|
+
|
|
28
|
+
// Fit & pack trained model
|
|
29
|
+
const model = new SoftmaxClassifier({
|
|
30
|
+
classesCount: target.categories.length,
|
|
31
|
+
featuresCount: features.length,
|
|
32
|
+
});
|
|
33
|
+
await model.fit(features, target);
|
|
34
|
+
const modelBytes = model.toBytes();
|
|
35
|
+
|
|
36
|
+
// Unpack & apply model
|
|
37
|
+
const unpackedModel = new SoftmaxClassifier(undefined, modelBytes);
|
|
38
|
+
unpackedModel.predict(features);
|
|
39
|
+
}, {timeout: TIMEOUT, benchmark: true});
|
|
40
|
+
|
|
41
|
+
test('Correctness', async () => {
|
|
42
|
+
// Prepare data
|
|
43
|
+
const df = classificationDataset(ROWS_K, MIN_COLS, true);
|
|
44
|
+
const features = df.columns;
|
|
45
|
+
const target = features.byIndex(MIN_COLS);
|
|
46
|
+
features.remove(target.name);
|
|
47
|
+
|
|
48
|
+
// Fit & pack trained model
|
|
49
|
+
const model = new SoftmaxClassifier({
|
|
50
|
+
classesCount: target.categories.length,
|
|
51
|
+
featuresCount: features.length,
|
|
52
|
+
});
|
|
53
|
+
|
|
54
|
+
await model.fit(features, target);
|
|
55
|
+
const modelBytes = model.toBytes();
|
|
56
|
+
|
|
57
|
+
// Unpack & apply model
|
|
58
|
+
const unpackedModel = new SoftmaxClassifier(undefined, modelBytes);
|
|
59
|
+
const prediction = unpackedModel.predict(features);
|
|
60
|
+
|
|
61
|
+
// Evaluate accuracy
|
|
62
|
+
const acc = accuracy(target, prediction);
|
|
63
|
+
expect(
|
|
64
|
+
acc > MIN_ACCURACY,
|
|
65
|
+
true,
|
|
66
|
+
`Softmax failed, too small accuracy: ${acc}; expected: <= ${MIN_ACCURACY}`,
|
|
67
|
+
);
|
|
68
|
+
}, {timeout: TIMEOUT});
|
|
69
|
+
}); // Softmax
|
|
70
|
+
|
|
71
|
+
category('XGBoost', () => {
|
|
72
|
+
test(`Performance: ${ROWS_K}K samples, ${COLS} features`, async () => {
|
|
73
|
+
// Data
|
|
74
|
+
const df = classificationDataset(ROWS_K * 1000, COLS, false);
|
|
75
|
+
const features = df.columns;
|
|
76
|
+
const target = features.byIndex(COLS);
|
|
77
|
+
features.remove(target.name);
|
|
78
|
+
|
|
79
|
+
// Fit & pack trained model
|
|
80
|
+
const model = new XGBooster();
|
|
81
|
+
await model.fit(features, target);
|
|
82
|
+
const modelBytes = model.toBytes();
|
|
83
|
+
|
|
84
|
+
// Unpack & apply model
|
|
85
|
+
const unpackedModel = new XGBooster(modelBytes);
|
|
86
|
+
unpackedModel.predict(features);
|
|
87
|
+
}, {timeout: TIMEOUT, benchmark: true});
|
|
88
|
+
|
|
89
|
+
test('Correctness', async () => {
|
|
90
|
+
// Prepare data
|
|
91
|
+
const df = classificationDataset(ROWS_K, MIN_COLS, true);
|
|
92
|
+
const features = df.columns;
|
|
93
|
+
const target = features.byIndex(MIN_COLS);
|
|
94
|
+
features.remove(target.name);
|
|
95
|
+
|
|
96
|
+
// Fit & pack trained model
|
|
97
|
+
const model = new XGBooster();
|
|
98
|
+
|
|
99
|
+
await model.fit(features, target);
|
|
100
|
+
const modelBytes = model.toBytes();
|
|
101
|
+
|
|
102
|
+
// Unpack & apply model
|
|
103
|
+
const unpackedModel = new XGBooster(modelBytes);
|
|
104
|
+
const prediction = unpackedModel.predict(features);
|
|
105
|
+
|
|
106
|
+
// Evaluate accuracy
|
|
107
|
+
const acc = accuracy(target, prediction);
|
|
108
|
+
expect(
|
|
109
|
+
acc > MIN_ACCURACY,
|
|
110
|
+
true,
|
|
111
|
+
`XGBoost failed, too small accuracy: ${acc}; expected: <= ${MIN_ACCURACY}`,
|
|
112
|
+
);
|
|
113
|
+
}, {timeout: TIMEOUT});
|
|
114
|
+
}); // XGBoost
|