@datagrok/eda 1.4.12 → 1.4.13

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -8,9 +8,8 @@ import * as DG from 'datagrok-api/dg';
8
8
  //@ts-ignore: no types
9
9
  import * as jStat from 'jstat';
10
10
 
11
- import {Cutoff, DescriptorStatistics, SigmoidParams} from './pmpo-defs';
12
-
13
- const SQRT_2_PI = Math.sqrt(2 * Math.PI);
11
+ import {ConfusionMatrix, Cutoff, DescriptorStatistics, ModelEvaluationResult,
12
+ ROC_TRESHOLDS, ROC_TRESHOLDS_COUNT, SigmoidParams} from './pmpo-defs';
14
13
 
15
14
  /** Splits the dataframe into desired and non-desired tables based on the desirability column */
16
15
  export function getDesiredTables(df: DG.DataFrame, desirability: DG.Column) {
@@ -70,6 +69,8 @@ export function getDescriptorStatistics(des: DG.Column, nonDes: DG.Column): Desc
70
69
  nonDesAvg: nonDesAvg,
71
70
  nonDesStd: nonDesStd,
72
71
  nonSesLen: nonDesLen,
72
+ min: Math.min(des.stats.min, nonDes.stats.min),
73
+ max: Math.max(des.stats.max, nonDes.stats.max),
73
74
  tstat: t,
74
75
  pValue: pValue,
75
76
  };
@@ -163,6 +164,140 @@ export function sigmoidS(x: number, x0: number, b: number, c: number): number {
163
164
  }
164
165
 
165
166
  /** Normal probability density function */
