@elaraai/east-py-datascience 0.0.2-beta.8 → 0.0.2-beta.80
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +58 -1
- package/dist/src/alns/alns.d.ts +528 -0
- package/dist/src/alns/alns.d.ts.map +1 -0
- package/dist/src/alns/alns.js +238 -0
- package/dist/src/alns/alns.js.map +1 -0
- package/dist/src/google_or/google_or.d.ts +2422 -0
- package/dist/src/google_or/google_or.d.ts.map +1 -0
- package/dist/src/google_or/google_or.js +542 -0
- package/dist/src/google_or/google_or.js.map +1 -0
- package/dist/{gp → src/gp}/gp.d.ts +185 -136
- package/dist/src/gp/gp.d.ts.map +1 -0
- package/dist/{gp → src/gp}/gp.js +64 -12
- package/dist/src/gp/gp.js.map +1 -0
- package/dist/src/index.d.ts +34 -0
- package/dist/src/index.d.ts.map +1 -0
- package/dist/src/index.js +57 -0
- package/dist/src/index.js.map +1 -0
- package/dist/src/lightgbm/lightgbm.d.ts +575 -0
- package/dist/src/lightgbm/lightgbm.d.ts.map +1 -0
- package/dist/{lightgbm → src/lightgbm}/lightgbm.js +104 -18
- package/dist/src/lightgbm/lightgbm.js.map +1 -0
- package/dist/src/lightning/lightning.d.ts +1594 -0
- package/dist/src/lightning/lightning.d.ts.map +1 -0
- package/dist/src/lightning/lightning.js +468 -0
- package/dist/src/lightning/lightning.js.map +1 -0
- package/dist/{mads → src/mads}/mads.d.ts +109 -112
- package/dist/src/mads/mads.d.ts.map +1 -0
- package/dist/{mads → src/mads}/mads.js +6 -8
- package/dist/src/mads/mads.js.map +1 -0
- package/dist/src/mapie/mapie.d.ts +3680 -0
- package/dist/src/mapie/mapie.d.ts.map +1 -0
- package/dist/src/mapie/mapie.js +616 -0
- package/dist/src/mapie/mapie.js.map +1 -0
- package/dist/{ngboost → src/ngboost}/ngboost.d.ts +192 -142
- package/dist/src/ngboost/ngboost.d.ts.map +1 -0
- package/dist/{ngboost → src/ngboost}/ngboost.js +67 -14
- package/dist/src/ngboost/ngboost.js.map +1 -0
- package/dist/src/optimization/optimization.d.ts +420 -0
- package/dist/src/optimization/optimization.d.ts.map +1 -0
- package/dist/src/optimization/optimization.js +257 -0
- package/dist/src/optimization/optimization.js.map +1 -0
- package/dist/{optuna → src/optuna}/optuna.d.ts +374 -314
- package/dist/src/optuna/optuna.d.ts.map +1 -0
- package/dist/{optuna → src/optuna}/optuna.js +2 -0
- package/dist/src/optuna/optuna.js.map +1 -0
- package/dist/src/pymc/pymc.d.ts +2932 -0
- package/dist/src/pymc/pymc.d.ts.map +1 -0
- package/dist/src/pymc/pymc.js +688 -0
- package/dist/src/pymc/pymc.js.map +1 -0
- package/dist/src/scipy/scipy.d.ts +2205 -0
- package/dist/src/scipy/scipy.d.ts.map +1 -0
- package/dist/src/scipy/scipy.js +884 -0
- package/dist/src/scipy/scipy.js.map +1 -0
- package/dist/src/shap/shap.d.ts +2988 -0
- package/dist/src/shap/shap.d.ts.map +1 -0
- package/dist/src/shap/shap.js +500 -0
- package/dist/src/shap/shap.js.map +1 -0
- package/dist/{simanneal → src/simanneal}/simanneal.d.ts +257 -160
- package/dist/src/simanneal/simanneal.d.ts.map +1 -0
- package/dist/{simanneal → src/simanneal}/simanneal.js +105 -8
- package/dist/src/simanneal/simanneal.js.map +1 -0
- package/dist/src/simulation/simulation.d.ts +431 -0
- package/dist/src/simulation/simulation.d.ts.map +1 -0
- package/dist/src/simulation/simulation.js +306 -0
- package/dist/src/simulation/simulation.js.map +1 -0
- package/dist/src/sklearn/sklearn.d.ts +6362 -0
- package/dist/src/sklearn/sklearn.d.ts.map +1 -0
- package/dist/src/sklearn/sklearn.js +1508 -0
- package/dist/src/sklearn/sklearn.js.map +1 -0
- package/dist/src/torch/torch.d.ts +1205 -0
- package/dist/src/torch/torch.d.ts.map +1 -0
- package/dist/{torch → src/torch}/torch.js +109 -18
- package/dist/src/torch/torch.js.map +1 -0
- package/dist/src/types.d.ts +43 -0
- package/dist/src/types.d.ts.map +1 -0
- package/dist/src/types.js +44 -0
- package/dist/src/types.js.map +1 -0
- package/dist/src/xgboost/xgboost.d.ts +1424 -0
- package/dist/src/xgboost/xgboost.d.ts.map +1 -0
- package/dist/src/xgboost/xgboost.js +432 -0
- package/dist/src/xgboost/xgboost.js.map +1 -0
- package/package.json +12 -12
- package/dist/gp/gp.d.ts.map +0 -1
- package/dist/gp/gp.js.map +0 -1
- package/dist/index.d.ts +0 -27
- package/dist/index.d.ts.map +0 -1
- package/dist/index.js +0 -41
- package/dist/index.js.map +0 -1
- package/dist/lightgbm/lightgbm.d.ts +0 -494
- package/dist/lightgbm/lightgbm.d.ts.map +0 -1
- package/dist/lightgbm/lightgbm.js.map +0 -1
- package/dist/mads/mads.d.ts.map +0 -1
- package/dist/mads/mads.js.map +0 -1
- package/dist/ngboost/ngboost.d.ts.map +0 -1
- package/dist/ngboost/ngboost.js.map +0 -1
- package/dist/optuna/optuna.d.ts.map +0 -1
- package/dist/optuna/optuna.js.map +0 -1
- package/dist/scipy/scipy.d.ts +0 -1260
- package/dist/scipy/scipy.d.ts.map +0 -1
- package/dist/scipy/scipy.js +0 -413
- package/dist/scipy/scipy.js.map +0 -1
- package/dist/shap/shap.d.ts +0 -657
- package/dist/shap/shap.d.ts.map +0 -1
- package/dist/shap/shap.js +0 -241
- package/dist/shap/shap.js.map +0 -1
- package/dist/simanneal/simanneal.d.ts.map +0 -1
- package/dist/simanneal/simanneal.js.map +0 -1
- package/dist/sklearn/sklearn.d.ts +0 -2691
- package/dist/sklearn/sklearn.d.ts.map +0 -1
- package/dist/sklearn/sklearn.js +0 -524
- package/dist/sklearn/sklearn.js.map +0 -1
- package/dist/torch/torch.d.ts +0 -1081
- package/dist/torch/torch.d.ts.map +0 -1
- package/dist/torch/torch.js.map +0 -1
- package/dist/tsconfig.tsbuildinfo +0 -1
- package/dist/types.d.ts +0 -80
- package/dist/types.d.ts.map +0 -1
- package/dist/types.js +0 -81
- package/dist/types.js.map +0 -1
- package/dist/xgboost/xgboost.d.ts +0 -504
- package/dist/xgboost/xgboost.d.ts.map +0 -1
- package/dist/xgboost/xgboost.js +0 -177
- package/dist/xgboost/xgboost.js.map +0 -1
|
@@ -0,0 +1,1508 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Copyright (c) 2025 Elara AI Pty Ltd
|
|
3
|
+
* Dual-licensed under AGPL-3.0 and commercial license. See LICENSE for details.
|
|
4
|
+
*/
|
|
5
|
+
/**
|
|
6
|
+
* Scikit-learn platform functions for East.
|
|
7
|
+
*
|
|
8
|
+
* Provides core machine learning utilities: preprocessing, model selection, and metrics.
|
|
9
|
+
* Uses ONNX for model serialization to enable portable inference.
|
|
10
|
+
*
|
|
11
|
+
* @packageDocumentation
|
|
12
|
+
*/
|
|
13
|
+
import { East, StructType, VariantType, OptionType, IntegerType, BooleanType, FloatType, BlobType, ArrayType, StringType, NullType, } from "@elaraai/east";
|
|
14
|
+
import { VectorType, MatrixType } from "../types.js";
|
|
15
|
+
import { XGBoostConfigType } from "../xgboost/xgboost.js";
|
|
16
|
+
import { LightGBMConfigType } from "../lightgbm/lightgbm.js";
|
|
17
|
+
import { NGBoostConfigType } from "../ngboost/ngboost.js";
|
|
18
|
+
import { GPConfigType } from "../gp/gp.js";
|
|
19
|
+
// Re-export shared types for convenience
|
|
20
|
+
export { VectorType, MatrixType } from "../types.js";
|
|
21
|
+
// ============================================================================
|
|
22
|
+
// Class Weight Types
|
|
23
|
+
// ============================================================================
|
|
24
|
+
/**
|
|
25
|
+
* Mode for computing class weights.
|
|
26
|
+
*/
|
|
27
|
+
export const ClassWeightModeType = VariantType({
|
|
28
|
+
/** Weights are inversely proportional to class frequencies */
|
|
29
|
+
balanced: NullType,
|
|
30
|
+
});
|
|
31
|
+
// ============================================================================
|
|
32
|
+
// Confusion Matrix Types
|
|
33
|
+
// ============================================================================
|
|
34
|
+
/**
|
|
35
|
+
* Result type for confusion matrix.
|
|
36
|
+
*/
|
|
37
|
+
export const ConfusionMatrixResultType = StructType({
|
|
38
|
+
/** Confusion matrix (n_classes x n_classes) */
|
|
39
|
+
matrix: MatrixType(FloatType),
|
|
40
|
+
/** Class labels in order */
|
|
41
|
+
classes: VectorType(IntegerType),
|
|
42
|
+
});
|
|
43
|
+
// Re-export config types used in RegressorChain
|
|
44
|
+
export { XGBoostConfigType } from "../xgboost/xgboost.js";
|
|
45
|
+
export { LightGBMConfigType } from "../lightgbm/lightgbm.js";
|
|
46
|
+
export { NGBoostConfigType } from "../ngboost/ngboost.js";
|
|
47
|
+
export { GPConfigType } from "../gp/gp.js";
|
|
48
|
+
// ============================================================================
|
|
49
|
+
// Config Types
|
|
50
|
+
// ============================================================================
|
|
51
|
+
/**
|
|
52
|
+
* Configuration for data splitting.
|
|
53
|
+
*
|
|
54
|
+
* Examples:
|
|
55
|
+
* - 2-way: split_sizes: [0.8, 0.2] -> train/test
|
|
56
|
+
* - 3-way: split_sizes: [0.7, 0.15, 0.15] -> train/val/test
|
|
57
|
+
* - 4-way: split_sizes: [0.6, 0.1, 0.15, 0.15] -> train/val/calib/test
|
|
58
|
+
*/
|
|
59
|
+
export const SplitConfigType = StructType({
|
|
60
|
+
/** Array of split proportions (must sum to 1.0). */
|
|
61
|
+
split_sizes: ArrayType(FloatType),
|
|
62
|
+
/** Random seed for reproducibility */
|
|
63
|
+
random_state: OptionType(IntegerType),
|
|
64
|
+
/** Whether to shuffle data before splitting (default true) */
|
|
65
|
+
shuffle: OptionType(BooleanType),
|
|
66
|
+
/**
|
|
67
|
+
* Stratification columns - controls proportional distribution across splits.
|
|
68
|
+
* Combined into compound strata. Each inner array is one column of labels.
|
|
69
|
+
* Note: Stratify does NOT guarantee overlap - use the overlap parameter for that.
|
|
70
|
+
*/
|
|
71
|
+
stratify: OptionType(MatrixType(IntegerType)),
|
|
72
|
+
/**
|
|
73
|
+
* Columns that must have overlapping representation in all splits.
|
|
74
|
+
* Each column is checked independently - values that don't appear in all splits are rejected.
|
|
75
|
+
* Each inner array is one column of labels (same length as X).
|
|
76
|
+
*/
|
|
77
|
+
overlap: OptionType(MatrixType(IntegerType)),
|
|
78
|
+
/**
|
|
79
|
+
* Multi-value overlap columns - each sample can have MULTIPLE values (a set).
|
|
80
|
+
* Structure: Array of columns, where each column is Array of samples, where each sample is Array of values.
|
|
81
|
+
* Ensures each unique value (across all samples) appears in all splits.
|
|
82
|
+
* Use this when a single sample can belong to multiple categories over time.
|
|
83
|
+
*/
|
|
84
|
+
multi_overlap: OptionType(ArrayType(ArrayType(VectorType(IntegerType)))),
|
|
85
|
+
/**
|
|
86
|
+
* Minimum samples per overlap value. Values with fewer samples are rejected. (default = n_splits)
|
|
87
|
+
* This ensures enough samples to distribute across all splits.
|
|
88
|
+
*/
|
|
89
|
+
min_overlap: OptionType(IntegerType),
|
|
90
|
+
});
|
|
91
|
+
// ============================================================================
|
|
92
|
+
// Result Types
|
|
93
|
+
// ============================================================================
|
|
94
|
+
/**
|
|
95
|
+
* Result of data splitting.
|
|
96
|
+
*/
|
|
97
|
+
export const SplitResultType = StructType({
|
|
98
|
+
/** Array of feature matrices, one per split (in order of split_sizes) */
|
|
99
|
+
X_splits: ArrayType(MatrixType(FloatType)),
|
|
100
|
+
/** Array of target matrices, one per split (in order of split_sizes) */
|
|
101
|
+
Y_splits: ArrayType(MatrixType(FloatType)),
|
|
102
|
+
/** Indices of rows rejected due to rare stratify classes or missing overlap values */
|
|
103
|
+
rejected_indices: ArrayType(IntegerType),
|
|
104
|
+
});
|
|
105
|
+
/**
|
|
106
|
+
* Configuration for categorical overlap filtering.
|
|
107
|
+
*/
|
|
108
|
+
export const OverlapConfigType = StructType({
|
|
109
|
+
/** Which column indices in the feature matrix are categorical */
|
|
110
|
+
cat_indices: VectorType(IntegerType),
|
|
111
|
+
});
|
|
112
|
+
/**
|
|
113
|
+
* Result of overlap filtering.
|
|
114
|
+
*/
|
|
115
|
+
export const OverlapResultType = StructType({
|
|
116
|
+
/** Filtered feature matrices (one per target, rows with unseen categories removed) */
|
|
117
|
+
X_filtered: ArrayType(MatrixType(FloatType)),
|
|
118
|
+
/** Filtered target matrices (one per target, filtered in sync with X) */
|
|
119
|
+
Y_filtered: ArrayType(MatrixType(FloatType)),
|
|
120
|
+
/** Number of rejected rows per target */
|
|
121
|
+
rejected_counts: VectorType(IntegerType),
|
|
122
|
+
/** Per categorical column, the sorted list of known values from the reference */
|
|
123
|
+
known_categories: ArrayType(VectorType(IntegerType)),
|
|
124
|
+
});
|
|
125
|
+
// ============================================================================
|
|
126
|
+
// Flexible Metrics Types
|
|
127
|
+
// ============================================================================
|
|
128
|
+
/**
|
|
129
|
+
* Available regression metrics from sklearn.metrics.
|
|
130
|
+
*/
|
|
131
|
+
export const RegressionMetricType = VariantType({
|
|
132
|
+
/** Mean Squared Error - sklearn.metrics.mean_squared_error */
|
|
133
|
+
mse: NullType,
|
|
134
|
+
/** Root Mean Squared Error - sqrt(MSE) */
|
|
135
|
+
rmse: NullType,
|
|
136
|
+
/** Mean Absolute Error - sklearn.metrics.mean_absolute_error */
|
|
137
|
+
mae: NullType,
|
|
138
|
+
/** R² (coefficient of determination) - sklearn.metrics.r2_score */
|
|
139
|
+
r2: NullType,
|
|
140
|
+
/** Mean Absolute Percentage Error - sklearn.metrics.mean_absolute_percentage_error */
|
|
141
|
+
mape: NullType,
|
|
142
|
+
/** Explained Variance Score - sklearn.metrics.explained_variance_score */
|
|
143
|
+
explained_variance: NullType,
|
|
144
|
+
/** Max Error - sklearn.metrics.max_error */
|
|
145
|
+
max_error: NullType,
|
|
146
|
+
/** Median Absolute Error - sklearn.metrics.median_absolute_error */
|
|
147
|
+
median_ae: NullType,
|
|
148
|
+
/** Mean Error (bias) - mean(pred - true), should be ~0 for unbiased predictions */
|
|
149
|
+
mean_error: NullType,
|
|
150
|
+
/** Pinball Loss - proper scoring rule for quantile regression (requires alpha parameter) */
|
|
151
|
+
pinball_loss: FloatType,
|
|
152
|
+
/** Huber Loss - robust to outliers (requires delta parameter, default 1.0) */
|
|
153
|
+
huber: FloatType,
|
|
154
|
+
/** Mean Tweedie Deviance - for skewed distributions (requires power parameter) */
|
|
155
|
+
mean_tweedie_deviance: FloatType,
|
|
156
|
+
});
|
|
157
|
+
/**
|
|
158
|
+
* Single metric result (scalar value).
|
|
159
|
+
*/
|
|
160
|
+
export const MetricResultType = StructType({
|
|
161
|
+
/** Which metric was computed */
|
|
162
|
+
metric: RegressionMetricType,
|
|
163
|
+
/** Scalar metric value */
|
|
164
|
+
value: FloatType,
|
|
165
|
+
});
|
|
166
|
+
/**
|
|
167
|
+
* Result containing multiple computed metrics.
|
|
168
|
+
*/
|
|
169
|
+
export const MetricsResultType = ArrayType(MetricResultType);
|
|
170
|
+
/**
|
|
171
|
+
* Aggregation strategy for multi-target metrics.
|
|
172
|
+
*/
|
|
173
|
+
export const MetricAggregationType = VariantType({
|
|
174
|
+
/** Return metric for each target separately (default) */
|
|
175
|
+
per_target: NullType,
|
|
176
|
+
/** Average across all targets (uniform weights) */
|
|
177
|
+
uniform_average: NullType,
|
|
178
|
+
});
|
|
179
|
+
/**
|
|
180
|
+
* Configuration for multi-target metrics computation.
|
|
181
|
+
*/
|
|
182
|
+
export const MultiMetricsConfigType = StructType({
|
|
183
|
+
/** How to aggregate metrics across targets (default: per_target) */
|
|
184
|
+
aggregation: OptionType(MetricAggregationType),
|
|
185
|
+
});
|
|
186
|
+
/**
|
|
187
|
+
* Multi-target metric result.
|
|
188
|
+
*/
|
|
189
|
+
export const MultiMetricResultType = StructType({
|
|
190
|
+
/** Which metric was computed */
|
|
191
|
+
metric: RegressionMetricType,
|
|
192
|
+
/** Metric value(s) */
|
|
193
|
+
value: VariantType({
|
|
194
|
+
/** Aggregated scalar value */
|
|
195
|
+
scalar: FloatType,
|
|
196
|
+
/** Per-target values [target_0, target_1, ...] */
|
|
197
|
+
per_target: VectorType(FloatType),
|
|
198
|
+
}),
|
|
199
|
+
});
|
|
200
|
+
/**
|
|
201
|
+
* Result containing multiple computed metrics (multi-target).
|
|
202
|
+
*/
|
|
203
|
+
export const MultiMetricsResultType = ArrayType(MultiMetricResultType);
|
|
204
|
+
/**
|
|
205
|
+
* Weights type for Cohen's Kappa score.
|
|
206
|
+
*/
|
|
207
|
+
export const CohenKappaWeightsType = VariantType({
|
|
208
|
+
/** No weighting (default) */
|
|
209
|
+
none: NullType,
|
|
210
|
+
/** Linear weighting - penalizes disagreements linearly */
|
|
211
|
+
linear: NullType,
|
|
212
|
+
/** Quadratic weighting - penalizes disagreements quadratically */
|
|
213
|
+
quadratic: NullType,
|
|
214
|
+
});
|
|
215
|
+
/**
|
|
216
|
+
* Available classification metrics from sklearn.metrics.
|
|
217
|
+
*/
|
|
218
|
+
export const ClassificationMetricType = VariantType({
|
|
219
|
+
/** Accuracy - sklearn.metrics.accuracy_score */
|
|
220
|
+
accuracy: NullType,
|
|
221
|
+
/** Balanced Accuracy - sklearn.metrics.balanced_accuracy_score */
|
|
222
|
+
balanced_accuracy: NullType,
|
|
223
|
+
/** Precision - sklearn.metrics.precision_score */
|
|
224
|
+
precision: NullType,
|
|
225
|
+
/** Recall - sklearn.metrics.recall_score */
|
|
226
|
+
recall: NullType,
|
|
227
|
+
/** F1 Score - sklearn.metrics.f1_score */
|
|
228
|
+
f1: NullType,
|
|
229
|
+
/** Matthews Correlation Coefficient - sklearn.metrics.matthews_corrcoef */
|
|
230
|
+
matthews_corrcoef: NullType,
|
|
231
|
+
/** Cohen's Kappa - sklearn.metrics.cohen_kappa_score (with optional weights) */
|
|
232
|
+
cohen_kappa: CohenKappaWeightsType,
|
|
233
|
+
/** Jaccard Score - sklearn.metrics.jaccard_score */
|
|
234
|
+
jaccard: NullType,
|
|
235
|
+
});
|
|
236
|
+
/**
|
|
237
|
+
* Averaging strategy for multi-class classification metrics.
|
|
238
|
+
*/
|
|
239
|
+
export const ClassificationAverageType = VariantType({
|
|
240
|
+
/** Calculate metrics for each label, return unweighted mean */
|
|
241
|
+
macro: NullType,
|
|
242
|
+
/** Calculate metrics globally by counting total TP, FP, FN */
|
|
243
|
+
micro: NullType,
|
|
244
|
+
/** Calculate metrics for each label, return weighted mean by support */
|
|
245
|
+
weighted: NullType,
|
|
246
|
+
/** Only for binary classification */
|
|
247
|
+
binary: NullType,
|
|
248
|
+
});
|
|
249
|
+
/**
|
|
250
|
+
* Multi-class strategy for ROC AUC.
|
|
251
|
+
*/
|
|
252
|
+
export const RocAucMultiClassType = VariantType({
|
|
253
|
+
/** One-vs-rest (OvR) - computes AUC of each class against all others */
|
|
254
|
+
ovr: NullType,
|
|
255
|
+
/** One-vs-one (OvO) - computes pairwise AUC and averages */
|
|
256
|
+
ovo: NullType,
|
|
257
|
+
});
|
|
258
|
+
/**
|
|
259
|
+
* Configuration for ROC AUC score.
|
|
260
|
+
*/
|
|
261
|
+
export const RocAucConfigType = StructType({
|
|
262
|
+
/** Multi-class strategy (default: ovr) */
|
|
263
|
+
multi_class: OptionType(RocAucMultiClassType),
|
|
264
|
+
/** Averaging strategy for multi-class: 'macro' or 'weighted' (default: macro) */
|
|
265
|
+
average: OptionType(ClassificationAverageType),
|
|
266
|
+
});
|
|
267
|
+
/**
|
|
268
|
+
* Configuration for classification metrics.
|
|
269
|
+
*/
|
|
270
|
+
export const ClassificationMetricsConfigType = StructType({
|
|
271
|
+
/** Averaging strategy for multi-class (default: macro) */
|
|
272
|
+
average: OptionType(ClassificationAverageType),
|
|
273
|
+
});
|
|
274
|
+
/**
|
|
275
|
+
* Single classification metric result.
|
|
276
|
+
*/
|
|
277
|
+
export const ClassificationMetricResultType = StructType({
|
|
278
|
+
/** Which metric was computed */
|
|
279
|
+
metric: ClassificationMetricType,
|
|
280
|
+
/** Scalar metric value */
|
|
281
|
+
value: FloatType,
|
|
282
|
+
});
|
|
283
|
+
/**
|
|
284
|
+
* Result containing multiple computed classification metrics.
|
|
285
|
+
*/
|
|
286
|
+
export const ClassificationMetricResultsType = ArrayType(ClassificationMetricResultType);
|
|
287
|
+
/**
|
|
288
|
+
* Configuration for multi-target classification metrics.
|
|
289
|
+
*/
|
|
290
|
+
export const MultiClassificationConfigType = StructType({
|
|
291
|
+
/** Averaging strategy for multi-class (default: macro) */
|
|
292
|
+
average: OptionType(ClassificationAverageType),
|
|
293
|
+
/** How to aggregate across targets (default: per_target) */
|
|
294
|
+
aggregation: OptionType(MetricAggregationType),
|
|
295
|
+
});
|
|
296
|
+
/**
|
|
297
|
+
* Multi-target classification metric result.
|
|
298
|
+
*/
|
|
299
|
+
export const MultiClassificationMetricResultType = StructType({
|
|
300
|
+
/** Which metric was computed */
|
|
301
|
+
metric: ClassificationMetricType,
|
|
302
|
+
/** Metric value(s) */
|
|
303
|
+
value: VariantType({
|
|
304
|
+
/** Aggregated scalar value */
|
|
305
|
+
scalar: FloatType,
|
|
306
|
+
/** Per-target values */
|
|
307
|
+
per_target: VectorType(FloatType),
|
|
308
|
+
}),
|
|
309
|
+
});
|
|
310
|
+
/**
|
|
311
|
+
* Result containing multiple computed classification metrics (multi-target).
|
|
312
|
+
*/
|
|
313
|
+
export const MultiClassificationMetricResultsType = ArrayType(MultiClassificationMetricResultType);
|
|
314
|
+
// ============================================================================
|
|
315
|
+
// GMM Types
|
|
316
|
+
// ============================================================================
|
|
317
|
+
/**
|
|
318
|
+
* Covariance type for Gaussian Mixture Models.
|
|
319
|
+
*/
|
|
320
|
+
export const GMMCovarianceType = VariantType({
|
|
321
|
+
/** Each component has its own general covariance matrix */
|
|
322
|
+
full: NullType,
|
|
323
|
+
/** All components share the same general covariance matrix */
|
|
324
|
+
tied: NullType,
|
|
325
|
+
/** Each component has its own diagonal covariance matrix */
|
|
326
|
+
diag: NullType,
|
|
327
|
+
/** Each component has its own single variance */
|
|
328
|
+
spherical: NullType,
|
|
329
|
+
});
|
|
330
|
+
/**
|
|
331
|
+
* Configuration for Gaussian Mixture Model fitting.
|
|
332
|
+
*/
|
|
333
|
+
export const GMMConfigType = StructType({
|
|
334
|
+
/** Number of mixture components (default 1) */
|
|
335
|
+
n_components: OptionType(IntegerType),
|
|
336
|
+
/** Covariance type (default full) */
|
|
337
|
+
covariance_type: OptionType(GMMCovarianceType),
|
|
338
|
+
/** Maximum number of EM iterations (default 100) */
|
|
339
|
+
max_iter: OptionType(IntegerType),
|
|
340
|
+
/** Number of initializations (default 1) */
|
|
341
|
+
n_init: OptionType(IntegerType),
|
|
342
|
+
/** Convergence tolerance (default 1e-3) */
|
|
343
|
+
tol: OptionType(FloatType),
|
|
344
|
+
/** Regularization added to diagonal of covariance (default 1e-6) */
|
|
345
|
+
reg_covar: OptionType(FloatType),
|
|
346
|
+
/** Random seed for reproducibility */
|
|
347
|
+
random_state: OptionType(IntegerType),
|
|
348
|
+
});
|
|
349
|
+
// ============================================================================
|
|
350
|
+
// Model Blob Types
|
|
351
|
+
// ============================================================================
|
|
352
|
+
/**
|
|
353
|
+
* Model blob type for serialized sklearn models.
|
|
354
|
+
*
|
|
355
|
+
* Each model type has its own variant case containing ONNX bytes and metadata.
|
|
356
|
+
*/
|
|
357
|
+
export const SklearnModelBlobType = VariantType({
|
|
358
|
+
/** StandardScaler model */
|
|
359
|
+
standard_scaler: StructType({
|
|
360
|
+
/** ONNX model bytes */
|
|
361
|
+
onnx: BlobType,
|
|
362
|
+
/** Number of input features */
|
|
363
|
+
n_features: IntegerType,
|
|
364
|
+
}),
|
|
365
|
+
/** MinMaxScaler model */
|
|
366
|
+
min_max_scaler: StructType({
|
|
367
|
+
/** ONNX model bytes */
|
|
368
|
+
onnx: BlobType,
|
|
369
|
+
/** Number of input features */
|
|
370
|
+
n_features: IntegerType,
|
|
371
|
+
}),
|
|
372
|
+
/** RobustScaler model */
|
|
373
|
+
robust_scaler: StructType({
|
|
374
|
+
/** ONNX model bytes */
|
|
375
|
+
onnx: BlobType,
|
|
376
|
+
/** Number of input features */
|
|
377
|
+
n_features: IntegerType,
|
|
378
|
+
}),
|
|
379
|
+
/** LabelEncoder model */
|
|
380
|
+
label_encoder: StructType({
|
|
381
|
+
/** Cloudpickle serialized encoder */
|
|
382
|
+
data: BlobType,
|
|
383
|
+
/** Number of unique classes */
|
|
384
|
+
n_classes: IntegerType,
|
|
385
|
+
}),
|
|
386
|
+
/** OrdinalEncoder model */
|
|
387
|
+
ordinal_encoder: StructType({
|
|
388
|
+
/** Cloudpickle serialized encoder */
|
|
389
|
+
data: BlobType,
|
|
390
|
+
/** Number of features */
|
|
391
|
+
n_features: IntegerType,
|
|
392
|
+
}),
|
|
393
|
+
/** RegressorChain model */
|
|
394
|
+
regressor_chain: StructType({
|
|
395
|
+
/** Cloudpickle serialized chain */
|
|
396
|
+
data: BlobType,
|
|
397
|
+
/** Number of input features */
|
|
398
|
+
n_features: IntegerType,
|
|
399
|
+
/** Number of target outputs */
|
|
400
|
+
n_targets: IntegerType,
|
|
401
|
+
/** Base estimator type name */
|
|
402
|
+
base_estimator_type: StringType,
|
|
403
|
+
}),
|
|
404
|
+
/** Gaussian Mixture Model */
|
|
405
|
+
gaussian_mixture: StructType({
|
|
406
|
+
/** Cloudpickle serialized GMM */
|
|
407
|
+
data: BlobType,
|
|
408
|
+
/** Number of input features */
|
|
409
|
+
n_features: IntegerType,
|
|
410
|
+
/** Number of mixture components */
|
|
411
|
+
n_components: IntegerType,
|
|
412
|
+
}),
|
|
413
|
+
});
|
|
414
|
+
// ============================================================================
|
|
415
|
+
// RegressorChain Types
|
|
416
|
+
// ============================================================================
|
|
417
|
+
/**
|
|
418
|
+
* Base estimator configuration for RegressorChain.
|
|
419
|
+
* Variant carries both the estimator type AND its configuration.
|
|
420
|
+
*/
|
|
421
|
+
export const RegressorChainBaseConfigType = VariantType({
|
|
422
|
+
/** XGBoost regressor */
|
|
423
|
+
xgboost: XGBoostConfigType,
|
|
424
|
+
/** LightGBM regressor */
|
|
425
|
+
lightgbm: LightGBMConfigType,
|
|
426
|
+
/** NGBoost regressor */
|
|
427
|
+
ngboost: NGBoostConfigType,
|
|
428
|
+
/** Gaussian Process regressor */
|
|
429
|
+
gp: GPConfigType,
|
|
430
|
+
});
|
|
431
|
+
/**
|
|
432
|
+
* Configuration for RegressorChain.
|
|
433
|
+
*/
|
|
434
|
+
export const RegressorChainConfigType = StructType({
|
|
435
|
+
/** Base estimator with its configuration */
|
|
436
|
+
base_estimator: RegressorChainBaseConfigType,
|
|
437
|
+
/** Chain order (indices of targets). None = natural order [0,1,2,...] */
|
|
438
|
+
order: OptionType(ArrayType(IntegerType)),
|
|
439
|
+
/** Random seed for reproducibility */
|
|
440
|
+
random_state: OptionType(IntegerType),
|
|
441
|
+
});
|
|
442
|
+
// ============================================================================
|
|
443
|
+
// Platform Functions
|
|
444
|
+
// ============================================================================
|
|
445
|
+
/**
|
|
446
|
+
* Split arrays into N subsets (train/test, train/val/test, etc.).
|
|
447
|
+
*
|
|
448
|
+
* @param X - Feature matrix
|
|
449
|
+
* @param Y - Target matrix
|
|
450
|
+
* @param config - Split configuration with split_sizes, stratify, overlap
|
|
451
|
+
* @returns Split result with X_splits, Y_splits arrays
|
|
452
|
+
*
|
|
453
|
+
* @example
|
|
454
|
+
* ```ts
|
|
455
|
+
* // 2-way split (train/test)
|
|
456
|
+
* const result = Sklearn.split(X, Y, { split_sizes: [0.8, 0.2], ... });
|
|
457
|
+
* const [X_train, X_test] = [result.X_splits.get(0n), result.X_splits.get(1n)];
|
|
458
|
+
*
|
|
459
|
+
* // 3-way split (train/val/test)
|
|
460
|
+
* const result = Sklearn.split(X, Y, { split_sizes: [0.7, 0.15, 0.15], ... });
|
|
461
|
+
*
|
|
462
|
+
* // With multi-column stratification
|
|
463
|
+
* const result = Sklearn.split(X, Y, {
|
|
464
|
+
* split_sizes: [0.7, 0.15, 0.15],
|
|
465
|
+
* stratify: variant('some', [origin_labels, category_labels]),
|
|
466
|
+
* overlap: variant('some', [class_labels]),
|
|
467
|
+
* });
|
|
468
|
+
* ```
|
|
469
|
+
*/
|
|
470
|
+
export const sklearn_split = East.platform("sklearn_split", [MatrixType(FloatType), MatrixType(FloatType), SplitConfigType], SplitResultType);
|
|
471
|
+
/**
|
|
472
|
+
* Filter target matrices to only contain rows whose categorical values exist in the reference.
|
|
473
|
+
*
|
|
474
|
+
* Given a reference feature matrix (e.g. training data) and one or more target matrices
|
|
475
|
+
* (e.g. validation, calibration), removes rows from each target where any categorical
|
|
476
|
+
* column has a value not seen in the reference.
|
|
477
|
+
*
|
|
478
|
+
* @param X_reference - Reference feature matrix (defines known categories)
|
|
479
|
+
* @param X_targets - Array of target feature matrices to filter
|
|
480
|
+
* @param Y_targets - Array of target label matrices to filter in sync
|
|
481
|
+
* @param config - OverlapConfigType with cat_indices
|
|
482
|
+
* @returns OverlapResultType with X_filtered, Y_filtered, rejected_counts, known_categories
|
|
483
|
+
*
|
|
484
|
+
* @example
|
|
485
|
+
* ```ts
|
|
486
|
+
* // After per-head filtering, ensure val/calib only have categories seen in train
|
|
487
|
+
* const result = Sklearn.overlap(X_train, [X_val, X_calib], [Y_val, Y_calib], { cat_indices: cat_features });
|
|
488
|
+
* const X_val_clean = result.X_filtered.get(0n);
|
|
489
|
+
* const X_calib_clean = result.X_filtered.get(1n);
|
|
490
|
+
* ```
|
|
491
|
+
*/
|
|
492
|
+
export const sklearn_overlap = East.platform("sklearn_overlap", [MatrixType(FloatType), ArrayType(MatrixType(FloatType)), ArrayType(MatrixType(FloatType)), OverlapConfigType], OverlapResultType);
|
|
493
|
+
/**
|
|
494
|
+
* Fit a StandardScaler to training data.
|
|
495
|
+
*
|
|
496
|
+
* Standardizes features by removing the mean and scaling to unit variance.
|
|
497
|
+
*
|
|
498
|
+
* @param X - Training feature matrix
|
|
499
|
+
* @returns Model blob containing fitted scaler
|
|
500
|
+
*/
|
|
501
|
+
export const sklearn_standard_scaler_fit = East.platform("sklearn_standard_scaler_fit", [MatrixType(FloatType)], SklearnModelBlobType);
|
|
502
|
+
/**
|
|
503
|
+
* Transform data using a fitted StandardScaler.
|
|
504
|
+
*
|
|
505
|
+
* @param model - Fitted scaler model blob
|
|
506
|
+
* @param X - Feature matrix to transform
|
|
507
|
+
* @returns Transformed feature matrix
|
|
508
|
+
*/
|
|
509
|
+
export const sklearn_standard_scaler_transform = East.platform("sklearn_standard_scaler_transform", [SklearnModelBlobType, MatrixType(FloatType)], MatrixType(FloatType));
|
|
510
|
+
/**
|
|
511
|
+
* Fit a MinMaxScaler to training data.
|
|
512
|
+
*
|
|
513
|
+
* Scales features to a given range (default [0, 1]).
|
|
514
|
+
*
|
|
515
|
+
* @param X - Training feature matrix
|
|
516
|
+
* @returns Model blob containing fitted scaler
|
|
517
|
+
*/
|
|
518
|
+
export const sklearn_min_max_scaler_fit = East.platform("sklearn_min_max_scaler_fit", [MatrixType(FloatType)], SklearnModelBlobType);
|
|
519
|
+
/**
|
|
520
|
+
* Transform data using a fitted MinMaxScaler.
|
|
521
|
+
*
|
|
522
|
+
* @param model - Fitted scaler model blob
|
|
523
|
+
* @param X - Feature matrix to transform
|
|
524
|
+
* @returns Transformed feature matrix
|
|
525
|
+
*/
|
|
526
|
+
export const sklearn_min_max_scaler_transform = East.platform("sklearn_min_max_scaler_transform", [SklearnModelBlobType, MatrixType(FloatType)], MatrixType(FloatType));
|
|
527
|
+
/**
|
|
528
|
+
* Fit a RobustScaler to training data.
|
|
529
|
+
*
|
|
530
|
+
* Scales features using statistics that are robust to outliers.
|
|
531
|
+
* Centers data using the median and scales using the interquartile range (IQR).
|
|
532
|
+
*
|
|
533
|
+
* @param X - Training feature matrix
|
|
534
|
+
* @returns Model blob containing fitted scaler
|
|
535
|
+
*/
|
|
536
|
+
export const sklearn_robust_scaler_fit = East.platform("sklearn_robust_scaler_fit", [MatrixType(FloatType)], SklearnModelBlobType);
|
|
537
|
+
/**
|
|
538
|
+
* Transform data using a fitted RobustScaler.
|
|
539
|
+
*
|
|
540
|
+
* @param model - Fitted scaler model blob
|
|
541
|
+
* @param X - Feature matrix to transform
|
|
542
|
+
* @returns Transformed feature matrix
|
|
543
|
+
*/
|
|
544
|
+
export const sklearn_robust_scaler_transform = East.platform("sklearn_robust_scaler_transform", [SklearnModelBlobType, MatrixType(FloatType)], MatrixType(FloatType));
|
|
545
|
+
/**
|
|
546
|
+
* Fit a LabelEncoder to encode target labels.
|
|
547
|
+
*
|
|
548
|
+
* Encodes labels with values between 0 and n_classes-1.
|
|
549
|
+
*
|
|
550
|
+
* @param y - Target labels (1D integer array)
|
|
551
|
+
* @returns Model blob containing fitted encoder
|
|
552
|
+
*/
|
|
553
|
+
export const sklearn_label_encoder_fit = East.platform("sklearn_label_encoder_fit", [VectorType(IntegerType)], SklearnModelBlobType);
|
|
554
|
+
/**
|
|
555
|
+
* Transform labels using a fitted LabelEncoder.
|
|
556
|
+
*
|
|
557
|
+
* @param model - Fitted encoder model blob
|
|
558
|
+
* @param y - Labels to transform
|
|
559
|
+
* @returns Encoded labels (0 to n_classes-1)
|
|
560
|
+
*/
|
|
561
|
+
export const sklearn_label_encoder_transform = East.platform("sklearn_label_encoder_transform", [SklearnModelBlobType, VectorType(IntegerType)], VectorType(IntegerType));
|
|
562
|
+
/**
|
|
563
|
+
* Inverse transform encoded labels back to original values.
|
|
564
|
+
*
|
|
565
|
+
* @param model - Fitted encoder model blob
|
|
566
|
+
* @param y - Encoded labels to inverse transform
|
|
567
|
+
* @returns Original label values
|
|
568
|
+
*/
|
|
569
|
+
export const sklearn_label_encoder_inverse_transform = East.platform("sklearn_label_encoder_inverse_transform", [SklearnModelBlobType, VectorType(IntegerType)], VectorType(IntegerType));
|
|
570
|
+
/**
|
|
571
|
+
* Fit an OrdinalEncoder to encode categorical features.
|
|
572
|
+
*
|
|
573
|
+
* Encodes categorical features as ordinal integers.
|
|
574
|
+
*
|
|
575
|
+
* @param X - Feature matrix with categorical values
|
|
576
|
+
* @returns Model blob containing fitted encoder
|
|
577
|
+
*/
|
|
578
|
+
export const sklearn_ordinal_encoder_fit = East.platform("sklearn_ordinal_encoder_fit", [MatrixType(FloatType)], SklearnModelBlobType);
|
|
579
|
+
/**
|
|
580
|
+
* Transform features using a fitted OrdinalEncoder.
|
|
581
|
+
*
|
|
582
|
+
* @param model - Fitted encoder model blob
|
|
583
|
+
* @param X - Feature matrix to transform
|
|
584
|
+
* @returns Encoded feature matrix
|
|
585
|
+
*/
|
|
586
|
+
export const sklearn_ordinal_encoder_transform = East.platform("sklearn_ordinal_encoder_transform", [SklearnModelBlobType, MatrixType(FloatType)], MatrixType(FloatType));
|
|
587
|
+
/**
|
|
588
|
+
* Compute class weights for balanced training.
|
|
589
|
+
*
|
|
590
|
+
* Calculates weights inversely proportional to class frequencies,
|
|
591
|
+
* useful for handling class imbalance in classification tasks.
|
|
592
|
+
*
|
|
593
|
+
* @param mode - How to compute weights (balanced)
|
|
594
|
+
* @param y - Class labels (1D integer array)
|
|
595
|
+
* @returns Weights for each class (ordered by class index)
|
|
596
|
+
*/
|
|
597
|
+
export const sklearn_compute_class_weight = East.platform("sklearn_compute_class_weight", [ClassWeightModeType, VectorType(IntegerType)], VectorType(FloatType));
|
|
598
|
+
/**
|
|
599
|
+
* Compute confusion matrix for classification results.
|
|
600
|
+
*
|
|
601
|
+
* Returns a matrix where entry [i,j] is the number of samples
|
|
602
|
+
* with true label i that were predicted as label j.
|
|
603
|
+
*
|
|
604
|
+
* @param y_true - True class labels (1D integer array)
|
|
605
|
+
* @param y_pred - Predicted class labels (1D integer array)
|
|
606
|
+
* @returns Confusion matrix result with matrix and class labels
|
|
607
|
+
*/
|
|
608
|
+
export const sklearn_confusion_matrix = East.platform("sklearn_confusion_matrix", [VectorType(IntegerType), VectorType(IntegerType)], ConfusionMatrixResultType);
|
|
609
|
+
/**
|
|
610
|
+
* Compute ROC AUC score for classification results.
|
|
611
|
+
*
|
|
612
|
+
* For binary classification, pass probabilities for the positive class.
|
|
613
|
+
* For multi-class, pass probability matrix (n_samples x n_classes).
|
|
614
|
+
*
|
|
615
|
+
* @param y_true - True class labels (1D integer array)
|
|
616
|
+
* @param y_proba - Predicted probabilities (matrix: n_samples x n_classes)
|
|
617
|
+
* @param config - Configuration for multi-class handling
|
|
618
|
+
* @returns ROC AUC score
|
|
619
|
+
*/
|
|
620
|
+
export const sklearn_roc_auc_score = East.platform("sklearn_roc_auc_score", [VectorType(IntegerType), MatrixType(FloatType), RocAucConfigType], FloatType);
|
|
621
|
+
/**
|
|
622
|
+
* Compute log loss (cross-entropy loss) for classification results.
|
|
623
|
+
*
|
|
624
|
+
* @param y_true - True class labels (1D integer array)
|
|
625
|
+
* @param y_proba - Predicted probabilities (matrix: n_samples x n_classes)
|
|
626
|
+
* @returns Log loss value
|
|
627
|
+
*/
|
|
628
|
+
export const sklearn_log_loss = East.platform("sklearn_log_loss", [VectorType(IntegerType), MatrixType(FloatType)], FloatType);
|
|
629
|
+
/**
|
|
630
|
+
* Compute the silhouette score for clustering quality evaluation.
|
|
631
|
+
*
|
|
632
|
+
* The silhouette score measures how similar each sample is to its own cluster
|
|
633
|
+
* compared to other clusters. Values range from -1 to 1, where higher values
|
|
634
|
+
* indicate better-defined clusters.
|
|
635
|
+
*
|
|
636
|
+
* @param X - Feature matrix (n_samples x n_features)
|
|
637
|
+
* @param labels - Cluster labels for each sample (1D integer array)
|
|
638
|
+
* @returns Silhouette score (float, -1 to 1)
|
|
639
|
+
*/
|
|
640
|
+
export const sklearn_silhouette_score = East.platform("sklearn_silhouette_score", [MatrixType(FloatType), VectorType(IntegerType)], FloatType);
|
|
641
|
+
/**
|
|
642
|
+
* Train a RegressorChain for multi-target regression.
|
|
643
|
+
*
|
|
644
|
+
* Each model in the chain uses previous targets as additional features,
|
|
645
|
+
* enabling modeling of dependencies between targets.
|
|
646
|
+
*
|
|
647
|
+
* @param X - Feature matrix
|
|
648
|
+
* @param Y - Target matrix (rows=samples, cols=targets)
|
|
649
|
+
* @param config - Chain configuration
|
|
650
|
+
* @returns Model blob containing fitted chain
|
|
651
|
+
*/
|
|
652
|
+
export const sklearn_regressor_chain_train = East.platform("sklearn_regressor_chain_train", [MatrixType(FloatType), MatrixType(FloatType), RegressorChainConfigType], SklearnModelBlobType);
|
|
653
|
+
/**
|
|
654
|
+
* Predict using a fitted RegressorChain.
|
|
655
|
+
*
|
|
656
|
+
* @param model - Fitted chain model blob
|
|
657
|
+
* @param X - Feature matrix to predict
|
|
658
|
+
* @returns Predicted target matrix
|
|
659
|
+
*/
|
|
660
|
+
export const sklearn_regressor_chain_predict = East.platform("sklearn_regressor_chain_predict", [SklearnModelBlobType, MatrixType(FloatType)], MatrixType(FloatType));
|
|
661
|
+
// ============================================================================
|
|
662
|
+
// GMM Platform Functions
|
|
663
|
+
// ============================================================================
|
|
664
|
+
/**
|
|
665
|
+
* Fit a Gaussian Mixture Model to data.
|
|
666
|
+
*
|
|
667
|
+
* @param X - Feature matrix (n_samples x n_features)
|
|
668
|
+
* @param config - GMM configuration
|
|
669
|
+
* @returns Model blob containing fitted GMM
|
|
670
|
+
*/
|
|
671
|
+
export const sklearn_gmm_fit = East.platform("sklearn_gmm_fit", [MatrixType(FloatType), GMMConfigType], SklearnModelBlobType);
|
|
672
|
+
/**
|
|
673
|
+
* Predict cluster labels for data using a fitted GMM.
|
|
674
|
+
*
|
|
675
|
+
* @param model - Fitted GMM model blob
|
|
676
|
+
* @param X - Feature matrix to predict
|
|
677
|
+
* @returns Predicted cluster labels (0 to n_components-1)
|
|
678
|
+
*/
|
|
679
|
+
export const sklearn_gmm_predict = East.platform("sklearn_gmm_predict", [SklearnModelBlobType, MatrixType(FloatType)], VectorType(IntegerType));
|
|
680
|
+
/**
|
|
681
|
+
* Predict posterior probabilities for each component.
|
|
682
|
+
*
|
|
683
|
+
* @param model - Fitted GMM model blob
|
|
684
|
+
* @param X - Feature matrix
|
|
685
|
+
* @returns Probability matrix (n_samples x n_components)
|
|
686
|
+
*/
|
|
687
|
+
export const sklearn_gmm_predict_proba = East.platform("sklearn_gmm_predict_proba", [SklearnModelBlobType, MatrixType(FloatType)], MatrixType(FloatType));
|
|
688
|
+
/**
|
|
689
|
+
* Compute per-sample log-likelihood under the model.
|
|
690
|
+
*
|
|
691
|
+
* @param model - Fitted GMM model blob
|
|
692
|
+
* @param X - Feature matrix
|
|
693
|
+
* @returns Log-likelihood for each sample
|
|
694
|
+
*/
|
|
695
|
+
export const sklearn_gmm_score_samples = East.platform("sklearn_gmm_score_samples", [SklearnModelBlobType, MatrixType(FloatType)], VectorType(FloatType));
|
|
696
|
+
/**
|
|
697
|
+
* Generate random samples from the fitted GMM.
|
|
698
|
+
*
|
|
699
|
+
* @param model - Fitted GMM model blob
|
|
700
|
+
* @param n_samples - Number of samples to generate
|
|
701
|
+
* @returns Generated samples matrix (n_samples x n_features)
|
|
702
|
+
*/
|
|
703
|
+
export const sklearn_gmm_sample = East.platform("sklearn_gmm_sample", [SklearnModelBlobType, IntegerType], MatrixType(FloatType));
|
|
704
|
+
/**
|
|
705
|
+
* Compute Bayesian Information Criterion for the model on data.
|
|
706
|
+
*
|
|
707
|
+
* Lower BIC indicates a better model. Useful for selecting n_components.
|
|
708
|
+
*
|
|
709
|
+
* @param model - Fitted GMM model blob
|
|
710
|
+
* @param X - Feature matrix
|
|
711
|
+
* @returns BIC score
|
|
712
|
+
*/
|
|
713
|
+
export const sklearn_gmm_bic = East.platform("sklearn_gmm_bic", [SklearnModelBlobType, MatrixType(FloatType)], FloatType);
|
|
714
|
+
/**
|
|
715
|
+
* Compute Akaike Information Criterion for the model on data.
|
|
716
|
+
*
|
|
717
|
+
* Lower AIC indicates a better model. Useful for selecting n_components.
|
|
718
|
+
*
|
|
719
|
+
* @param model - Fitted GMM model blob
|
|
720
|
+
* @param X - Feature matrix
|
|
721
|
+
* @returns AIC score
|
|
722
|
+
*/
|
|
723
|
+
export const sklearn_gmm_aic = East.platform("sklearn_gmm_aic", [SklearnModelBlobType, MatrixType(FloatType)], FloatType);
|
|
724
|
+
/**
|
|
725
|
+
* Compute regression metrics for single-target predictions.
|
|
726
|
+
*
|
|
727
|
+
* @param y_true - True target values (1D vector)
|
|
728
|
+
* @param y_pred - Predicted target values (1D vector)
|
|
729
|
+
* @param metrics - Array of metrics to compute
|
|
730
|
+
* @returns Array of metric results with scalar values
|
|
731
|
+
*/
|
|
732
|
+
export const sklearn_compute_metrics = East.platform("sklearn_compute_metrics", [VectorType(FloatType), VectorType(FloatType), ArrayType(RegressionMetricType)], MetricsResultType);
|
|
733
|
+
/**
|
|
734
|
+
* Compute regression metrics for multi-target predictions.
|
|
735
|
+
*
|
|
736
|
+
* @param Y_true - True target matrix [n_samples, n_targets]
|
|
737
|
+
* @param Y_pred - Predicted target matrix [n_samples, n_targets]
|
|
738
|
+
* @param metrics - Array of metrics to compute
|
|
739
|
+
* @param config - Aggregation configuration
|
|
740
|
+
* @returns Array of metric results with per-target or aggregated values
|
|
741
|
+
*/
|
|
742
|
+
export const sklearn_compute_metrics_multi = East.platform("sklearn_compute_metrics_multi", [MatrixType(FloatType), MatrixType(FloatType), ArrayType(RegressionMetricType), MultiMetricsConfigType], MultiMetricsResultType);
|
|
743
|
+
/**
|
|
744
|
+
* Compute classification metrics for single-target predictions.
|
|
745
|
+
*
|
|
746
|
+
* @param y_true - True class labels (1D integer array)
|
|
747
|
+
* @param y_pred - Predicted class labels (1D integer array)
|
|
748
|
+
* @param metrics - Array of metrics to compute
|
|
749
|
+
* @param config - Configuration (averaging strategy)
|
|
750
|
+
* @returns Array of metric results with scalar values
|
|
751
|
+
*/
|
|
752
|
+
export const sklearn_compute_classification_metrics = East.platform("sklearn_compute_classification_metrics", [VectorType(IntegerType), VectorType(IntegerType), ArrayType(ClassificationMetricType), ClassificationMetricsConfigType], ClassificationMetricResultsType);
|
|
753
|
+
/**
|
|
754
|
+
* Compute classification metrics for multi-target predictions.
|
|
755
|
+
*
|
|
756
|
+
* @param Y_true - True class labels matrix [n_samples, n_targets]
|
|
757
|
+
* @param Y_pred - Predicted class labels matrix [n_samples, n_targets]
|
|
758
|
+
* @param metrics - Array of metrics to compute
|
|
759
|
+
* @param config - Configuration (averaging, aggregation)
|
|
760
|
+
* @returns Array of metric results with per-target or aggregated values
|
|
761
|
+
*/
|
|
762
|
+
export const sklearn_compute_classification_metrics_multi = East.platform("sklearn_compute_classification_metrics_multi", [MatrixType(FloatType), MatrixType(FloatType), ArrayType(ClassificationMetricType), MultiClassificationConfigType], MultiClassificationMetricResultsType);
|
|
763
|
+
// ============================================================================
|
|
764
|
+
// Grouped Export
|
|
765
|
+
// ============================================================================
|
|
766
|
+
/**
|
|
767
|
+
* Type definitions for sklearn functions.
|
|
768
|
+
*/
|
|
769
|
+
export const SklearnTypes = {
|
|
770
|
+
/** Class weight mode type */
|
|
771
|
+
ClassWeightModeType,
|
|
772
|
+
/** Confusion matrix result type */
|
|
773
|
+
ConfusionMatrixResultType,
|
|
774
|
+
/** ROC AUC multi-class strategy type */
|
|
775
|
+
RocAucMultiClassType,
|
|
776
|
+
/** ROC AUC configuration type */
|
|
777
|
+
RocAucConfigType,
|
|
778
|
+
/** Split configuration type */
|
|
779
|
+
SplitConfigType,
|
|
780
|
+
/** Split result type */
|
|
781
|
+
SplitResultType,
|
|
782
|
+
/** Overlap configuration type */
|
|
783
|
+
OverlapConfigType,
|
|
784
|
+
/** Overlap result type */
|
|
785
|
+
OverlapResultType,
|
|
786
|
+
/** Model blob type for sklearn models */
|
|
787
|
+
ModelBlobType: SklearnModelBlobType,
|
|
788
|
+
/** RegressorChain base estimator config type */
|
|
789
|
+
RegressorChainBaseConfigType,
|
|
790
|
+
/** RegressorChain config type */
|
|
791
|
+
RegressorChainConfigType,
|
|
792
|
+
/** GMM covariance type */
|
|
793
|
+
GMMCovarianceType,
|
|
794
|
+
/** GMM configuration type */
|
|
795
|
+
GMMConfigType,
|
|
796
|
+
// Flexible metrics types
|
|
797
|
+
/** Regression metric variant */
|
|
798
|
+
RegressionMetricType,
|
|
799
|
+
/** Single metric result */
|
|
800
|
+
MetricResultType,
|
|
801
|
+
/** Multiple metrics result */
|
|
802
|
+
MetricsResultType,
|
|
803
|
+
/** Metric aggregation type */
|
|
804
|
+
MetricAggregationType,
|
|
805
|
+
/** Multi-target metrics config */
|
|
806
|
+
MultiMetricsConfigType,
|
|
807
|
+
/** Multi-target metric result */
|
|
808
|
+
MultiMetricResultType,
|
|
809
|
+
/** Multi-target metrics result */
|
|
810
|
+
MultiMetricsResultType,
|
|
811
|
+
/** Cohen's Kappa weights type */
|
|
812
|
+
CohenKappaWeightsType,
|
|
813
|
+
/** Classification metric variant */
|
|
814
|
+
ClassificationMetricType,
|
|
815
|
+
/** Classification averaging type */
|
|
816
|
+
ClassificationAverageType,
|
|
817
|
+
/** Classification metrics config */
|
|
818
|
+
ClassificationMetricsConfigType,
|
|
819
|
+
/** Classification metric result */
|
|
820
|
+
ClassificationMetricResultType,
|
|
821
|
+
/** Classification metrics result */
|
|
822
|
+
ClassificationMetricResultsType,
|
|
823
|
+
/** Multi-target classification config */
|
|
824
|
+
MultiClassificationConfigType,
|
|
825
|
+
/** Multi-target classification metric result */
|
|
826
|
+
MultiClassificationMetricResultType,
|
|
827
|
+
/** Multi-target classification metrics result */
|
|
828
|
+
MultiClassificationMetricResultsType,
|
|
829
|
+
};
|
|
830
|
+
/**
|
|
831
|
+
* Scikit-learn machine learning utilities.
|
|
832
|
+
*
|
|
833
|
+
* Provides preprocessing, model selection, and metrics for ML workflows.
|
|
834
|
+
*
|
|
835
|
+
* @example
|
|
836
|
+
* ```ts
|
|
837
|
+
* import { East, variant } from "@elaraai/east";
|
|
838
|
+
* import { Sklearn } from "@elaraai/east-py-datascience";
|
|
839
|
+
*
|
|
840
|
+
* const pipeline = East.function([], Sklearn.Types.SplitResultType, $ => {
|
|
841
|
+
* const X = $.let([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]);
|
|
842
|
+
* const y = $.let([1.0, 2.0, 3.0, 4.0]);
|
|
843
|
+
* const config = $.let({
|
|
844
|
+
* test_size: variant('some', 0.25),
|
|
845
|
+
* random_state: variant('some', 42n),
|
|
846
|
+
* shuffle: variant('some', true),
|
|
847
|
+
* });
|
|
848
|
+
* return $.return(Sklearn.trainTestSplit(X, y, config));
|
|
849
|
+
* });
|
|
850
|
+
* ```
|
|
851
|
+
*/
|
|
852
|
+
export const Sklearn = {
|
|
853
|
+
/**
|
|
854
|
+
* Split arrays into N subsets (train/test, train/val/test, etc.).
|
|
855
|
+
*
|
|
856
|
+
* @example
|
|
857
|
+
* ```ts
|
|
858
|
+
* import { East, FloatType, variant } from "@elaraai/east";
|
|
859
|
+
* import { Sklearn, SplitConfigType, MatrixType } from "@elaraai/east-py-datascience";
|
|
860
|
+
*
|
|
861
|
+
* const splitData = East.function(
|
|
862
|
+
* [MatrixType(FloatType), MatrixType(FloatType)],
|
|
863
|
+
* Sklearn.Types.SplitResultType,
|
|
864
|
+
* ($, X, Y) => {
|
|
865
|
+
* const config = $.let({
|
|
866
|
+
* split_sizes: [0.7, 0.15, 0.15],
|
|
867
|
+
* random_state: variant("some", 42n),
|
|
868
|
+
* shuffle: variant("some", true),
|
|
869
|
+
* stratify: variant("none", null),
|
|
870
|
+
* overlap: variant("none", null),
|
|
871
|
+
* multi_overlap: variant("none", null),
|
|
872
|
+
* min_overlap: variant("none", null),
|
|
873
|
+
* }, SplitConfigType);
|
|
874
|
+
* return $.return(Sklearn.split(X, Y, config));
|
|
875
|
+
* }
|
|
876
|
+
* );
|
|
877
|
+
* ```
|
|
878
|
+
*/
|
|
879
|
+
split: sklearn_split,
|
|
880
|
+
/**
|
|
881
|
+
* Filter targets to only contain rows with categorical values seen in reference.
|
|
882
|
+
*
|
|
883
|
+
* @example
|
|
884
|
+
* ```ts
|
|
885
|
+
* import { East, FloatType, IntegerType, ArrayType } from "@elaraai/east";
|
|
886
|
+
* import { Sklearn, OverlapConfigType, MatrixType, VectorType } from "@elaraai/east-py-datascience";
|
|
887
|
+
*
|
|
888
|
+
* const filterOverlap = East.function(
|
|
889
|
+
* [MatrixType(FloatType), ArrayType(MatrixType(FloatType)), ArrayType(MatrixType(FloatType))],
|
|
890
|
+
* Sklearn.Types.OverlapResultType,
|
|
891
|
+
* ($, X_ref, X_targets, Y_targets) => {
|
|
892
|
+
* const config = $.let({
|
|
893
|
+
* cat_indices: new BigInt64Array([0n, 2n]),
|
|
894
|
+
* }, OverlapConfigType);
|
|
895
|
+
* return $.return(Sklearn.overlap(X_ref, X_targets, Y_targets, config));
|
|
896
|
+
* }
|
|
897
|
+
* );
|
|
898
|
+
* ```
|
|
899
|
+
*/
|
|
900
|
+
overlap: sklearn_overlap,
|
|
901
|
+
/**
|
|
902
|
+
* Fit a StandardScaler to training data.
|
|
903
|
+
*
|
|
904
|
+
* Standardizes features by removing the mean and scaling to unit variance.
|
|
905
|
+
*
|
|
906
|
+
* @example
|
|
907
|
+
* ```ts
|
|
908
|
+
* import { East, FloatType } from "@elaraai/east";
|
|
909
|
+
* import { Sklearn, MatrixType } from "@elaraai/east-py-datascience";
|
|
910
|
+
*
|
|
911
|
+
* const fitScaler = East.function(
|
|
912
|
+
* [MatrixType(FloatType)],
|
|
913
|
+
* Sklearn.Types.ModelBlobType,
|
|
914
|
+
* ($, X) => {
|
|
915
|
+
* return $.return(Sklearn.standardScalerFit(X));
|
|
916
|
+
* }
|
|
917
|
+
* );
|
|
918
|
+
* ```
|
|
919
|
+
*/
|
|
920
|
+
standardScalerFit: sklearn_standard_scaler_fit,
|
|
921
|
+
/**
|
|
922
|
+
* Transform data using a fitted StandardScaler.
|
|
923
|
+
*
|
|
924
|
+
* @example
|
|
925
|
+
* ```ts
|
|
926
|
+
* import { East, FloatType } from "@elaraai/east";
|
|
927
|
+
* import { Sklearn, MatrixType } from "@elaraai/east-py-datascience";
|
|
928
|
+
*
|
|
929
|
+
* const transform = East.function(
|
|
930
|
+
* [Sklearn.Types.ModelBlobType, MatrixType(FloatType)],
|
|
931
|
+
* MatrixType(FloatType),
|
|
932
|
+
* ($, model, X) => {
|
|
933
|
+
* return $.return(Sklearn.standardScalerTransform(model, X));
|
|
934
|
+
* }
|
|
935
|
+
* );
|
|
936
|
+
* ```
|
|
937
|
+
*/
|
|
938
|
+
standardScalerTransform: sklearn_standard_scaler_transform,
|
|
939
|
+
/**
|
|
940
|
+
* Fit a MinMaxScaler to training data.
|
|
941
|
+
*
|
|
942
|
+
* Scales features to a given range (default [0, 1]).
|
|
943
|
+
*
|
|
944
|
+
* @example
|
|
945
|
+
* ```ts
|
|
946
|
+
* import { East, FloatType } from "@elaraai/east";
|
|
947
|
+
* import { Sklearn, MatrixType } from "@elaraai/east-py-datascience";
|
|
948
|
+
*
|
|
949
|
+
* const fitScaler = East.function(
|
|
950
|
+
* [MatrixType(FloatType)],
|
|
951
|
+
* Sklearn.Types.ModelBlobType,
|
|
952
|
+
* ($, X) => {
|
|
953
|
+
* return $.return(Sklearn.minMaxScalerFit(X));
|
|
954
|
+
* }
|
|
955
|
+
* );
|
|
956
|
+
* ```
|
|
957
|
+
*/
|
|
958
|
+
minMaxScalerFit: sklearn_min_max_scaler_fit,
|
|
959
|
+
/**
|
|
960
|
+
* Transform data using a fitted MinMaxScaler.
|
|
961
|
+
*
|
|
962
|
+
* @example
|
|
963
|
+
* ```ts
|
|
964
|
+
* import { East, FloatType } from "@elaraai/east";
|
|
965
|
+
* import { Sklearn, MatrixType } from "@elaraai/east-py-datascience";
|
|
966
|
+
*
|
|
967
|
+
* const transform = East.function(
|
|
968
|
+
* [Sklearn.Types.ModelBlobType, MatrixType(FloatType)],
|
|
969
|
+
* MatrixType(FloatType),
|
|
970
|
+
* ($, model, X) => {
|
|
971
|
+
* return $.return(Sklearn.minMaxScalerTransform(model, X));
|
|
972
|
+
* }
|
|
973
|
+
* );
|
|
974
|
+
* ```
|
|
975
|
+
*/
|
|
976
|
+
minMaxScalerTransform: sklearn_min_max_scaler_transform,
|
|
977
|
+
/**
|
|
978
|
+
* Fit a RobustScaler to training data.
|
|
979
|
+
*
|
|
980
|
+
* Scales features using statistics robust to outliers (median and IQR).
|
|
981
|
+
*
|
|
982
|
+
* @example
|
|
983
|
+
* ```ts
|
|
984
|
+
* import { East, FloatType } from "@elaraai/east";
|
|
985
|
+
* import { Sklearn, MatrixType } from "@elaraai/east-py-datascience";
|
|
986
|
+
*
|
|
987
|
+
* const fitScaler = East.function(
|
|
988
|
+
* [MatrixType(FloatType)],
|
|
989
|
+
* Sklearn.Types.ModelBlobType,
|
|
990
|
+
* ($, X) => {
|
|
991
|
+
* return $.return(Sklearn.robustScalerFit(X));
|
|
992
|
+
* }
|
|
993
|
+
* );
|
|
994
|
+
* ```
|
|
995
|
+
*/
|
|
996
|
+
robustScalerFit: sklearn_robust_scaler_fit,
|
|
997
|
+
/**
|
|
998
|
+
* Transform data using a fitted RobustScaler.
|
|
999
|
+
*
|
|
1000
|
+
* @example
|
|
1001
|
+
* ```ts
|
|
1002
|
+
* import { East, FloatType } from "@elaraai/east";
|
|
1003
|
+
* import { Sklearn, MatrixType } from "@elaraai/east-py-datascience";
|
|
1004
|
+
*
|
|
1005
|
+
* const transform = East.function(
|
|
1006
|
+
* [Sklearn.Types.ModelBlobType, MatrixType(FloatType)],
|
|
1007
|
+
* MatrixType(FloatType),
|
|
1008
|
+
* ($, model, X) => {
|
|
1009
|
+
* return $.return(Sklearn.robustScalerTransform(model, X));
|
|
1010
|
+
* }
|
|
1011
|
+
* );
|
|
1012
|
+
* ```
|
|
1013
|
+
*/
|
|
1014
|
+
robustScalerTransform: sklearn_robust_scaler_transform,
|
|
1015
|
+
/**
|
|
1016
|
+
* Fit a LabelEncoder to encode target labels.
|
|
1017
|
+
*
|
|
1018
|
+
* Encodes labels with values between 0 and n_classes-1.
|
|
1019
|
+
*
|
|
1020
|
+
* @example
|
|
1021
|
+
* ```ts
|
|
1022
|
+
* import { East, IntegerType } from "@elaraai/east";
|
|
1023
|
+
* import { Sklearn, VectorType } from "@elaraai/east-py-datascience";
|
|
1024
|
+
*
|
|
1025
|
+
* const fitEncoder = East.function(
|
|
1026
|
+
* [VectorType(IntegerType)],
|
|
1027
|
+
* Sklearn.Types.ModelBlobType,
|
|
1028
|
+
* ($, y) => {
|
|
1029
|
+
* return $.return(Sklearn.labelEncoderFit(y));
|
|
1030
|
+
* }
|
|
1031
|
+
* );
|
|
1032
|
+
* ```
|
|
1033
|
+
*/
|
|
1034
|
+
labelEncoderFit: sklearn_label_encoder_fit,
|
|
1035
|
+
/**
|
|
1036
|
+
* Transform labels using a fitted LabelEncoder.
|
|
1037
|
+
*
|
|
1038
|
+
* @example
|
|
1039
|
+
* ```ts
|
|
1040
|
+
* import { East, IntegerType } from "@elaraai/east";
|
|
1041
|
+
* import { Sklearn, VectorType } from "@elaraai/east-py-datascience";
|
|
1042
|
+
*
|
|
1043
|
+
* const transform = East.function(
|
|
1044
|
+
* [Sklearn.Types.ModelBlobType, VectorType(IntegerType)],
|
|
1045
|
+
* VectorType(IntegerType),
|
|
1046
|
+
* ($, model, y) => {
|
|
1047
|
+
* return $.return(Sklearn.labelEncoderTransform(model, y));
|
|
1048
|
+
* }
|
|
1049
|
+
* );
|
|
1050
|
+
* ```
|
|
1051
|
+
*/
|
|
1052
|
+
labelEncoderTransform: sklearn_label_encoder_transform,
|
|
1053
|
+
/**
|
|
1054
|
+
* Inverse transform encoded labels back to original values.
|
|
1055
|
+
*
|
|
1056
|
+
* @example
|
|
1057
|
+
* ```ts
|
|
1058
|
+
* import { East, IntegerType } from "@elaraai/east";
|
|
1059
|
+
* import { Sklearn, VectorType } from "@elaraai/east-py-datascience";
|
|
1060
|
+
*
|
|
1061
|
+
* const inverse = East.function(
|
|
1062
|
+
* [Sklearn.Types.ModelBlobType, VectorType(IntegerType)],
|
|
1063
|
+
* VectorType(IntegerType),
|
|
1064
|
+
* ($, model, y) => {
|
|
1065
|
+
* return $.return(Sklearn.labelEncoderInverseTransform(model, y));
|
|
1066
|
+
* }
|
|
1067
|
+
* );
|
|
1068
|
+
* ```
|
|
1069
|
+
*/
|
|
1070
|
+
labelEncoderInverseTransform: sklearn_label_encoder_inverse_transform,
|
|
1071
|
+
/**
|
|
1072
|
+
* Fit an OrdinalEncoder to encode categorical features.
|
|
1073
|
+
*
|
|
1074
|
+
* @example
|
|
1075
|
+
* ```ts
|
|
1076
|
+
* import { East, FloatType } from "@elaraai/east";
|
|
1077
|
+
* import { Sklearn, MatrixType } from "@elaraai/east-py-datascience";
|
|
1078
|
+
*
|
|
1079
|
+
* const fitEncoder = East.function(
|
|
1080
|
+
* [MatrixType(FloatType)],
|
|
1081
|
+
* Sklearn.Types.ModelBlobType,
|
|
1082
|
+
* ($, X) => {
|
|
1083
|
+
* return $.return(Sklearn.ordinalEncoderFit(X));
|
|
1084
|
+
* }
|
|
1085
|
+
* );
|
|
1086
|
+
* ```
|
|
1087
|
+
*/
|
|
1088
|
+
ordinalEncoderFit: sklearn_ordinal_encoder_fit,
|
|
1089
|
+
/**
|
|
1090
|
+
* Transform features using a fitted OrdinalEncoder.
|
|
1091
|
+
*
|
|
1092
|
+
* @example
|
|
1093
|
+
* ```ts
|
|
1094
|
+
* import { East, FloatType } from "@elaraai/east";
|
|
1095
|
+
* import { Sklearn, MatrixType } from "@elaraai/east-py-datascience";
|
|
1096
|
+
*
|
|
1097
|
+
* const transform = East.function(
|
|
1098
|
+
* [Sklearn.Types.ModelBlobType, MatrixType(FloatType)],
|
|
1099
|
+
* MatrixType(FloatType),
|
|
1100
|
+
* ($, model, X) => {
|
|
1101
|
+
* return $.return(Sklearn.ordinalEncoderTransform(model, X));
|
|
1102
|
+
* }
|
|
1103
|
+
* );
|
|
1104
|
+
* ```
|
|
1105
|
+
*/
|
|
1106
|
+
ordinalEncoderTransform: sklearn_ordinal_encoder_transform,
|
|
1107
|
+
/**
|
|
1108
|
+
* Compute regression metrics for single-target predictions.
|
|
1109
|
+
*
|
|
1110
|
+
* @example
|
|
1111
|
+
* ```ts
|
|
1112
|
+
* import { East, FloatType, ArrayType, variant } from "@elaraai/east";
|
|
1113
|
+
* import { Sklearn, VectorType, RegressionMetricType } from "@elaraai/east-py-datascience";
|
|
1114
|
+
*
|
|
1115
|
+
* const evaluate = East.function(
|
|
1116
|
+
* [VectorType(FloatType), VectorType(FloatType)],
|
|
1117
|
+
* Sklearn.Types.MetricsResultType,
|
|
1118
|
+
* ($, y_true, y_pred) => {
|
|
1119
|
+
* const metrics = $.let([
|
|
1120
|
+
* variant("mse", null),
|
|
1121
|
+
* variant("mae", null),
|
|
1122
|
+
* variant("r2", null),
|
|
1123
|
+
* ], ArrayType(RegressionMetricType));
|
|
1124
|
+
* return $.return(Sklearn.computeMetrics(y_true, y_pred, metrics));
|
|
1125
|
+
* }
|
|
1126
|
+
* );
|
|
1127
|
+
* ```
|
|
1128
|
+
*/
|
|
1129
|
+
computeMetrics: sklearn_compute_metrics,
|
|
1130
|
+
/**
|
|
1131
|
+
* Compute regression metrics for multi-target predictions.
|
|
1132
|
+
*
|
|
1133
|
+
* @example
|
|
1134
|
+
* ```ts
|
|
1135
|
+
* import { East, FloatType, ArrayType, variant } from "@elaraai/east";
|
|
1136
|
+
* import { Sklearn, MatrixType, RegressionMetricType, MultiMetricsConfigType } from "@elaraai/east-py-datascience";
|
|
1137
|
+
*
|
|
1138
|
+
* const evaluate = East.function(
|
|
1139
|
+
* [MatrixType(FloatType), MatrixType(FloatType)],
|
|
1140
|
+
* Sklearn.Types.MultiMetricsResultType,
|
|
1141
|
+
* ($, Y_true, Y_pred) => {
|
|
1142
|
+
* const metrics = $.let([
|
|
1143
|
+
* variant("mse", null),
|
|
1144
|
+
* variant("r2", null),
|
|
1145
|
+
* ], ArrayType(RegressionMetricType));
|
|
1146
|
+
* const config = $.let({
|
|
1147
|
+
* aggregation: variant("some", variant("per_target", null)),
|
|
1148
|
+
* }, MultiMetricsConfigType);
|
|
1149
|
+
* return $.return(Sklearn.computeMetricsMulti(Y_true, Y_pred, metrics, config));
|
|
1150
|
+
* }
|
|
1151
|
+
* );
|
|
1152
|
+
* ```
|
|
1153
|
+
*/
|
|
1154
|
+
computeMetricsMulti: sklearn_compute_metrics_multi,
|
|
1155
|
+
/**
|
|
1156
|
+
* Compute classification metrics for single-target predictions.
|
|
1157
|
+
*
|
|
1158
|
+
* @example
|
|
1159
|
+
* ```ts
|
|
1160
|
+
* import { East, IntegerType, ArrayType, variant } from "@elaraai/east";
|
|
1161
|
+
* import { Sklearn, VectorType, ClassificationMetricType, ClassificationMetricsConfigType } from "@elaraai/east-py-datascience";
|
|
1162
|
+
*
|
|
1163
|
+
* const evaluate = East.function(
|
|
1164
|
+
* [VectorType(IntegerType), VectorType(IntegerType)],
|
|
1165
|
+
* Sklearn.Types.ClassificationMetricResultsType,
|
|
1166
|
+
* ($, y_true, y_pred) => {
|
|
1167
|
+
* const metrics = $.let([
|
|
1168
|
+
* variant("accuracy", null),
|
|
1169
|
+
* variant("f1", null),
|
|
1170
|
+
* variant("precision", null),
|
|
1171
|
+
* ], ArrayType(ClassificationMetricType));
|
|
1172
|
+
* const config = $.let({
|
|
1173
|
+
* average: variant("some", variant("macro", null)),
|
|
1174
|
+
* }, ClassificationMetricsConfigType);
|
|
1175
|
+
* return $.return(Sklearn.computeClassificationMetrics(y_true, y_pred, metrics, config));
|
|
1176
|
+
* }
|
|
1177
|
+
* );
|
|
1178
|
+
* ```
|
|
1179
|
+
*/
|
|
1180
|
+
computeClassificationMetrics: sklearn_compute_classification_metrics,
|
|
1181
|
+
/**
|
|
1182
|
+
* Compute classification metrics for multi-target predictions.
|
|
1183
|
+
*
|
|
1184
|
+
* @example
|
|
1185
|
+
* ```ts
|
|
1186
|
+
* import { East, FloatType, ArrayType, variant } from "@elaraai/east";
|
|
1187
|
+
* import { Sklearn, MatrixType, ClassificationMetricType, MultiClassificationConfigType } from "@elaraai/east-py-datascience";
|
|
1188
|
+
*
|
|
1189
|
+
* const evaluate = East.function(
|
|
1190
|
+
* [MatrixType(FloatType), MatrixType(FloatType)],
|
|
1191
|
+
* Sklearn.Types.MultiClassificationMetricResultsType,
|
|
1192
|
+
* ($, Y_true, Y_pred) => {
|
|
1193
|
+
* const metrics = $.let([
|
|
1194
|
+
* variant("accuracy", null),
|
|
1195
|
+
* variant("f1", null),
|
|
1196
|
+
* ], ArrayType(ClassificationMetricType));
|
|
1197
|
+
* const config = $.let({
|
|
1198
|
+
* average: variant("some", variant("macro", null)),
|
|
1199
|
+
* aggregation: variant("some", variant("per_target", null)),
|
|
1200
|
+
* }, MultiClassificationConfigType);
|
|
1201
|
+
* return $.return(Sklearn.computeClassificationMetricsMulti(Y_true, Y_pred, metrics, config));
|
|
1202
|
+
* }
|
|
1203
|
+
* );
|
|
1204
|
+
* ```
|
|
1205
|
+
*/
|
|
1206
|
+
computeClassificationMetricsMulti: sklearn_compute_classification_metrics_multi,
|
|
1207
|
+
/**
|
|
1208
|
+
* Compute class weights for balanced training.
|
|
1209
|
+
*
|
|
1210
|
+
* Calculates weights inversely proportional to class frequencies.
|
|
1211
|
+
*
|
|
1212
|
+
* @example
|
|
1213
|
+
* ```ts
|
|
1214
|
+
* import { East, FloatType, IntegerType, variant } from "@elaraai/east";
|
|
1215
|
+
* import { Sklearn, VectorType, ClassWeightModeType } from "@elaraai/east-py-datascience";
|
|
1216
|
+
*
|
|
1217
|
+
* const getWeights = East.function(
|
|
1218
|
+
* [VectorType(IntegerType)],
|
|
1219
|
+
* VectorType(FloatType),
|
|
1220
|
+
* ($, y) => {
|
|
1221
|
+
* const mode = $.let(variant("balanced", null), ClassWeightModeType);
|
|
1222
|
+
* return $.return(Sklearn.computeClassWeight(mode, y));
|
|
1223
|
+
* }
|
|
1224
|
+
* );
|
|
1225
|
+
* ```
|
|
1226
|
+
*/
|
|
1227
|
+
computeClassWeight: sklearn_compute_class_weight,
|
|
1228
|
+
/**
|
|
1229
|
+
* Compute confusion matrix for classification results.
|
|
1230
|
+
*
|
|
1231
|
+
* @example
|
|
1232
|
+
* ```ts
|
|
1233
|
+
* import { East, IntegerType } from "@elaraai/east";
|
|
1234
|
+
* import { Sklearn, VectorType } from "@elaraai/east-py-datascience";
|
|
1235
|
+
*
|
|
1236
|
+
* const getMatrix = East.function(
|
|
1237
|
+
* [VectorType(IntegerType), VectorType(IntegerType)],
|
|
1238
|
+
* Sklearn.Types.ConfusionMatrixResultType,
|
|
1239
|
+
* ($, y_true, y_pred) => {
|
|
1240
|
+
* return $.return(Sklearn.confusionMatrix(y_true, y_pred));
|
|
1241
|
+
* }
|
|
1242
|
+
* );
|
|
1243
|
+
* ```
|
|
1244
|
+
*/
|
|
1245
|
+
confusionMatrix: sklearn_confusion_matrix,
|
|
1246
|
+
/**
|
|
1247
|
+
* Compute ROC AUC score for classification results.
|
|
1248
|
+
*
|
|
1249
|
+
* @example
|
|
1250
|
+
* ```ts
|
|
1251
|
+
* import { East, FloatType, IntegerType, variant } from "@elaraai/east";
|
|
1252
|
+
* import { Sklearn, VectorType, MatrixType, RocAucConfigType } from "@elaraai/east-py-datascience";
|
|
1253
|
+
*
|
|
1254
|
+
* const getAuc = East.function(
|
|
1255
|
+
* [VectorType(IntegerType), MatrixType(FloatType)],
|
|
1256
|
+
* FloatType,
|
|
1257
|
+
* ($, y_true, y_proba) => {
|
|
1258
|
+
* const config = $.let({
|
|
1259
|
+
* multi_class: variant("some", variant("ovr", null)),
|
|
1260
|
+
* average: variant("some", variant("macro", null)),
|
|
1261
|
+
* }, RocAucConfigType);
|
|
1262
|
+
* return $.return(Sklearn.rocAucScore(y_true, y_proba, config));
|
|
1263
|
+
* }
|
|
1264
|
+
* );
|
|
1265
|
+
* ```
|
|
1266
|
+
*/
|
|
1267
|
+
rocAucScore: sklearn_roc_auc_score,
|
|
1268
|
+
/**
|
|
1269
|
+
* Compute log loss (cross-entropy loss) for classification results.
|
|
1270
|
+
*
|
|
1271
|
+
* @example
|
|
1272
|
+
* ```ts
|
|
1273
|
+
* import { East, FloatType, IntegerType } from "@elaraai/east";
|
|
1274
|
+
* import { Sklearn, VectorType, MatrixType } from "@elaraai/east-py-datascience";
|
|
1275
|
+
*
|
|
1276
|
+
* const getLoss = East.function(
|
|
1277
|
+
* [VectorType(IntegerType), MatrixType(FloatType)],
|
|
1278
|
+
* FloatType,
|
|
1279
|
+
* ($, y_true, y_proba) => {
|
|
1280
|
+
* return $.return(Sklearn.logLoss(y_true, y_proba));
|
|
1281
|
+
* }
|
|
1282
|
+
* );
|
|
1283
|
+
* ```
|
|
1284
|
+
*/
|
|
1285
|
+
logLoss: sklearn_log_loss,
|
|
1286
|
+
/**
|
|
1287
|
+
* Train a RegressorChain for multi-target regression.
|
|
1288
|
+
*
|
|
1289
|
+
* Each model in the chain uses previous targets as additional features.
|
|
1290
|
+
*
|
|
1291
|
+
* @example
|
|
1292
|
+
* ```ts
|
|
1293
|
+
* import { East, FloatType, variant } from "@elaraai/east";
|
|
1294
|
+
* import { Sklearn, MatrixType, RegressorChainConfigType } from "@elaraai/east-py-datascience";
|
|
1295
|
+
*
|
|
1296
|
+
* const trainChain = East.function(
|
|
1297
|
+
* [MatrixType(FloatType), MatrixType(FloatType)],
|
|
1298
|
+
* Sklearn.Types.ModelBlobType,
|
|
1299
|
+
* ($, X, Y) => {
|
|
1300
|
+
* const config = $.let({
|
|
1301
|
+
* base_estimator: variant("xgboost", {
|
|
1302
|
+
* n_estimators: variant("some", 100n),
|
|
1303
|
+
* max_depth: variant("some", 3n),
|
|
1304
|
+
* learning_rate: variant("some", 0.1),
|
|
1305
|
+
* min_child_weight: variant("none", null),
|
|
1306
|
+
* subsample: variant("none", null),
|
|
1307
|
+
* colsample_bytree: variant("none", null),
|
|
1308
|
+
* reg_alpha: variant("none", null),
|
|
1309
|
+
* reg_lambda: variant("none", null),
|
|
1310
|
+
* gamma: variant("none", null),
|
|
1311
|
+
* random_state: variant("some", 42n),
|
|
1312
|
+
* n_jobs: variant("none", null),
|
|
1313
|
+
* sample_weight: variant("none", null),
|
|
1314
|
+
* categorical_features: variant("none", null),
|
|
1315
|
+
* categorical_n: variant("none", null),
|
|
1316
|
+
* max_cat_to_onehot: variant("none", null),
|
|
1317
|
+
* max_cat_threshold: variant("none", null),
|
|
1318
|
+
* }),
|
|
1319
|
+
* order: variant("none", null),
|
|
1320
|
+
* random_state: variant("some", 42n),
|
|
1321
|
+
* }, RegressorChainConfigType);
|
|
1322
|
+
* return $.return(Sklearn.regressorChainTrain(X, Y, config));
|
|
1323
|
+
* }
|
|
1324
|
+
* );
|
|
1325
|
+
* ```
|
|
1326
|
+
*/
|
|
1327
|
+
regressorChainTrain: sklearn_regressor_chain_train,
|
|
1328
|
+
/**
|
|
1329
|
+
* Predict using a fitted RegressorChain.
|
|
1330
|
+
*
|
|
1331
|
+
* @example
|
|
1332
|
+
* ```ts
|
|
1333
|
+
* import { East, FloatType } from "@elaraai/east";
|
|
1334
|
+
* import { Sklearn, MatrixType } from "@elaraai/east-py-datascience";
|
|
1335
|
+
*
|
|
1336
|
+
* const predict = East.function(
|
|
1337
|
+
* [Sklearn.Types.ModelBlobType, MatrixType(FloatType)],
|
|
1338
|
+
* MatrixType(FloatType),
|
|
1339
|
+
* ($, model, X) => {
|
|
1340
|
+
* return $.return(Sklearn.regressorChainPredict(model, X));
|
|
1341
|
+
* }
|
|
1342
|
+
* );
|
|
1343
|
+
* ```
|
|
1344
|
+
*/
|
|
1345
|
+
regressorChainPredict: sklearn_regressor_chain_predict,
|
|
1346
|
+
/**
|
|
1347
|
+
* Fit a Gaussian Mixture Model to data.
|
|
1348
|
+
*
|
|
1349
|
+
* @example
|
|
1350
|
+
* ```ts
|
|
1351
|
+
* import { East, FloatType, variant } from "@elaraai/east";
|
|
1352
|
+
* import { Sklearn, MatrixType, GMMConfigType } from "@elaraai/east-py-datascience";
|
|
1353
|
+
*
|
|
1354
|
+
* const fitGmm = East.function(
|
|
1355
|
+
* [MatrixType(FloatType)],
|
|
1356
|
+
* Sklearn.Types.ModelBlobType,
|
|
1357
|
+
* ($, X) => {
|
|
1358
|
+
* const config = $.let({
|
|
1359
|
+
* n_components: variant("some", 3n),
|
|
1360
|
+
* covariance_type: variant("some", variant("full", null)),
|
|
1361
|
+
* max_iter: variant("none", null),
|
|
1362
|
+
* n_init: variant("none", null),
|
|
1363
|
+
* tol: variant("none", null),
|
|
1364
|
+
* reg_covar: variant("none", null),
|
|
1365
|
+
* random_state: variant("some", 42n),
|
|
1366
|
+
* }, GMMConfigType);
|
|
1367
|
+
* return $.return(Sklearn.gmmFit(X, config));
|
|
1368
|
+
* }
|
|
1369
|
+
* );
|
|
1370
|
+
* ```
|
|
1371
|
+
*/
|
|
1372
|
+
gmmFit: sklearn_gmm_fit,
|
|
1373
|
+
/**
|
|
1374
|
+
* Predict cluster labels for data using a fitted GMM.
|
|
1375
|
+
*
|
|
1376
|
+
* @example
|
|
1377
|
+
* ```ts
|
|
1378
|
+
* import { East, FloatType, IntegerType } from "@elaraai/east";
|
|
1379
|
+
* import { Sklearn, MatrixType, VectorType } from "@elaraai/east-py-datascience";
|
|
1380
|
+
*
|
|
1381
|
+
* const predict = East.function(
|
|
1382
|
+
* [Sklearn.Types.ModelBlobType, MatrixType(FloatType)],
|
|
1383
|
+
* VectorType(IntegerType),
|
|
1384
|
+
* ($, model, X) => {
|
|
1385
|
+
* return $.return(Sklearn.gmmPredict(model, X));
|
|
1386
|
+
* }
|
|
1387
|
+
* );
|
|
1388
|
+
* ```
|
|
1389
|
+
*/
|
|
1390
|
+
gmmPredict: sklearn_gmm_predict,
|
|
1391
|
+
/**
|
|
1392
|
+
* Predict posterior probabilities for each GMM component.
|
|
1393
|
+
*
|
|
1394
|
+
* @example
|
|
1395
|
+
* ```ts
|
|
1396
|
+
* import { East, FloatType } from "@elaraai/east";
|
|
1397
|
+
* import { Sklearn, MatrixType } from "@elaraai/east-py-datascience";
|
|
1398
|
+
*
|
|
1399
|
+
* const predictProba = East.function(
|
|
1400
|
+
* [Sklearn.Types.ModelBlobType, MatrixType(FloatType)],
|
|
1401
|
+
* MatrixType(FloatType),
|
|
1402
|
+
* ($, model, X) => {
|
|
1403
|
+
* return $.return(Sklearn.gmmPredictProba(model, X));
|
|
1404
|
+
* }
|
|
1405
|
+
* );
|
|
1406
|
+
* ```
|
|
1407
|
+
*/
|
|
1408
|
+
gmmPredictProba: sklearn_gmm_predict_proba,
|
|
1409
|
+
/**
|
|
1410
|
+
* Compute per-sample log-likelihood under the fitted GMM.
|
|
1411
|
+
*
|
|
1412
|
+
* @example
|
|
1413
|
+
* ```ts
|
|
1414
|
+
* import { East, FloatType } from "@elaraai/east";
|
|
1415
|
+
* import { Sklearn, MatrixType, VectorType } from "@elaraai/east-py-datascience";
|
|
1416
|
+
*
|
|
1417
|
+
* const score = East.function(
|
|
1418
|
+
* [Sklearn.Types.ModelBlobType, MatrixType(FloatType)],
|
|
1419
|
+
* VectorType(FloatType),
|
|
1420
|
+
* ($, model, X) => {
|
|
1421
|
+
* return $.return(Sklearn.gmmScoreSamples(model, X));
|
|
1422
|
+
* }
|
|
1423
|
+
* );
|
|
1424
|
+
* ```
|
|
1425
|
+
*/
|
|
1426
|
+
gmmScoreSamples: sklearn_gmm_score_samples,
|
|
1427
|
+
/**
|
|
1428
|
+
* Generate random samples from the fitted GMM.
|
|
1429
|
+
*
|
|
1430
|
+
* @example
|
|
1431
|
+
* ```ts
|
|
1432
|
+
* import { East, FloatType, IntegerType } from "@elaraai/east";
|
|
1433
|
+
* import { Sklearn, MatrixType } from "@elaraai/east-py-datascience";
|
|
1434
|
+
*
|
|
1435
|
+
* const generate = East.function(
|
|
1436
|
+
* [Sklearn.Types.ModelBlobType, IntegerType],
|
|
1437
|
+
* MatrixType(FloatType),
|
|
1438
|
+
* ($, model, n) => {
|
|
1439
|
+
* return $.return(Sklearn.gmmSample(model, n));
|
|
1440
|
+
* }
|
|
1441
|
+
* );
|
|
1442
|
+
* ```
|
|
1443
|
+
*/
|
|
1444
|
+
gmmSample: sklearn_gmm_sample,
|
|
1445
|
+
/**
|
|
1446
|
+
* Compute Bayesian Information Criterion for a fitted GMM.
|
|
1447
|
+
*
|
|
1448
|
+
* Lower BIC indicates a better model. Useful for selecting n_components.
|
|
1449
|
+
*
|
|
1450
|
+
* @example
|
|
1451
|
+
* ```ts
|
|
1452
|
+
* import { East, FloatType } from "@elaraai/east";
|
|
1453
|
+
* import { Sklearn, MatrixType } from "@elaraai/east-py-datascience";
|
|
1454
|
+
*
|
|
1455
|
+
* const getBic = East.function(
|
|
1456
|
+
* [Sklearn.Types.ModelBlobType, MatrixType(FloatType)],
|
|
1457
|
+
* FloatType,
|
|
1458
|
+
* ($, model, X) => {
|
|
1459
|
+
* return $.return(Sklearn.gmmBic(model, X));
|
|
1460
|
+
* }
|
|
1461
|
+
* );
|
|
1462
|
+
* ```
|
|
1463
|
+
*/
|
|
1464
|
+
gmmBic: sklearn_gmm_bic,
|
|
1465
|
+
/**
|
|
1466
|
+
* Compute Akaike Information Criterion for a fitted GMM.
|
|
1467
|
+
*
|
|
1468
|
+
* Lower AIC indicates a better model. Useful for selecting n_components.
|
|
1469
|
+
*
|
|
1470
|
+
* @example
|
|
1471
|
+
* ```ts
|
|
1472
|
+
* import { East, FloatType } from "@elaraai/east";
|
|
1473
|
+
* import { Sklearn, MatrixType } from "@elaraai/east-py-datascience";
|
|
1474
|
+
*
|
|
1475
|
+
* const getAic = East.function(
|
|
1476
|
+
* [Sklearn.Types.ModelBlobType, MatrixType(FloatType)],
|
|
1477
|
+
* FloatType,
|
|
1478
|
+
* ($, model, X) => {
|
|
1479
|
+
* return $.return(Sklearn.gmmAic(model, X));
|
|
1480
|
+
* }
|
|
1481
|
+
* );
|
|
1482
|
+
* ```
|
|
1483
|
+
*/
|
|
1484
|
+
gmmAic: sklearn_gmm_aic,
|
|
1485
|
+
/**
|
|
1486
|
+
* Compute the silhouette score for clustering quality evaluation.
|
|
1487
|
+
*
|
|
1488
|
+
* Values range from -1 to 1: higher means better-separated clusters.
|
|
1489
|
+
*
|
|
1490
|
+
* @example
|
|
1491
|
+
* ```ts
|
|
1492
|
+
* import { East, FloatType, IntegerType } from "@elaraai/east";
|
|
1493
|
+
* import { Sklearn, MatrixType, VectorType } from "@elaraai/east-py-datascience";
|
|
1494
|
+
*
|
|
1495
|
+
* const score = East.function(
|
|
1496
|
+
* [MatrixType(FloatType), VectorType(IntegerType)],
|
|
1497
|
+
* FloatType,
|
|
1498
|
+
* ($, X, labels) => {
|
|
1499
|
+
* return $.return(Sklearn.silhouetteScore(X, labels));
|
|
1500
|
+
* }
|
|
1501
|
+
* );
|
|
1502
|
+
* ```
|
|
1503
|
+
*/
|
|
1504
|
+
silhouetteScore: sklearn_silhouette_score,
|
|
1505
|
+
/** Type definitions */
|
|
1506
|
+
Types: SklearnTypes,
|
|
1507
|
+
};
|
|
1508
|
+
//# sourceMappingURL=sklearn.js.map
|