166
- export function normalPdf(x: number, mu: number, sigma: number): number {
167
- return Math.exp(-((x - mu)**2) / (2 * sigma**2)) / (sigma * SQRT_2_PI);
167
+ export function gaussDesirabilityFunc(x: number, mu: number, sigma: number): number {
168
+ return Math.exp(-((x - mu)**2) / (2 * sigma**2));
168
169
  }
170
+
171
+ /** Computes the confusion matrix given desirability (labels) and prediction columns
172
+ * @param desirability - desirability column (boolean)
173
+ * @param prediction - prediction column (numeric)
174
+ * @param threshold - threshold to convert prediction scores to binary labels
175
+ * @return ConfusionMatrix object with TP, TN, FP, FN counts
176
+ */
177
+ export function getConfusionMatrix(desirability: DG.Column, prediction: DG.Column, threshold: number): ConfusionMatrix {
178
+ if (desirability.length !== prediction.length)
179
+ throw new Error('Failed to compute confusion matrix: columns have different lengths.');
180
+
181
+ if (desirability.type !== DG.COLUMN_TYPE.BOOL)
182
+ throw new Error('Failed to compute confusion matrix: desirability column must be boolean.');
183
+
184
+ if (!prediction.isNumerical)
185
+ throw new Error('Failed to compute confusion matrix: prediction column must be numerical.');
186
+
187
+ let TP = 0;
188
+ let TN = 0;
189
+ let FP = 0;
190
+ let FN = 0;
191
+
192
+ const desRaw = desirability.getRawData();
193
+ const predRaw = prediction.getRawData();
194
+
195
+ let desIdx = 0;
196
+ let curPos = 0;
197
+ let desElem = desRaw[0];
198
+
199
+ // Here, we extract bits from the desirability boolean column in chunks of 32 bits
200
+ for (let predIdx = 0; predIdx < prediction.length; ++predIdx) {
201
+ // console.log(predIdx + 1, ': ',
202
+ // desirability.get(predIdx), '<-->', (desElem >>> curPos) & 1, ' vs ', predRaw[predIdx] >= threshold);
203
+
204
+ if (((desElem >>> curPos) & 1) == 1) { // True actual
205
+ if (predRaw[predIdx] >= threshold) { // True predicted
206
+ ++TP;
207
+ } else { // False predicted
208
+ ++FN;
209
+ }
210
+ } else { // False actual
211
+ if (predRaw[predIdx] >= threshold) { // True predicted
212
+ ++FP;
213
+ } else { // False predicted
214
+ ++TN;
215
+ }
216
+ }
217
+
218
+ ++curPos;
219
+
220
+ // Move to the next desirability element if we have processed 32 bits
221
+ if (curPos >= 32) {
222
+ curPos = 0;
223
+ ++desIdx;
224
+ desElem = desRaw[desIdx];
225
+ }
226
+ } // for predIdx
227
+
228
+ return {TP: TP, TN: TN, FP: FP, FN: FN};
229
+ } // getConfusionMatrix
230
+
231
+ /** Computes Area Under Curve (AUC) given TPR and FPR arrays
232
+ * @param tpr - True Positive Rate array
233
+ * @param fpr - False Positive Rate array
234
+ * @return AUC value
235
+ */
236
+ export function getAuc(tpr: Float32Array, fpr: Float32Array): number {
237
+ if (tpr.length !== fpr.length)
238
+ throw new Error('Failed to compute AUC: TPR and FPR arrays have different lengths.');
239
+
240
+ let auc = 0.0;
241
+
242
+ for (let i = 1; i < tpr.length; ++i) {
243
+ const xDiff = Math.abs(fpr[i] - fpr[i - 1]);
244
+ const yAvg = (tpr[i] + tpr[i - 1]) / 2.0;
245
+ auc += xDiff * yAvg;
246
+ }
247
+
248
+ return auc;
249
+ } // getAuc
250
+
251
+ /** Converts numeric prediction column to boolean based on the given threshold
252
+ * @param numericPrediction - numeric prediction column
253
+ * @param threshold - threshold to convert prediction scores to binary labels
254
+ * @param name - name for the resulting boolean column
255
+ * @return Boolean prediction column
256
+ */
257
+ export function getBoolPredictionColumn(numericPrediction: DG.Column, threshold: number, name: string): DG.Column {
258
+ if (!numericPrediction.isNumerical)
259
+ throw new Error('Failed to compute confusion matrix: prediction column must be numerical.');
260
+
261
+ const size = numericPrediction.length;
262
+ const boolPredData = new Array<boolean>(size);
263
+ const predRaw = numericPrediction.getRawData();
264
+
265
+ for (let i = 0; i < size; ++i)
266
+ boolPredData[i] = (predRaw[i] >= threshold);
267
+
268
+ return DG.Column.fromList(DG.COLUMN_TYPE.BOOL, name, boolPredData);
269
+ } // getBoolPredictionColumn
270
+
271
+ /** Computes pMPO model evaluation metrics: AUC, optimal threshold, TPR and FPR arrays
272
+ * @param desirability - desirability column (boolean)
273
+ * @param prediction - prediction column (numeric)
274
+ * @return ModelEvaluationResult object with AUC, optimal threshold, TPR and FPR arrays
275
+ */
276
+ export function getPmpoEvaluation(desirability: DG.Column, prediction: DG.Column): ModelEvaluationResult {
277
+ const tpr = new Float32Array(ROC_TRESHOLDS_COUNT);
278
+ const fpr = new Float32Array(ROC_TRESHOLDS_COUNT);
279
+
280
+ let bestJ = -1;
281
+ let currentJ = -1;
282
+ let bestThreshold = ROC_TRESHOLDS[0];
283
+
284
+ // Compute TPR and FPR for each threshold
285
+ for (let i = 0; i < ROC_TRESHOLDS_COUNT; ++i) {
286
+ const confusion = getConfusionMatrix(desirability, prediction, ROC_TRESHOLDS[i]);
287
+ tpr[i] = (confusion.TP + confusion.FN) > 0 ? confusion.TP / (confusion.TP + confusion.FN) : 0;
288
+ fpr[i] = (confusion.FP + confusion.TN) > 0 ? confusion.FP / (confusion.FP + confusion.TN) : 0;
289
+ currentJ = tpr[i] - fpr[i];
290
+
291
+ if (currentJ > bestJ) {
292
+ bestJ = currentJ;
293
+ bestThreshold = ROC_TRESHOLDS[i];
294
+ }
295
+ }
296
+
297
+ return {
298
+ auc: getAuc(tpr, fpr),
299
+ threshold: bestThreshold,
300
+ tpr: tpr,
301
+ fpr: fpr,
302
+ };
303
+ } // getPmpoEvaluation
@@ -5,7 +5,7 @@ import * as ui from 'datagrok-api/ui';
5
5
  import * as DG from 'datagrok-api/dg';
6
6
  import {_package} from '../package-test';
7
7
 
8
- import {category, expect, test} from '@datagrok-libraries/utils/src/test';
8
+ import {category, expect, test} from '@datagrok-libraries/test/src/test';
9
9
 
10
10
  import {oneWayAnova, FactorizedData} from '../anova/anova-tools';
11
11
 
@@ -5,7 +5,7 @@ import * as ui from 'datagrok-api/ui';
5
5
  import * as DG from 'datagrok-api/dg';
6
6
  import {_package} from '../package-test';
7
7
 
8
- import {category, expect, test} from '@datagrok-libraries/utils/src/test';
8
+ import {category, expect, test} from '@datagrok-libraries/test/src/test';
9
9
 
10
10
  import {classificationDataset, accuracy} from './utils';
11
11
  import {SoftmaxClassifier} from '../softmax-classifier';
@@ -5,7 +5,7 @@ import {_package} from '../package-test';
5
5
 
6
6
  // tests for dimensionality reduction
7
7
 
8
- import {category, expect, test} from '@datagrok-libraries/utils/src/test';
8
+ import {category, expect, test} from '@datagrok-libraries/test/src/test';
9
9
  import {DimReductionMethods} from '@datagrok-libraries/ml/src/multi-column-dimensionality-reduction/types';
10
10
  import {KnownMetrics, NumberMetricsNames, StringMetricsNames} from '@datagrok-libraries/ml/src/typed-metrics';
11
11
  import {multiColReduceDimensionality}
@@ -5,7 +5,7 @@ import * as ui from 'datagrok-api/ui';
5
5
  import * as DG from 'datagrok-api/dg';
6
6
  import {_package} from '../package-test';
7
7
 
8
- import {category, expect, test} from '@datagrok-libraries/utils/src/test';
8
+ import {category, expect, test} from '@datagrok-libraries/test/src/test';
9
9
  import {computePCA} from '../eda-tools';
10
10
  import {getPlsAnalysis} from '../pls/pls-tools';
11
11
  import {PlsModel} from '../pls/pls-ml';
@@ -5,7 +5,7 @@ import * as ui from 'datagrok-api/ui';
5
5
  import * as DG from 'datagrok-api/dg';
6
6
  import {_package} from '../package-test';
7
7
 
8
- import {category, expect, test} from '@datagrok-libraries/utils/src/test';
8
+ import {category, expect, test} from '@datagrok-libraries/test/src/test';
9
9
 
10
10
  import {MetricInfo, DISTANCE_TYPE, impute} from '../missing-values-imputation/knn-imputer';
11
11
  import {getFeatureInputSettings} from '../missing-values-imputation/ui';
@@ -0,0 +1,253 @@
1
+ // Tests for Pareto Front Computations
2
+ // Performance tests for the Pareto optimality algorithm
3
+
4
+ import * as grok from 'datagrok-api/grok';
5
+ import * as ui from 'datagrok-api/ui';
6
+ import * as DG from 'datagrok-api/dg';
7
+ import {_package} from '../package-test';
8
+
9
+ import {category, expect, test} from '@datagrok-libraries/test/src/test';
10
+
11
+ import {getParetoMask} from '../pareto-optimization/pareto-computations';
12
+ import {OPT_TYPE, NumericArray} from '../pareto-optimization/defs';
13
+
14
+ const TIMEOUT = 5000;
15
+
16
+ // Test dataset sizes
17
+ const ROWS_COUNT = 1000000;
18
+ const M = 1000000;
19
+ const COLS_COUNT = 2;
20
+ const suffix = M < 1e6 ? 'K' : 'M';
21
+ const DATASET_SIZE_LABEL = `${ROWS_COUNT / M}${suffix} points, ${COLS_COUNT}D`;
22
+
23
+ /** Generates synthetic numeric data for Pareto front testing */
24
+ function generateSyntheticData(nPoints: number, nDims: number, seed: number = 42): NumericArray[] {
25
+ const data: NumericArray[] = [];
26
+
27
+ // Simple deterministic pseudo-random generator for reproducibility
28
+ let rng = seed;
29
+ const random = () => {
30
+ rng = (rng * 1664525 + 1013904223) % 4294967296;
31
+ return rng / 4294967296;
32
+ };
33
+
34
+ for (let d = 0; d < nDims; d++) {
35
+ const column = new Float32Array(nPoints);
36
+ for (let i = 0; i < nPoints; i++) {
37
+ // Generate values with some correlation to create realistic Pareto fronts
38
+ column[i] = random() * 100 + (d * 10);
39
+ }
40
+ data.push(column);
41
+ }
42
+
43
+ return data;
44
+ }
45
+
46
+ /** Generates optimization sense array */
47
+ function generateSense(nDims: number, pattern: 'all-min' | 'all-max' | 'mixed'): OPT_TYPE[] {
48
+ const sense: OPT_TYPE[] = [];
49
+
50
+ for (let d = 0; d < nDims; d++) {
51
+ if (pattern === 'all-min') {
52
+ sense.push(OPT_TYPE.MIN);
53
+ } else if (pattern === 'all-max') {
54
+ sense.push(OPT_TYPE.MAX);
55
+ } else {
56
+ // Mixed: alternate between MIN and MAX
57
+ sense.push(d % 2 === 0 ? OPT_TYPE.MIN : OPT_TYPE.MAX);
58
+ }
59
+ }
60
+
61
+ return sense;
62
+ }
63
+
64
+ /** Generates null indices set for testing missing value handling */
65
+ function generateNullIndices(nPoints: number, nullRatio: number): Set<number> {
66
+ const nullCount = Math.floor(nPoints * nullRatio);
67
+ const nullIndices = new Set<number>();
68
+
69
+ // Distribute null indices evenly
70
+ const step = Math.floor(nPoints / nullCount);
71
+ for (let i = 0; i < nullCount; i++) {
72
+ nullIndices.add(i * step);
73
+ }
74
+
75
+ return nullIndices;
76
+ }
77
+
78
+ /** Validates Pareto mask result */
79
+ function validateParetoMask(mask: boolean[], nPoints: number): void {
80
+ if (mask.length !== nPoints) {
81
+ throw new Error(`Invalid mask length: expected ${nPoints}, got ${mask.length}`);
82
+ }
83
+
84
+ const optimalCount = mask.filter(x => x).length;
85
+ if (optimalCount === 0) {
86
+ throw new Error('No optimal points found');
87
+ }
88
+
89
+ if (optimalCount === nPoints) {
90
+ grok.shell.warning('All points are optimal - data may be degenerate');
91
+ }
92
+ }
93
+
94
+ category('Pareto optimization', () => {
95
+
96
+ test(`Performance: ${DATASET_SIZE_LABEL}`, async () => {
97
+ let mask: boolean[] | null = null;
98
+ let error: Error | null = null;
99
+
100
+ try {
101
+ const data = generateSyntheticData(ROWS_COUNT, COLS_COUNT);
102
+ const sense = generateSense(COLS_COUNT, 'mixed');
103
+ mask = getParetoMask(data, sense, ROWS_COUNT);
104
+ validateParetoMask(mask, ROWS_COUNT);
105
+ } catch (e) {
106
+ error = e as Error;
107
+ grok.shell.error(error.message);
108
+ }
109
+
110
+ expect(mask !== null, true, 'Failed to compute Pareto mask');
111
+ expect(error === null, true, error?.message ?? '');
112
+ }, {timeout: TIMEOUT});
113
+
114
+ // Tests for different optimization patterns
115
+ test(`Performance: ${DATASET_SIZE_LABEL}, all minimize`, async () => {
116
+ let mask: boolean[] | null = null;
117
+ let error: Error | null = null;
118
+
119
+ try {
120
+ const data = generateSyntheticData(ROWS_COUNT, COLS_COUNT);
121
+ const sense = generateSense(COLS_COUNT, 'all-min');
122
+ mask = getParetoMask(data, sense, ROWS_COUNT);
123
+ validateParetoMask(mask, ROWS_COUNT);
124
+ } catch (e) {
125
+ error = e as Error;
126
+ grok.shell.error(error.message);
127
+ }
128
+
129
+ expect(mask !== null, true, 'Failed to compute Pareto mask');
130
+ expect(error === null, true, error?.message ?? '');
131
+ }, {timeout: TIMEOUT});
132
+
133
+ test(`Performance: ${DATASET_SIZE_LABEL}, all maximize`, async () => {
134
+ let mask: boolean[] | null = null;
135
+ let error: Error | null = null;
136
+
137
+ try {
138
+ const data = generateSyntheticData(ROWS_COUNT, COLS_COUNT);
139
+ const sense = generateSense(COLS_COUNT, 'all-max');
140
+ mask = getParetoMask(data, sense, ROWS_COUNT);
141
+ validateParetoMask(mask, ROWS_COUNT);
142
+ } catch (e) {
143
+ error = e as Error;
144
+ grok.shell.error(error.message);
145
+ }
146
+
147
+ expect(mask !== null, true, 'Failed to compute Pareto mask');
148
+ expect(error === null, true, error?.message ?? '');
149
+ }, {timeout: TIMEOUT});
150
+
151
+ // Tests with missing values
152
+ test(`Performance: ${DATASET_SIZE_LABEL} with 10% null indices`, async () => {
153
+ let mask: boolean[] | null = null;
154
+ let error: Error | null = null;
155
+
156
+ try {
157
+ const data = generateSyntheticData(ROWS_COUNT, COLS_COUNT);
158
+ const sense = generateSense(COLS_COUNT, 'mixed');
159
+ const nullIndices = generateNullIndices(ROWS_COUNT, 0.1);
160
+ mask = getParetoMask(data, sense, ROWS_COUNT, nullIndices);
161
+ validateParetoMask(mask, ROWS_COUNT);
162
+ } catch (e) {
163
+ error = e as Error;
164
+ grok.shell.error(error.message);
165
+ }
166
+
167
+ expect(mask !== null, true, 'Failed to compute Pareto mask');
168
+ expect(error === null, true, error?.message ?? '');
169
+ }, {timeout: TIMEOUT});
170
+
171
+ test(`Performance: ${DATASET_SIZE_LABEL} with 25% null indices`, async () => {
172
+ let mask: boolean[] | null = null;
173
+ let error: Error | null = null;
174
+
175
+ try {
176
+ const data = generateSyntheticData(ROWS_COUNT, COLS_COUNT);
177
+ const sense = generateSense(COLS_COUNT, 'mixed');
178
+ const nullIndices = generateNullIndices(ROWS_COUNT, 0.25);
179
+ mask = getParetoMask(data, sense, ROWS_COUNT, nullIndices);
180
+ validateParetoMask(mask, ROWS_COUNT);
181
+ } catch (e) {
182
+ error = e as Error;
183
+ grok.shell.error(error.message);
184
+ }
185
+
186
+ expect(mask !== null, true, 'Failed to compute Pareto mask');
187
+ expect(error === null, true, error?.message ?? '');
188
+ }, {timeout: TIMEOUT});
189
+
190
+ // Edge cases
191
+ test('Edge case: Empty dataset', async () => {
192
+ let mask: boolean[] | null = null;
193
+ let error: Error | null = null;
194
+
195
+ try {
196
+ const data: NumericArray[] = [new Float32Array(0), new Float32Array(0)];
197
+ const sense = generateSense(COLS_COUNT, 'mixed');
198
+ mask = getParetoMask(data, sense, 0);
199
+ } catch (e) {
200
+ error = e as Error;
201
+ grok.shell.error(error.message);
202
+ }
203
+
204
+ expect(mask !== null, true, 'Failed to compute Pareto mask');
205
+ expect(mask!.length, 0, 'Empty dataset should return empty mask');
206
+ expect(error === null, true, error?.message ?? '');
207
+ }, {timeout: TIMEOUT});
208
+
209
+ test('Edge case: Single point', async () => {
210
+ let mask: boolean[] | null = null;
211
+ let error: Error | null = null;
212
+
213
+ try {
214
+ const data: NumericArray[] = [new Float32Array([1.0]), new Float32Array([2.0])];
215
+ const sense = generateSense(COLS_COUNT, 'mixed');
216
+
217
+ mask = getParetoMask(data, sense, 1);
218
+ } catch (e) {
219
+ error = e as Error;
220
+ grok.shell.error(error.message);
221
+ }
222
+
223
+ expect(mask !== null, true, 'Failed to compute Pareto mask');
224
+ expect(mask!.length, 1, 'Single point dataset should return mask with one element');
225
+ expect(mask![0], true, 'Single point should be optimal');
226
+ expect(error === null, true, error?.message ?? '');
227
+ }, {timeout: TIMEOUT});
228
+
229
+ test('Edge case: All identical points', async () => {
230
+ let mask: boolean[] | null = null;
231
+ let error: Error | null = null;
232
+
233
+ try {
234
+ const nPoints = 100;
235
+ const data: NumericArray[] = [
236
+ new Float32Array(nPoints).fill(5.0),
237
+ new Float32Array(nPoints).fill(10.0),
238
+ ];
239
+ const sense = generateSense(COLS_COUNT, 'mixed');
240
+
241
+ mask = getParetoMask(data, sense, nPoints);
242
+ } catch (e) {
243
+ error = e as Error;
244
+ grok.shell.error(error.message);
245
+ }
246
+
247
+ expect(mask !== null, true, 'Failed to compute Pareto mask');
248
+ expect(mask!.length, 100, 'Should return mask with correct length');
249
+ const optimalCount = mask!.filter(x => x).length;
250
+ expect(optimalCount > 0, true, 'At least some identical points should be optimal');
251
+ expect(error === null, true, error?.message ?? '');
252
+ }, {timeout: TIMEOUT});
253
+ });
@@ -0,0 +1,157 @@
1
+ // Tests for Probabilistic MPO (pMPO)
2
+ // Reference scores are pre-computed and stored in the 'drugs-props-train-scores.csv' file.
3
+ // This scores are computed using the library: https://github.com/Merck/pmpo
4
+
5
+ import * as grok from 'datagrok-api/grok';
6
+ import * as ui from 'datagrok-api/ui';
7
+ import * as DG from 'datagrok-api/dg';
8
+ import {_package} from '../package-test';
9
+
10
+ import {category, expect, test} from '@datagrok-libraries/test/src/test';
11
+
12
+ import {Pmpo} from '../probabilistic-scoring/prob-scoring';
13
+ import {P_VAL_TRES_DEFAULT, Q_CUTOFF_DEFAULT, R2_DEFAULT, SCORES_PATH,
14
+ SOURCE_PATH} from '../probabilistic-scoring/pmpo-defs';
15
+ import {getSynteticPmpoData} from '../probabilistic-scoring/data-generator';
16
+
17
+ const TIMEOUT = 10000;
18
+ const MAD_THRESH = 1E-6;
19
+
20
+ const DESIRABILITY_COL_NAME = 'CNS';
21
+ const DESCRIPTOR_NAMES = ['TPSA', 'TPSA_S', 'HBA', 'HBD', 'MW', 'nAtoms',
22
+ 'cLogD_ACD_v15', 'mapKa', 'cLogP_Biobyte', 'mbpKa', 'cLogP_ACD_v15', 'ALogP98'];
23
+ const SCORES_NAME = 'Score';
24
+ const DRUG = 'Drug';
25
+
26
+ const SIGMOIDAL = 'Sigmoidal';
27
+ const GAUSSIAN = 'Gaussian';
28
+ const PMPO_MODES = [SIGMOIDAL, GAUSSIAN];
29
+
30
+ const SAMPLES_K = 100;
31
+ const SAMPLES_COUNT = 1000 * SAMPLES_K;
32
+
33
+ /** Computes the maximum absolute deviation between pMPO scores in two data frames */
34
+ function getScoreMaxDeviation(sourceDrugCol: DG.Column, sourceScores: DG.Column,
35
+ referenceDrugCol: DG.Column, referenceScores: DG.Column): number {
36
+ let mad = 0;
37
+
38
+ const sourceDrugList = sourceDrugCol.toList();
39
+ const referenceDrugList = referenceDrugCol.toList();
40
+
41
+ const sourceScoresRaw = sourceScores.getRawData();
42
+ const referenceScoresRaw = referenceScores.getRawData();
43
+
44
+ sourceDrugList.forEach((name, idx) => {
45
+ const refIdx = referenceDrugList.indexOf(name);
46
+
47
+ if (refIdx < 0)
48
+ throw new Error(`Failed to compare pMPO scores: the "${name}" drug is missing in the reference data.`);
49
+
50
+ mad = Math.max(mad, Math.abs(sourceScoresRaw[idx] - referenceScoresRaw[refIdx]));
51
+ });
52
+
53
+ return mad;
54
+ } // getScoreMaxDeviation
55
+
56
+ category('Probabilistic MPO', () => {
57
+ // Correctness tests: compare pMPO scores with reference scores
58
+ PMPO_MODES.forEach((refScoreName) => {
59
+ const useSigmoid = (refScoreName == SIGMOIDAL);
60
+
61
+ test('Correctness: ' + refScoreName, async () => {
62
+ let sourceDf: DG.DataFrame | null = null;
63
+ let referenceDf: DG.DataFrame | null = null;
64
+ let desirability: DG.Column | null = null;
65
+ let descriptors: DG.Column[] = [];
66
+ let sourceDrugCol: DG.Column | null = null;
67
+ let referenceDrugCol: DG.Column | null = null;
68
+ let referencePrediction: DG.Column | null = null;
69
+ let mad: number | null = null;
70
+
71
+ try {
72
+ // Load data
73
+ sourceDf = await grok.dapi.files.readCsv(SOURCE_PATH);
74
+ referenceDf = await grok.dapi.files.readCsv(SCORES_PATH);
75
+
76
+ // Extract training items
77
+ desirability = sourceDf.col(DESIRABILITY_COL_NAME);
78
+ descriptors = sourceDf.columns.byNames(DESCRIPTOR_NAMES);
79
+
80
+ if (desirability == null)
81
+ throw new Error();
82
+
83
+ // Train pMPO model
84
+ const trainRes = Pmpo.fit(
85
+ sourceDf,
86
+ DG.DataFrame.fromColumns(descriptors).columns,
87
+ desirability,
88
+ P_VAL_TRES_DEFAULT,
89
+ R2_DEFAULT,
90
+ Q_CUTOFF_DEFAULT,
91
+ );
92
+
93
+ // Apply pMPO
94
+ const prediction = Pmpo.predict(sourceDf, trainRes.params, useSigmoid, SCORES_NAME);
95
+
96
+ // Compare with reference scores
97
+ sourceDrugCol = sourceDf.col(DRUG);
98
+ referenceDrugCol = referenceDf.col(DRUG);
99
+ referencePrediction = referenceDf.col(refScoreName);
100
+
101
+ mad = getScoreMaxDeviation(sourceDrugCol!, prediction, referenceDrugCol!, referencePrediction!);
102
+
103
+ //console.log(refScoreName, ': max absolute deviation of pMPO scores:', mad);
104
+ } catch (error) {
105
+ grok.shell.error((error as Error).message);
106
+ }
107
+
108
+ expect(sourceDf !== null, true, 'Failed to load the source data: ' + SOURCE_PATH);
109
+ expect(referenceDf !== null, true, 'Failed to load the scores data: ' + SCORES_PATH);
110
+ expect(desirability !== null, true, 'Inconsistent source data: no column ' + DESIRABILITY_COL_NAME);
111
+ expect(descriptors.length, DESCRIPTOR_NAMES.length, 'Inconsistent source data: no enough of columns');
112
+ expect(sourceDrugCol !== null, true, 'Inconsistent source data: no column ' + DRUG);
113
+ expect(referenceDrugCol !== null, true, 'Inconsistent reference data: no column ' + DRUG);
114
+ expect(referencePrediction !== null, true, 'Inconsistent reference data: no column ' + SCORES_NAME);
115
+ expect(mad !== null, true, 'Failed to compare pMPO scores with the reference data');
116
+ expect(mad! < MAD_THRESH, true, `Max absolute deviation of pMPO scores exceeds the threshold (${MAD_THRESH})`);
117
+ }, {timeout: TIMEOUT});
118
+ });
119
+
120
+ // Performance tests: measure time of pMPO training
121
+ test('Performance: ' + SAMPLES_K + 'K drugs, ' + DESCRIPTOR_NAMES.length + ' descriptors', async () => {
122
+ let sourceDf: DG.DataFrame | null = null;
123
+ let desirability: DG.Column | null = null;
124
+ let descriptors: DG.Column[] = [];
125
+
126
+ try {
127
+ // Generate synthetic data
128
+ sourceDf = await getSynteticPmpoData(SAMPLES_COUNT);
129
+
130
+ // Extract training items
131
+ desirability = sourceDf.col(DESIRABILITY_COL_NAME);
132
+ descriptors = sourceDf.columns.byNames(DESCRIPTOR_NAMES);
133
+
134
+ if (desirability == null)
135
+ throw new Error();
136
+
137
+ // Train pMPO model
138
+ const trainRes = Pmpo.fit(
139
+ sourceDf,
140
+ DG.DataFrame.fromColumns(descriptors).columns,
141
+ desirability,
142
+ P_VAL_TRES_DEFAULT,
143
+ R2_DEFAULT,
144
+ Q_CUTOFF_DEFAULT,
145
+ );
146
+
147
+ // Apply pMPO
148
+ Pmpo.predict(sourceDf, trainRes.params, true, SCORES_NAME);
149
+ } catch (error) {
150
+ grok.shell.error((error as Error).message);
151
+ }
152
+
153
+ expect(sourceDf !== null, true, 'Failed to load the source data: ' + SOURCE_PATH);
154
+ expect(desirability !== null, true, 'Inconsistent source data: no column ' + DESIRABILITY_COL_NAME);
155
+ expect(descriptors.length, DESCRIPTOR_NAMES.length, 'Inconsistent source data: no enough of columns');
156
+ }, {timeout: TIMEOUT});
157
+ });