@elaraai/east-py-datascience 0.0.2-beta.8 → 0.0.2-beta.80
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +58 -1
- package/dist/src/alns/alns.d.ts +528 -0
- package/dist/src/alns/alns.d.ts.map +1 -0
- package/dist/src/alns/alns.js +238 -0
- package/dist/src/alns/alns.js.map +1 -0
- package/dist/src/google_or/google_or.d.ts +2422 -0
- package/dist/src/google_or/google_or.d.ts.map +1 -0
- package/dist/src/google_or/google_or.js +542 -0
- package/dist/src/google_or/google_or.js.map +1 -0
- package/dist/{gp → src/gp}/gp.d.ts +185 -136
- package/dist/src/gp/gp.d.ts.map +1 -0
- package/dist/{gp → src/gp}/gp.js +64 -12
- package/dist/src/gp/gp.js.map +1 -0
- package/dist/src/index.d.ts +34 -0
- package/dist/src/index.d.ts.map +1 -0
- package/dist/src/index.js +57 -0
- package/dist/src/index.js.map +1 -0
- package/dist/src/lightgbm/lightgbm.d.ts +575 -0
- package/dist/src/lightgbm/lightgbm.d.ts.map +1 -0
- package/dist/{lightgbm → src/lightgbm}/lightgbm.js +104 -18
- package/dist/src/lightgbm/lightgbm.js.map +1 -0
- package/dist/src/lightning/lightning.d.ts +1594 -0
- package/dist/src/lightning/lightning.d.ts.map +1 -0
- package/dist/src/lightning/lightning.js +468 -0
- package/dist/src/lightning/lightning.js.map +1 -0
- package/dist/{mads → src/mads}/mads.d.ts +109 -112
- package/dist/src/mads/mads.d.ts.map +1 -0
- package/dist/{mads → src/mads}/mads.js +6 -8
- package/dist/src/mads/mads.js.map +1 -0
- package/dist/src/mapie/mapie.d.ts +3680 -0
- package/dist/src/mapie/mapie.d.ts.map +1 -0
- package/dist/src/mapie/mapie.js +616 -0
- package/dist/src/mapie/mapie.js.map +1 -0
- package/dist/{ngboost → src/ngboost}/ngboost.d.ts +192 -142
- package/dist/src/ngboost/ngboost.d.ts.map +1 -0
- package/dist/{ngboost → src/ngboost}/ngboost.js +67 -14
- package/dist/src/ngboost/ngboost.js.map +1 -0
- package/dist/src/optimization/optimization.d.ts +420 -0
- package/dist/src/optimization/optimization.d.ts.map +1 -0
- package/dist/src/optimization/optimization.js +257 -0
- package/dist/src/optimization/optimization.js.map +1 -0
- package/dist/{optuna → src/optuna}/optuna.d.ts +374 -314
- package/dist/src/optuna/optuna.d.ts.map +1 -0
- package/dist/{optuna → src/optuna}/optuna.js +2 -0
- package/dist/src/optuna/optuna.js.map +1 -0
- package/dist/src/pymc/pymc.d.ts +2932 -0
- package/dist/src/pymc/pymc.d.ts.map +1 -0
- package/dist/src/pymc/pymc.js +688 -0
- package/dist/src/pymc/pymc.js.map +1 -0
- package/dist/src/scipy/scipy.d.ts +2205 -0
- package/dist/src/scipy/scipy.d.ts.map +1 -0
- package/dist/src/scipy/scipy.js +884 -0
- package/dist/src/scipy/scipy.js.map +1 -0
- package/dist/src/shap/shap.d.ts +2988 -0
- package/dist/src/shap/shap.d.ts.map +1 -0
- package/dist/src/shap/shap.js +500 -0
- package/dist/src/shap/shap.js.map +1 -0
- package/dist/{simanneal → src/simanneal}/simanneal.d.ts +257 -160
- package/dist/src/simanneal/simanneal.d.ts.map +1 -0
- package/dist/{simanneal → src/simanneal}/simanneal.js +105 -8
- package/dist/src/simanneal/simanneal.js.map +1 -0
- package/dist/src/simulation/simulation.d.ts +431 -0
- package/dist/src/simulation/simulation.d.ts.map +1 -0
- package/dist/src/simulation/simulation.js +306 -0
- package/dist/src/simulation/simulation.js.map +1 -0
- package/dist/src/sklearn/sklearn.d.ts +6362 -0
- package/dist/src/sklearn/sklearn.d.ts.map +1 -0
- package/dist/src/sklearn/sklearn.js +1508 -0
- package/dist/src/sklearn/sklearn.js.map +1 -0
- package/dist/src/torch/torch.d.ts +1205 -0
- package/dist/src/torch/torch.d.ts.map +1 -0
- package/dist/{torch → src/torch}/torch.js +109 -18
- package/dist/src/torch/torch.js.map +1 -0
- package/dist/src/types.d.ts +43 -0
- package/dist/src/types.d.ts.map +1 -0
- package/dist/src/types.js +44 -0
- package/dist/src/types.js.map +1 -0
- package/dist/src/xgboost/xgboost.d.ts +1424 -0
- package/dist/src/xgboost/xgboost.d.ts.map +1 -0
- package/dist/src/xgboost/xgboost.js +432 -0
- package/dist/src/xgboost/xgboost.js.map +1 -0
- package/package.json +12 -12
- package/dist/gp/gp.d.ts.map +0 -1
- package/dist/gp/gp.js.map +0 -1
- package/dist/index.d.ts +0 -27
- package/dist/index.d.ts.map +0 -1
- package/dist/index.js +0 -41
- package/dist/index.js.map +0 -1
- package/dist/lightgbm/lightgbm.d.ts +0 -494
- package/dist/lightgbm/lightgbm.d.ts.map +0 -1
- package/dist/lightgbm/lightgbm.js.map +0 -1
- package/dist/mads/mads.d.ts.map +0 -1
- package/dist/mads/mads.js.map +0 -1
- package/dist/ngboost/ngboost.d.ts.map +0 -1
- package/dist/ngboost/ngboost.js.map +0 -1
- package/dist/optuna/optuna.d.ts.map +0 -1
- package/dist/optuna/optuna.js.map +0 -1
- package/dist/scipy/scipy.d.ts +0 -1260
- package/dist/scipy/scipy.d.ts.map +0 -1
- package/dist/scipy/scipy.js +0 -413
- package/dist/scipy/scipy.js.map +0 -1
- package/dist/shap/shap.d.ts +0 -657
- package/dist/shap/shap.d.ts.map +0 -1
- package/dist/shap/shap.js +0 -241
- package/dist/shap/shap.js.map +0 -1
- package/dist/simanneal/simanneal.d.ts.map +0 -1
- package/dist/simanneal/simanneal.js.map +0 -1
- package/dist/sklearn/sklearn.d.ts +0 -2691
- package/dist/sklearn/sklearn.d.ts.map +0 -1
- package/dist/sklearn/sklearn.js +0 -524
- package/dist/sklearn/sklearn.js.map +0 -1
- package/dist/torch/torch.d.ts +0 -1081
- package/dist/torch/torch.d.ts.map +0 -1
- package/dist/torch/torch.js.map +0 -1
- package/dist/tsconfig.tsbuildinfo +0 -1
- package/dist/types.d.ts +0 -80
- package/dist/types.d.ts.map +0 -1
- package/dist/types.js +0 -81
- package/dist/types.js.map +0 -1
- package/dist/xgboost/xgboost.d.ts +0 -504
- package/dist/xgboost/xgboost.d.ts.map +0 -1
- package/dist/xgboost/xgboost.js +0 -177
- package/dist/xgboost/xgboost.js.map +0 -1
|
@@ -0,0 +1,1205 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Copyright (c) 2025 Elara AI Pty Ltd
|
|
3
|
+
* Dual-licensed under AGPL-3.0 and commercial license. See LICENSE for details.
|
|
4
|
+
*/
|
|
5
|
+
/**
|
|
6
|
+
* PyTorch platform functions for East.
|
|
7
|
+
*
|
|
8
|
+
* Provides neural network models using PyTorch.
|
|
9
|
+
* Uses cloudpickle for model serialization.
|
|
10
|
+
*
|
|
11
|
+
* @packageDocumentation
|
|
12
|
+
*/
|
|
13
|
+
import { StructType, VariantType, OptionType, IntegerType, FloatType, BlobType, ArrayType, NullType } from "@elaraai/east";
|
|
14
|
+
import { VectorType, MatrixType } from "../types.js";
|
|
15
|
+
export { VectorType, MatrixType } from "../types.js";
|
|
16
|
+
/**
|
|
17
|
+
* Activation function type for hidden layers.
|
|
18
|
+
*/
|
|
19
|
+
export declare const TorchActivationType: VariantType<{
|
|
20
|
+
/** Rectified Linear Unit */
|
|
21
|
+
readonly relu: NullType;
|
|
22
|
+
/** Hyperbolic tangent */
|
|
23
|
+
readonly tanh: NullType;
|
|
24
|
+
/** Sigmoid function */
|
|
25
|
+
readonly sigmoid: NullType;
|
|
26
|
+
/** Leaky ReLU */
|
|
27
|
+
readonly leaky_relu: NullType;
|
|
28
|
+
}>;
|
|
29
|
+
/**
|
|
30
|
+
* Loss function type for training.
|
|
31
|
+
*/
|
|
32
|
+
export declare const TorchLossType: VariantType<{
|
|
33
|
+
/** Mean Squared Error (regression) */
|
|
34
|
+
readonly mse: NullType;
|
|
35
|
+
/** Mean Absolute Error (regression) */
|
|
36
|
+
readonly mae: NullType;
|
|
37
|
+
/** Cross Entropy (multi-class classification with integer targets) */
|
|
38
|
+
readonly cross_entropy: NullType;
|
|
39
|
+
/** KL Divergence (distribution matching, use with softmax output) */
|
|
40
|
+
readonly kl_div: NullType;
|
|
41
|
+
/** Binary Cross Entropy (multi-label binary, requires sigmoid output) */
|
|
42
|
+
readonly bce: NullType;
|
|
43
|
+
/** Binary Cross Entropy with Logits (more stable, applies sigmoid internally - do NOT use with sigmoid output_activation) */
|
|
44
|
+
readonly bce_with_logits: NullType;
|
|
45
|
+
}>;
|
|
46
|
+
/**
|
|
47
|
+
* Optimizer type for training.
|
|
48
|
+
*/
|
|
49
|
+
export declare const TorchOptimizerType: VariantType<{
|
|
50
|
+
/** Adam optimizer */
|
|
51
|
+
readonly adam: NullType;
|
|
52
|
+
/** Stochastic Gradient Descent */
|
|
53
|
+
readonly sgd: NullType;
|
|
54
|
+
/** AdamW with weight decay */
|
|
55
|
+
readonly adamw: NullType;
|
|
56
|
+
/** RMSprop optimizer */
|
|
57
|
+
readonly rmsprop: NullType;
|
|
58
|
+
}>;
|
|
59
|
+
/**
|
|
60
|
+
* Output activation function type for the final layer.
|
|
61
|
+
* Applied only to the output layer, not hidden layers.
|
|
62
|
+
*/
|
|
63
|
+
export declare const TorchOutputActivationType: VariantType<{
|
|
64
|
+
/** No activation (linear output) - default */
|
|
65
|
+
readonly none: NullType;
|
|
66
|
+
/** Softmax (outputs sum to 1, for probability distributions) */
|
|
67
|
+
readonly softmax: NullType;
|
|
68
|
+
/** Sigmoid (each output independently in [0,1]) */
|
|
69
|
+
readonly sigmoid: NullType;
|
|
70
|
+
}>;
|
|
71
|
+
/**
|
|
72
|
+
* Configuration for MLP architecture.
|
|
73
|
+
*/
|
|
74
|
+
export declare const TorchMLPConfigType: StructType<{
|
|
75
|
+
/** Hidden layer sizes, e.g., [64, 32] */
|
|
76
|
+
readonly hidden_layers: ArrayType<IntegerType>;
|
|
77
|
+
/** Activation function for hidden layers (default relu) */
|
|
78
|
+
readonly activation: OptionType<VariantType<{
|
|
79
|
+
/** Rectified Linear Unit */
|
|
80
|
+
readonly relu: NullType;
|
|
81
|
+
/** Hyperbolic tangent */
|
|
82
|
+
readonly tanh: NullType;
|
|
83
|
+
/** Sigmoid function */
|
|
84
|
+
readonly sigmoid: NullType;
|
|
85
|
+
/** Leaky ReLU */
|
|
86
|
+
readonly leaky_relu: NullType;
|
|
87
|
+
}>>;
|
|
88
|
+
/** Output activation function (default none/linear). Ignored if output_constraints is set. */
|
|
89
|
+
readonly output_activation: OptionType<VariantType<{
|
|
90
|
+
/** No activation (linear output) - default */
|
|
91
|
+
readonly none: NullType;
|
|
92
|
+
/** Softmax (outputs sum to 1, for probability distributions) */
|
|
93
|
+
readonly softmax: NullType;
|
|
94
|
+
/** Sigmoid (each output independently in [0,1]) */
|
|
95
|
+
readonly sigmoid: NullType;
|
|
96
|
+
}>>;
|
|
97
|
+
/** Dropout rate (default 0.0) */
|
|
98
|
+
readonly dropout: OptionType<FloatType>;
|
|
99
|
+
/** Output dimension (default 1) */
|
|
100
|
+
readonly output_dim: OptionType<IntegerType>;
|
|
101
|
+
}>;
|
|
102
|
+
/**
|
|
103
|
+
* Configuration for training.
|
|
104
|
+
*/
|
|
105
|
+
export declare const TorchTrainConfigType: StructType<{
|
|
106
|
+
/** Number of epochs (default 100) */
|
|
107
|
+
readonly epochs: OptionType<IntegerType>;
|
|
108
|
+
/** Batch size (default 32) */
|
|
109
|
+
readonly batch_size: OptionType<IntegerType>;
|
|
110
|
+
/** Learning rate (default 0.001) */
|
|
111
|
+
readonly learning_rate: OptionType<FloatType>;
|
|
112
|
+
/** Loss function (default mse) */
|
|
113
|
+
readonly loss: OptionType<VariantType<{
|
|
114
|
+
/** Mean Squared Error (regression) */
|
|
115
|
+
readonly mse: NullType;
|
|
116
|
+
/** Mean Absolute Error (regression) */
|
|
117
|
+
readonly mae: NullType;
|
|
118
|
+
/** Cross Entropy (multi-class classification with integer targets) */
|
|
119
|
+
readonly cross_entropy: NullType;
|
|
120
|
+
/** KL Divergence (distribution matching, use with softmax output) */
|
|
121
|
+
readonly kl_div: NullType;
|
|
122
|
+
/** Binary Cross Entropy (multi-label binary, requires sigmoid output) */
|
|
123
|
+
readonly bce: NullType;
|
|
124
|
+
/** Binary Cross Entropy with Logits (more stable, applies sigmoid internally - do NOT use with sigmoid output_activation) */
|
|
125
|
+
readonly bce_with_logits: NullType;
|
|
126
|
+
}>>;
|
|
127
|
+
/** Optimizer (default adam) */
|
|
128
|
+
readonly optimizer: OptionType<VariantType<{
|
|
129
|
+
/** Adam optimizer */
|
|
130
|
+
readonly adam: NullType;
|
|
131
|
+
/** Stochastic Gradient Descent */
|
|
132
|
+
readonly sgd: NullType;
|
|
133
|
+
/** AdamW with weight decay */
|
|
134
|
+
readonly adamw: NullType;
|
|
135
|
+
/** RMSprop optimizer */
|
|
136
|
+
readonly rmsprop: NullType;
|
|
137
|
+
}>>;
|
|
138
|
+
/** Early stopping patience, 0 = disabled */
|
|
139
|
+
readonly early_stopping: OptionType<IntegerType>;
|
|
140
|
+
/** Validation split fraction (default 0.2) */
|
|
141
|
+
readonly validation_split: OptionType<FloatType>;
|
|
142
|
+
/** Random seed for reproducibility */
|
|
143
|
+
readonly random_state: OptionType<IntegerType>;
|
|
144
|
+
}>;
|
|
145
|
+
/**
|
|
146
|
+
* Result type for training.
|
|
147
|
+
*/
|
|
148
|
+
export declare const TorchTrainResultType: StructType<{
|
|
149
|
+
/** Training loss per epoch */
|
|
150
|
+
readonly train_losses: VectorType<FloatType>;
|
|
151
|
+
/** Validation loss per epoch */
|
|
152
|
+
readonly val_losses: VectorType<FloatType>;
|
|
153
|
+
/** Best epoch (for early stopping) */
|
|
154
|
+
readonly best_epoch: IntegerType;
|
|
155
|
+
}>;
|
|
156
|
+
/**
|
|
157
|
+
* Combined result from training (model + metrics).
|
|
158
|
+
*/
|
|
159
|
+
export declare const TorchTrainOutputType: StructType<{
|
|
160
|
+
/** Trained model blob */
|
|
161
|
+
readonly model: VariantType<{
|
|
162
|
+
readonly torch_mlp: StructType<{
|
|
163
|
+
readonly data: BlobType;
|
|
164
|
+
readonly n_features: IntegerType;
|
|
165
|
+
readonly hidden_layers: ArrayType<IntegerType>;
|
|
166
|
+
readonly output_dim: IntegerType;
|
|
167
|
+
}>;
|
|
168
|
+
}>;
|
|
169
|
+
/** Training result with losses */
|
|
170
|
+
readonly result: StructType<{
|
|
171
|
+
/** Training loss per epoch */
|
|
172
|
+
readonly train_losses: VectorType<FloatType>;
|
|
173
|
+
/** Validation loss per epoch */
|
|
174
|
+
readonly val_losses: VectorType<FloatType>;
|
|
175
|
+
/** Best epoch (for early stopping) */
|
|
176
|
+
readonly best_epoch: IntegerType;
|
|
177
|
+
}>;
|
|
178
|
+
}>;
|
|
179
|
+
/**
|
|
180
|
+
* Model blob type for serialized PyTorch models.
|
|
181
|
+
*/
|
|
182
|
+
export declare const TorchModelBlobType: VariantType<{
|
|
183
|
+
/** PyTorch MLP model */
|
|
184
|
+
readonly torch_mlp: StructType<{
|
|
185
|
+
/** Cloudpickle serialized model */
|
|
186
|
+
readonly data: BlobType;
|
|
187
|
+
/** Number of input features */
|
|
188
|
+
readonly n_features: IntegerType;
|
|
189
|
+
/** Hidden layer sizes */
|
|
190
|
+
readonly hidden_layers: ArrayType<IntegerType>;
|
|
191
|
+
/** Output dimension */
|
|
192
|
+
readonly output_dim: IntegerType;
|
|
193
|
+
}>;
|
|
194
|
+
}>;
|
|
195
|
+
/**
|
|
196
|
+
* Train a PyTorch MLP model.
|
|
197
|
+
*
|
|
198
|
+
* @param X - Feature matrix
|
|
199
|
+
* @param y - Target vector
|
|
200
|
+
* @param mlp_config - MLP architecture configuration
|
|
201
|
+
* @param train_config - Training configuration
|
|
202
|
+
* @returns Model blob and training result
|
|
203
|
+
*/
|
|
204
|
+
export declare const torch_mlp_train: import("@elaraai/east").PlatformDefinition<[MatrixType<FloatType>, VectorType<FloatType>, StructType<{
|
|
205
|
+
/** Hidden layer sizes, e.g., [64, 32] */
|
|
206
|
+
readonly hidden_layers: ArrayType<IntegerType>;
|
|
207
|
+
/** Activation function for hidden layers (default relu) */
|
|
208
|
+
readonly activation: OptionType<VariantType<{
|
|
209
|
+
/** Rectified Linear Unit */
|
|
210
|
+
readonly relu: NullType;
|
|
211
|
+
/** Hyperbolic tangent */
|
|
212
|
+
readonly tanh: NullType;
|
|
213
|
+
/** Sigmoid function */
|
|
214
|
+
readonly sigmoid: NullType;
|
|
215
|
+
/** Leaky ReLU */
|
|
216
|
+
readonly leaky_relu: NullType;
|
|
217
|
+
}>>;
|
|
218
|
+
/** Output activation function (default none/linear). Ignored if output_constraints is set. */
|
|
219
|
+
readonly output_activation: OptionType<VariantType<{
|
|
220
|
+
/** No activation (linear output) - default */
|
|
221
|
+
readonly none: NullType;
|
|
222
|
+
/** Softmax (outputs sum to 1, for probability distributions) */
|
|
223
|
+
readonly softmax: NullType;
|
|
224
|
+
/** Sigmoid (each output independently in [0,1]) */
|
|
225
|
+
readonly sigmoid: NullType;
|
|
226
|
+
}>>;
|
|
227
|
+
/** Dropout rate (default 0.0) */
|
|
228
|
+
readonly dropout: OptionType<FloatType>;
|
|
229
|
+
/** Output dimension (default 1) */
|
|
230
|
+
readonly output_dim: OptionType<IntegerType>;
|
|
231
|
+
}>, StructType<{
|
|
232
|
+
/** Number of epochs (default 100) */
|
|
233
|
+
readonly epochs: OptionType<IntegerType>;
|
|
234
|
+
/** Batch size (default 32) */
|
|
235
|
+
readonly batch_size: OptionType<IntegerType>;
|
|
236
|
+
/** Learning rate (default 0.001) */
|
|
237
|
+
readonly learning_rate: OptionType<FloatType>;
|
|
238
|
+
/** Loss function (default mse) */
|
|
239
|
+
readonly loss: OptionType<VariantType<{
|
|
240
|
+
/** Mean Squared Error (regression) */
|
|
241
|
+
readonly mse: NullType;
|
|
242
|
+
/** Mean Absolute Error (regression) */
|
|
243
|
+
readonly mae: NullType;
|
|
244
|
+
/** Cross Entropy (multi-class classification with integer targets) */
|
|
245
|
+
readonly cross_entropy: NullType;
|
|
246
|
+
/** KL Divergence (distribution matching, use with softmax output) */
|
|
247
|
+
readonly kl_div: NullType;
|
|
248
|
+
/** Binary Cross Entropy (multi-label binary, requires sigmoid output) */
|
|
249
|
+
readonly bce: NullType;
|
|
250
|
+
/** Binary Cross Entropy with Logits (more stable, applies sigmoid internally - do NOT use with sigmoid output_activation) */
|
|
251
|
+
readonly bce_with_logits: NullType;
|
|
252
|
+
}>>;
|
|
253
|
+
/** Optimizer (default adam) */
|
|
254
|
+
readonly optimizer: OptionType<VariantType<{
|
|
255
|
+
/** Adam optimizer */
|
|
256
|
+
readonly adam: NullType;
|
|
257
|
+
/** Stochastic Gradient Descent */
|
|
258
|
+
readonly sgd: NullType;
|
|
259
|
+
/** AdamW with weight decay */
|
|
260
|
+
readonly adamw: NullType;
|
|
261
|
+
/** RMSprop optimizer */
|
|
262
|
+
readonly rmsprop: NullType;
|
|
263
|
+
}>>;
|
|
264
|
+
/** Early stopping patience, 0 = disabled */
|
|
265
|
+
readonly early_stopping: OptionType<IntegerType>;
|
|
266
|
+
/** Validation split fraction (default 0.2) */
|
|
267
|
+
readonly validation_split: OptionType<FloatType>;
|
|
268
|
+
/** Random seed for reproducibility */
|
|
269
|
+
readonly random_state: OptionType<IntegerType>;
|
|
270
|
+
}>], StructType<{
|
|
271
|
+
/** Trained model blob */
|
|
272
|
+
readonly model: VariantType<{
|
|
273
|
+
readonly torch_mlp: StructType<{
|
|
274
|
+
readonly data: BlobType;
|
|
275
|
+
readonly n_features: IntegerType;
|
|
276
|
+
readonly hidden_layers: ArrayType<IntegerType>;
|
|
277
|
+
readonly output_dim: IntegerType;
|
|
278
|
+
}>;
|
|
279
|
+
}>;
|
|
280
|
+
/** Training result with losses */
|
|
281
|
+
readonly result: StructType<{
|
|
282
|
+
/** Training loss per epoch */
|
|
283
|
+
readonly train_losses: VectorType<FloatType>;
|
|
284
|
+
/** Validation loss per epoch */
|
|
285
|
+
readonly val_losses: VectorType<FloatType>;
|
|
286
|
+
/** Best epoch (for early stopping) */
|
|
287
|
+
readonly best_epoch: IntegerType;
|
|
288
|
+
}>;
|
|
289
|
+
}>>;
|
|
290
|
+
/**
|
|
291
|
+
* Make predictions with a trained PyTorch MLP.
|
|
292
|
+
*
|
|
293
|
+
* @param model - Trained MLP model blob
|
|
294
|
+
* @param X - Feature matrix
|
|
295
|
+
* @returns Predicted values
|
|
296
|
+
*/
|
|
297
|
+
export declare const torch_mlp_predict: import("@elaraai/east").PlatformDefinition<[VariantType<{
|
|
298
|
+
/** PyTorch MLP model */
|
|
299
|
+
readonly torch_mlp: StructType<{
|
|
300
|
+
/** Cloudpickle serialized model */
|
|
301
|
+
readonly data: BlobType;
|
|
302
|
+
/** Number of input features */
|
|
303
|
+
readonly n_features: IntegerType;
|
|
304
|
+
/** Hidden layer sizes */
|
|
305
|
+
readonly hidden_layers: ArrayType<IntegerType>;
|
|
306
|
+
/** Output dimension */
|
|
307
|
+
readonly output_dim: IntegerType;
|
|
308
|
+
}>;
|
|
309
|
+
}>, MatrixType<FloatType>], VectorType<FloatType>>;
|
|
310
|
+
/**
|
|
311
|
+
* Train a PyTorch MLP model with multi-output support.
|
|
312
|
+
*
|
|
313
|
+
* Supports multi-output regression (predicting multiple values per sample)
|
|
314
|
+
* and autoencoders (where input equals target for reconstruction learning).
|
|
315
|
+
* Output dimension is inferred from y.shape[1] unless overridden in config.
|
|
316
|
+
*
|
|
317
|
+
* @param X - Feature matrix (n_samples x n_features)
|
|
318
|
+
* @param y - Target matrix (n_samples x n_outputs)
|
|
319
|
+
* @param mlp_config - MLP architecture configuration
|
|
320
|
+
* @param train_config - Training configuration
|
|
321
|
+
* @returns Model blob and training result
|
|
322
|
+
*/
|
|
323
|
+
export declare const torch_mlp_train_multi: import("@elaraai/east").PlatformDefinition<[MatrixType<FloatType>, MatrixType<FloatType>, StructType<{
|
|
324
|
+
/** Hidden layer sizes, e.g., [64, 32] */
|
|
325
|
+
readonly hidden_layers: ArrayType<IntegerType>;
|
|
326
|
+
/** Activation function for hidden layers (default relu) */
|
|
327
|
+
readonly activation: OptionType<VariantType<{
|
|
328
|
+
/** Rectified Linear Unit */
|
|
329
|
+
readonly relu: NullType;
|
|
330
|
+
/** Hyperbolic tangent */
|
|
331
|
+
readonly tanh: NullType;
|
|
332
|
+
/** Sigmoid function */
|
|
333
|
+
readonly sigmoid: NullType;
|
|
334
|
+
/** Leaky ReLU */
|
|
335
|
+
readonly leaky_relu: NullType;
|
|
336
|
+
}>>;
|
|
337
|
+
/** Output activation function (default none/linear). Ignored if output_constraints is set. */
|
|
338
|
+
readonly output_activation: OptionType<VariantType<{
|
|
339
|
+
/** No activation (linear output) - default */
|
|
340
|
+
readonly none: NullType;
|
|
341
|
+
/** Softmax (outputs sum to 1, for probability distributions) */
|
|
342
|
+
readonly softmax: NullType;
|
|
343
|
+
/** Sigmoid (each output independently in [0,1]) */
|
|
344
|
+
readonly sigmoid: NullType;
|
|
345
|
+
}>>;
|
|
346
|
+
/** Dropout rate (default 0.0) */
|
|
347
|
+
readonly dropout: OptionType<FloatType>;
|
|
348
|
+
/** Output dimension (default 1) */
|
|
349
|
+
readonly output_dim: OptionType<IntegerType>;
|
|
350
|
+
}>, StructType<{
|
|
351
|
+
/** Number of epochs (default 100) */
|
|
352
|
+
readonly epochs: OptionType<IntegerType>;
|
|
353
|
+
/** Batch size (default 32) */
|
|
354
|
+
readonly batch_size: OptionType<IntegerType>;
|
|
355
|
+
/** Learning rate (default 0.001) */
|
|
356
|
+
readonly learning_rate: OptionType<FloatType>;
|
|
357
|
+
/** Loss function (default mse) */
|
|
358
|
+
readonly loss: OptionType<VariantType<{
|
|
359
|
+
/** Mean Squared Error (regression) */
|
|
360
|
+
readonly mse: NullType;
|
|
361
|
+
/** Mean Absolute Error (regression) */
|
|
362
|
+
readonly mae: NullType;
|
|
363
|
+
/** Cross Entropy (multi-class classification with integer targets) */
|
|
364
|
+
readonly cross_entropy: NullType;
|
|
365
|
+
/** KL Divergence (distribution matching, use with softmax output) */
|
|
366
|
+
readonly kl_div: NullType;
|
|
367
|
+
/** Binary Cross Entropy (multi-label binary, requires sigmoid output) */
|
|
368
|
+
readonly bce: NullType;
|
|
369
|
+
/** Binary Cross Entropy with Logits (more stable, applies sigmoid internally - do NOT use with sigmoid output_activation) */
|
|
370
|
+
readonly bce_with_logits: NullType;
|
|
371
|
+
}>>;
|
|
372
|
+
/** Optimizer (default adam) */
|
|
373
|
+
readonly optimizer: OptionType<VariantType<{
|
|
374
|
+
/** Adam optimizer */
|
|
375
|
+
readonly adam: NullType;
|
|
376
|
+
/** Stochastic Gradient Descent */
|
|
377
|
+
readonly sgd: NullType;
|
|
378
|
+
/** AdamW with weight decay */
|
|
379
|
+
readonly adamw: NullType;
|
|
380
|
+
/** RMSprop optimizer */
|
|
381
|
+
readonly rmsprop: NullType;
|
|
382
|
+
}>>;
|
|
383
|
+
/** Early stopping patience, 0 = disabled */
|
|
384
|
+
readonly early_stopping: OptionType<IntegerType>;
|
|
385
|
+
/** Validation split fraction (default 0.2) */
|
|
386
|
+
readonly validation_split: OptionType<FloatType>;
|
|
387
|
+
/** Random seed for reproducibility */
|
|
388
|
+
readonly random_state: OptionType<IntegerType>;
|
|
389
|
+
}>], StructType<{
|
|
390
|
+
/** Trained model blob */
|
|
391
|
+
readonly model: VariantType<{
|
|
392
|
+
readonly torch_mlp: StructType<{
|
|
393
|
+
readonly data: BlobType;
|
|
394
|
+
readonly n_features: IntegerType;
|
|
395
|
+
readonly hidden_layers: ArrayType<IntegerType>;
|
|
396
|
+
readonly output_dim: IntegerType;
|
|
397
|
+
}>;
|
|
398
|
+
}>;
|
|
399
|
+
/** Training result with losses */
|
|
400
|
+
readonly result: StructType<{
|
|
401
|
+
/** Training loss per epoch */
|
|
402
|
+
readonly train_losses: VectorType<FloatType>;
|
|
403
|
+
/** Validation loss per epoch */
|
|
404
|
+
readonly val_losses: VectorType<FloatType>;
|
|
405
|
+
/** Best epoch (for early stopping) */
|
|
406
|
+
readonly best_epoch: IntegerType;
|
|
407
|
+
}>;
|
|
408
|
+
}>>;
|
|
409
|
+
/**
|
|
410
|
+
* Make predictions with a trained PyTorch MLP (multi-output).
|
|
411
|
+
*
|
|
412
|
+
* Returns a matrix where each row contains the predicted outputs for a sample.
|
|
413
|
+
*
|
|
414
|
+
* @param model - Trained MLP model blob
|
|
415
|
+
* @param X - Feature matrix (n_samples x n_features)
|
|
416
|
+
* @returns Predicted matrix (n_samples x n_outputs)
|
|
417
|
+
*/
|
|
418
|
+
export declare const torch_mlp_predict_multi: import("@elaraai/east").PlatformDefinition<[VariantType<{
|
|
419
|
+
/** PyTorch MLP model */
|
|
420
|
+
readonly torch_mlp: StructType<{
|
|
421
|
+
/** Cloudpickle serialized model */
|
|
422
|
+
readonly data: BlobType;
|
|
423
|
+
/** Number of input features */
|
|
424
|
+
readonly n_features: IntegerType;
|
|
425
|
+
/** Hidden layer sizes */
|
|
426
|
+
readonly hidden_layers: ArrayType<IntegerType>;
|
|
427
|
+
/** Output dimension */
|
|
428
|
+
readonly output_dim: IntegerType;
|
|
429
|
+
}>;
|
|
430
|
+
}>, MatrixType<FloatType>], MatrixType<FloatType>>;
|
|
431
|
+
/**
|
|
432
|
+
* Extract intermediate layer activations (embeddings) from a trained MLP.
|
|
433
|
+
*
|
|
434
|
+
* For autoencoders, this allows extracting the bottleneck representation.
|
|
435
|
+
* The layer_index specifies which hidden layer's output to return (0-indexed).
|
|
436
|
+
*
|
|
437
|
+
* For an autoencoder with architecture [input -> 8 -> 2 -> 8 -> output]
|
|
438
|
+
* (hidden_layers: [8, 2, 8]):
|
|
439
|
+
* - layer_index=0: output after first hidden layer (8 features)
|
|
440
|
+
* - layer_index=1: output after second hidden layer (2 features) <- bottleneck
|
|
441
|
+
* - layer_index=2: output after third hidden layer (8 features)
|
|
442
|
+
*
|
|
443
|
+
* @param model - Trained MLP model blob
|
|
444
|
+
* @param X - Feature matrix (n_samples x n_features)
|
|
445
|
+
* @param layer_index - Which hidden layer's output to return (0-indexed)
|
|
446
|
+
* @returns Embedding matrix (n_samples x hidden_dim at that layer)
|
|
447
|
+
*
|
|
448
|
+
* @example
|
|
449
|
+
* ```ts
|
|
450
|
+
* // Train autoencoder: 4 features -> 8 -> 2 (bottleneck) -> 8 -> 4 features
|
|
451
|
+
* const mlp_config = $.let({
|
|
452
|
+
* hidden_layers: [8n, 2n, 8n],
|
|
453
|
+
* activation: variant('some', variant('relu', {})),
|
|
454
|
+
* dropout: variant('none', null),
|
|
455
|
+
* output_dim: variant('none', null),
|
|
456
|
+
* });
|
|
457
|
+
* const output = $.let(Torch.mlpTrainMulti(X, X, mlp_config, train_config));
|
|
458
|
+
*
|
|
459
|
+
* // Extract bottleneck embeddings (layer_index=1 for the 2-dim bottleneck)
|
|
460
|
+
* const embeddings = $.let(Torch.mlpEncode(output.model, X, 1n));
|
|
461
|
+
* // embeddings is now (n_samples x 2)
|
|
462
|
+
* ```
|
|
463
|
+
*/
|
|
464
|
+
export declare const torch_mlp_encode: import("@elaraai/east").PlatformDefinition<[VariantType<{
|
|
465
|
+
/** PyTorch MLP model */
|
|
466
|
+
readonly torch_mlp: StructType<{
|
|
467
|
+
/** Cloudpickle serialized model */
|
|
468
|
+
readonly data: BlobType;
|
|
469
|
+
/** Number of input features */
|
|
470
|
+
readonly n_features: IntegerType;
|
|
471
|
+
/** Hidden layer sizes */
|
|
472
|
+
readonly hidden_layers: ArrayType<IntegerType>;
|
|
473
|
+
/** Output dimension */
|
|
474
|
+
readonly output_dim: IntegerType;
|
|
475
|
+
}>;
|
|
476
|
+
}>, MatrixType<FloatType>, IntegerType], MatrixType<FloatType>>;
|
|
477
|
+
/**
|
|
478
|
+
* Decode embeddings back through the decoder portion of an MLP.
|
|
479
|
+
*
|
|
480
|
+
* For autoencoders, this takes bottleneck activations and runs them through
|
|
481
|
+
* the decoder to reconstruct the output. This is the complement to mlpEncode.
|
|
482
|
+
*
|
|
483
|
+
* For an autoencoder with architecture [input -> 8 -> 2 -> 8 -> output]
|
|
484
|
+
* (hidden_layers: [8, 2, 8]):
|
|
485
|
+
* - layer_index=1: Start from the 2-dim bottleneck, run through layers 2+ to output
|
|
486
|
+
* - layer_index=0: Start from the 8-dim first layer, run through layers 1+ to output
|
|
487
|
+
*
|
|
488
|
+
* Use case: Compute weighted average of origin embeddings, then decode to
|
|
489
|
+
* get the reconstructed blend weight distribution.
|
|
490
|
+
*
|
|
491
|
+
* @param model - Trained MLP model blob
|
|
492
|
+
* @param embeddings - Embedding matrix (n_samples x hidden_dim at layer_index)
|
|
493
|
+
* @param layer_index - Which hidden layer the embeddings come from (0-indexed)
|
|
494
|
+
* @returns Decoded output matrix (n_samples x output_dim)
|
|
495
|
+
*
|
|
496
|
+
* @example
|
|
497
|
+
* ```ts
|
|
498
|
+
* // After training autoencoder and extracting embeddings...
|
|
499
|
+
* const origin_embeddings = $.let(Torch.mlpEncode(output.model, X_onehot, 1n));
|
|
500
|
+
*
|
|
501
|
+
* // Compute weighted blend embedding (e.g., 50% origin A + 50% origin B)
|
|
502
|
+
* const blend_embedding = $.let(...); // weighted average of origin embeddings
|
|
503
|
+
*
|
|
504
|
+
* // Decode back to weight distribution
|
|
505
|
+
* const reconstructed = $.let(Torch.mlpDecode(output.model, blend_embedding, 1n));
|
|
506
|
+
* ```
|
|
507
|
+
*/
|
|
508
|
+
export declare const torch_mlp_decode: import("@elaraai/east").PlatformDefinition<[VariantType<{
|
|
509
|
+
/** PyTorch MLP model */
|
|
510
|
+
readonly torch_mlp: StructType<{
|
|
511
|
+
/** Cloudpickle serialized model */
|
|
512
|
+
readonly data: BlobType;
|
|
513
|
+
/** Number of input features */
|
|
514
|
+
readonly n_features: IntegerType;
|
|
515
|
+
/** Hidden layer sizes */
|
|
516
|
+
readonly hidden_layers: ArrayType<IntegerType>;
|
|
517
|
+
/** Output dimension */
|
|
518
|
+
readonly output_dim: IntegerType;
|
|
519
|
+
}>;
|
|
520
|
+
}>, MatrixType<FloatType>, IntegerType], MatrixType<FloatType>>;
|
|
521
|
+
/**
|
|
522
|
+
* Type definitions for PyTorch functions.
|
|
523
|
+
*/
|
|
524
|
+
export declare const TorchTypes: {
|
|
525
|
+
/** Activation function type for hidden layers */
|
|
526
|
+
readonly TorchActivationType: VariantType<{
|
|
527
|
+
/** Rectified Linear Unit */
|
|
528
|
+
readonly relu: NullType;
|
|
529
|
+
/** Hyperbolic tangent */
|
|
530
|
+
readonly tanh: NullType;
|
|
531
|
+
/** Sigmoid function */
|
|
532
|
+
readonly sigmoid: NullType;
|
|
533
|
+
/** Leaky ReLU */
|
|
534
|
+
readonly leaky_relu: NullType;
|
|
535
|
+
}>;
|
|
536
|
+
/** Output activation function type */
|
|
537
|
+
readonly TorchOutputActivationType: VariantType<{
|
|
538
|
+
/** No activation (linear output) - default */
|
|
539
|
+
readonly none: NullType;
|
|
540
|
+
/** Softmax (outputs sum to 1, for probability distributions) */
|
|
541
|
+
readonly softmax: NullType;
|
|
542
|
+
/** Sigmoid (each output independently in [0,1]) */
|
|
543
|
+
readonly sigmoid: NullType;
|
|
544
|
+
}>;
|
|
545
|
+
/** Loss function type */
|
|
546
|
+
readonly TorchLossType: VariantType<{
|
|
547
|
+
/** Mean Squared Error (regression) */
|
|
548
|
+
readonly mse: NullType;
|
|
549
|
+
/** Mean Absolute Error (regression) */
|
|
550
|
+
readonly mae: NullType;
|
|
551
|
+
/** Cross Entropy (multi-class classification with integer targets) */
|
|
552
|
+
readonly cross_entropy: NullType;
|
|
553
|
+
/** KL Divergence (distribution matching, use with softmax output) */
|
|
554
|
+
readonly kl_div: NullType;
|
|
555
|
+
/** Binary Cross Entropy (multi-label binary, requires sigmoid output) */
|
|
556
|
+
readonly bce: NullType;
|
|
557
|
+
/** Binary Cross Entropy with Logits (more stable, applies sigmoid internally - do NOT use with sigmoid output_activation) */
|
|
558
|
+
readonly bce_with_logits: NullType;
|
|
559
|
+
}>;
|
|
560
|
+
/** Optimizer type */
|
|
561
|
+
readonly TorchOptimizerType: VariantType<{
|
|
562
|
+
/** Adam optimizer */
|
|
563
|
+
readonly adam: NullType;
|
|
564
|
+
/** Stochastic Gradient Descent */
|
|
565
|
+
readonly sgd: NullType;
|
|
566
|
+
/** AdamW with weight decay */
|
|
567
|
+
readonly adamw: NullType;
|
|
568
|
+
/** RMSprop optimizer */
|
|
569
|
+
readonly rmsprop: NullType;
|
|
570
|
+
}>;
|
|
571
|
+
/** MLP configuration type */
|
|
572
|
+
readonly TorchMLPConfigType: StructType<{
|
|
573
|
+
/** Hidden layer sizes, e.g., [64, 32] */
|
|
574
|
+
readonly hidden_layers: ArrayType<IntegerType>;
|
|
575
|
+
/** Activation function for hidden layers (default relu) */
|
|
576
|
+
readonly activation: OptionType<VariantType<{
|
|
577
|
+
/** Rectified Linear Unit */
|
|
578
|
+
readonly relu: NullType;
|
|
579
|
+
/** Hyperbolic tangent */
|
|
580
|
+
readonly tanh: NullType;
|
|
581
|
+
/** Sigmoid function */
|
|
582
|
+
readonly sigmoid: NullType;
|
|
583
|
+
/** Leaky ReLU */
|
|
584
|
+
readonly leaky_relu: NullType;
|
|
585
|
+
}>>;
|
|
586
|
+
/** Output activation function (default none/linear). Ignored if output_constraints is set. */
|
|
587
|
+
readonly output_activation: OptionType<VariantType<{
|
|
588
|
+
/** No activation (linear output) - default */
|
|
589
|
+
readonly none: NullType;
|
|
590
|
+
/** Softmax (outputs sum to 1, for probability distributions) */
|
|
591
|
+
readonly softmax: NullType;
|
|
592
|
+
/** Sigmoid (each output independently in [0,1]) */
|
|
593
|
+
readonly sigmoid: NullType;
|
|
594
|
+
}>>;
|
|
595
|
+
/** Dropout rate (default 0.0) */
|
|
596
|
+
readonly dropout: OptionType<FloatType>;
|
|
597
|
+
/** Output dimension (default 1) */
|
|
598
|
+
readonly output_dim: OptionType<IntegerType>;
|
|
599
|
+
}>;
|
|
600
|
+
/** Training configuration type */
|
|
601
|
+
readonly TorchTrainConfigType: StructType<{
|
|
602
|
+
/** Number of epochs (default 100) */
|
|
603
|
+
readonly epochs: OptionType<IntegerType>;
|
|
604
|
+
/** Batch size (default 32) */
|
|
605
|
+
readonly batch_size: OptionType<IntegerType>;
|
|
606
|
+
/** Learning rate (default 0.001) */
|
|
607
|
+
readonly learning_rate: OptionType<FloatType>;
|
|
608
|
+
/** Loss function (default mse) */
|
|
609
|
+
readonly loss: OptionType<VariantType<{
|
|
610
|
+
/** Mean Squared Error (regression) */
|
|
611
|
+
readonly mse: NullType;
|
|
612
|
+
/** Mean Absolute Error (regression) */
|
|
613
|
+
readonly mae: NullType;
|
|
614
|
+
/** Cross Entropy (multi-class classification with integer targets) */
|
|
615
|
+
readonly cross_entropy: NullType;
|
|
616
|
+
/** KL Divergence (distribution matching, use with softmax output) */
|
|
617
|
+
readonly kl_div: NullType;
|
|
618
|
+
/** Binary Cross Entropy (multi-label binary, requires sigmoid output) */
|
|
619
|
+
readonly bce: NullType;
|
|
620
|
+
/** Binary Cross Entropy with Logits (more stable, applies sigmoid internally - do NOT use with sigmoid output_activation) */
|
|
621
|
+
readonly bce_with_logits: NullType;
|
|
622
|
+
}>>;
|
|
623
|
+
/** Optimizer (default adam) */
|
|
624
|
+
readonly optimizer: OptionType<VariantType<{
|
|
625
|
+
/** Adam optimizer */
|
|
626
|
+
readonly adam: NullType;
|
|
627
|
+
/** Stochastic Gradient Descent */
|
|
628
|
+
readonly sgd: NullType;
|
|
629
|
+
/** AdamW with weight decay */
|
|
630
|
+
readonly adamw: NullType;
|
|
631
|
+
/** RMSprop optimizer */
|
|
632
|
+
readonly rmsprop: NullType;
|
|
633
|
+
}>>;
|
|
634
|
+
/** Early stopping patience, 0 = disabled */
|
|
635
|
+
readonly early_stopping: OptionType<IntegerType>;
|
|
636
|
+
/** Validation split fraction (default 0.2) */
|
|
637
|
+
readonly validation_split: OptionType<FloatType>;
|
|
638
|
+
/** Random seed for reproducibility */
|
|
639
|
+
readonly random_state: OptionType<IntegerType>;
|
|
640
|
+
}>;
|
|
641
|
+
/** Training result type */
|
|
642
|
+
readonly TorchTrainResultType: StructType<{
|
|
643
|
+
/** Training loss per epoch */
|
|
644
|
+
readonly train_losses: VectorType<FloatType>;
|
|
645
|
+
/** Validation loss per epoch */
|
|
646
|
+
readonly val_losses: VectorType<FloatType>;
|
|
647
|
+
/** Best epoch (for early stopping) */
|
|
648
|
+
readonly best_epoch: IntegerType;
|
|
649
|
+
}>;
|
|
650
|
+
/** Training output type (model + result) */
|
|
651
|
+
readonly TorchTrainOutputType: StructType<{
|
|
652
|
+
/** Trained model blob */
|
|
653
|
+
readonly model: VariantType<{
|
|
654
|
+
readonly torch_mlp: StructType<{
|
|
655
|
+
readonly data: BlobType;
|
|
656
|
+
readonly n_features: IntegerType;
|
|
657
|
+
readonly hidden_layers: ArrayType<IntegerType>;
|
|
658
|
+
readonly output_dim: IntegerType;
|
|
659
|
+
}>;
|
|
660
|
+
}>;
|
|
661
|
+
/** Training result with losses */
|
|
662
|
+
readonly result: StructType<{
|
|
663
|
+
/** Training loss per epoch */
|
|
664
|
+
readonly train_losses: VectorType<FloatType>;
|
|
665
|
+
/** Validation loss per epoch */
|
|
666
|
+
readonly val_losses: VectorType<FloatType>;
|
|
667
|
+
/** Best epoch (for early stopping) */
|
|
668
|
+
readonly best_epoch: IntegerType;
|
|
669
|
+
}>;
|
|
670
|
+
}>;
|
|
671
|
+
/** Model blob type for PyTorch models */
|
|
672
|
+
readonly ModelBlobType: VariantType<{
|
|
673
|
+
/** PyTorch MLP model */
|
|
674
|
+
readonly torch_mlp: StructType<{
|
|
675
|
+
/** Cloudpickle serialized model */
|
|
676
|
+
readonly data: BlobType;
|
|
677
|
+
/** Number of input features */
|
|
678
|
+
readonly n_features: IntegerType;
|
|
679
|
+
/** Hidden layer sizes */
|
|
680
|
+
readonly hidden_layers: ArrayType<IntegerType>;
|
|
681
|
+
/** Output dimension */
|
|
682
|
+
readonly output_dim: IntegerType;
|
|
683
|
+
}>;
|
|
684
|
+
}>;
|
|
685
|
+
};
|
|
686
|
+
/**
|
|
687
|
+
* PyTorch neural network models.
|
|
688
|
+
*
|
|
689
|
+
* Provides MLP training and inference using PyTorch.
|
|
690
|
+
*
|
|
691
|
+
* @example
|
|
692
|
+
* ```ts
|
|
693
|
+
* import { East, variant } from "@elaraai/east";
|
|
694
|
+
* import { Torch } from "@elaraai/east-py-datascience";
|
|
695
|
+
*
|
|
696
|
+
* const train = East.function([], Torch.Types.TorchTrainOutputType, $ => {
|
|
697
|
+
* const X = $.let([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]);
|
|
698
|
+
* const y = $.let([1.0, 2.0, 3.0, 4.0]);
|
|
699
|
+
* const mlp_config = $.let({
|
|
700
|
+
* hidden_layers: [32n, 16n],
|
|
701
|
+
* activation: variant('none', null),
|
|
702
|
+
* dropout: variant('none', null),
|
|
703
|
+
* output_dim: variant('none', null),
|
|
704
|
+
* });
|
|
705
|
+
* const train_config = $.let({
|
|
706
|
+
* epochs: variant('some', 50n),
|
|
707
|
+
* batch_size: variant('some', 4n),
|
|
708
|
+
* learning_rate: variant('some', 0.01),
|
|
709
|
+
* loss: variant('none', null),
|
|
710
|
+
* optimizer: variant('none', null),
|
|
711
|
+
* early_stopping: variant('none', null),
|
|
712
|
+
* validation_split: variant('some', 0.2),
|
|
713
|
+
* random_state: variant('some', 42n),
|
|
714
|
+
* });
|
|
715
|
+
* return $.return(Torch.mlpTrain(X, y, mlp_config, train_config));
|
|
716
|
+
* });
|
|
717
|
+
* ```
|
|
718
|
+
*/
|
|
719
|
+
export declare const Torch: {
|
|
720
|
+
/**
|
|
721
|
+
* Train a PyTorch MLP model (single output).
|
|
722
|
+
*
|
|
723
|
+
* @example
|
|
724
|
+
* ```ts
|
|
725
|
+
* import { East, FloatType, MatrixType, VectorType, variant } from "@elaraai/east";
|
|
726
|
+
* import { Torch, TorchMLPConfigType, TorchTrainConfigType } from "@elaraai/east-py-datascience";
|
|
727
|
+
*
|
|
728
|
+
* const train = East.function([], Torch.Types.TorchTrainOutputType, ($) => {
|
|
729
|
+
* const X = $.let([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]);
|
|
730
|
+
* const y = $.let(new Float64Array([1.0, 2.0, 3.0, 4.0]));
|
|
731
|
+
* const mlp_config = $.let({
|
|
732
|
+
* hidden_layers: new BigInt64Array([16n, 8n]),
|
|
733
|
+
* activation: variant("some", variant("relu", null)),
|
|
734
|
+
* dropout: variant("none", null),
|
|
735
|
+
* output_dim: variant("none", null),
|
|
736
|
+
* output_activation: variant("none", null),
|
|
737
|
+
* }, TorchMLPConfigType);
|
|
738
|
+
* const train_config = $.let({
|
|
739
|
+
* epochs: variant("some", 50n), batch_size: variant("some", 4n),
|
|
740
|
+
* learning_rate: variant("some", 0.01), loss: variant("none", null),
|
|
741
|
+
* optimizer: variant("none", null), early_stopping: variant("none", null),
|
|
742
|
+
* validation_split: variant("some", 0.2), random_state: variant("some", 42n),
|
|
743
|
+
* }, TorchTrainConfigType);
|
|
744
|
+
* return $.return(Torch.mlpTrain(X, y, mlp_config, train_config));
|
|
745
|
+
* });
|
|
746
|
+
* ```
|
|
747
|
+
*/
|
|
748
|
+
readonly mlpTrain: import("@elaraai/east").PlatformDefinition<[MatrixType<FloatType>, VectorType<FloatType>, StructType<{
|
|
749
|
+
/** Hidden layer sizes, e.g., [64, 32] */
|
|
750
|
+
readonly hidden_layers: ArrayType<IntegerType>;
|
|
751
|
+
/** Activation function for hidden layers (default relu) */
|
|
752
|
+
readonly activation: OptionType<VariantType<{
|
|
753
|
+
/** Rectified Linear Unit */
|
|
754
|
+
readonly relu: NullType;
|
|
755
|
+
/** Hyperbolic tangent */
|
|
756
|
+
readonly tanh: NullType;
|
|
757
|
+
/** Sigmoid function */
|
|
758
|
+
readonly sigmoid: NullType;
|
|
759
|
+
/** Leaky ReLU */
|
|
760
|
+
readonly leaky_relu: NullType;
|
|
761
|
+
}>>;
|
|
762
|
+
/** Output activation function (default none/linear). Ignored if output_constraints is set. */
|
|
763
|
+
readonly output_activation: OptionType<VariantType<{
|
|
764
|
+
/** No activation (linear output) - default */
|
|
765
|
+
readonly none: NullType;
|
|
766
|
+
/** Softmax (outputs sum to 1, for probability distributions) */
|
|
767
|
+
readonly softmax: NullType;
|
|
768
|
+
/** Sigmoid (each output independently in [0,1]) */
|
|
769
|
+
readonly sigmoid: NullType;
|
|
770
|
+
}>>;
|
|
771
|
+
/** Dropout rate (default 0.0) */
|
|
772
|
+
readonly dropout: OptionType<FloatType>;
|
|
773
|
+
/** Output dimension (default 1) */
|
|
774
|
+
readonly output_dim: OptionType<IntegerType>;
|
|
775
|
+
}>, StructType<{
|
|
776
|
+
/** Number of epochs (default 100) */
|
|
777
|
+
readonly epochs: OptionType<IntegerType>;
|
|
778
|
+
/** Batch size (default 32) */
|
|
779
|
+
readonly batch_size: OptionType<IntegerType>;
|
|
780
|
+
/** Learning rate (default 0.001) */
|
|
781
|
+
readonly learning_rate: OptionType<FloatType>;
|
|
782
|
+
/** Loss function (default mse) */
|
|
783
|
+
readonly loss: OptionType<VariantType<{
|
|
784
|
+
/** Mean Squared Error (regression) */
|
|
785
|
+
readonly mse: NullType;
|
|
786
|
+
/** Mean Absolute Error (regression) */
|
|
787
|
+
readonly mae: NullType;
|
|
788
|
+
/** Cross Entropy (multi-class classification with integer targets) */
|
|
789
|
+
readonly cross_entropy: NullType;
|
|
790
|
+
/** KL Divergence (distribution matching, use with softmax output) */
|
|
791
|
+
readonly kl_div: NullType;
|
|
792
|
+
/** Binary Cross Entropy (multi-label binary, requires sigmoid output) */
|
|
793
|
+
readonly bce: NullType;
|
|
794
|
+
/** Binary Cross Entropy with Logits (more stable, applies sigmoid internally - do NOT use with sigmoid output_activation) */
|
|
795
|
+
readonly bce_with_logits: NullType;
|
|
796
|
+
}>>;
|
|
797
|
+
/** Optimizer (default adam) */
|
|
798
|
+
readonly optimizer: OptionType<VariantType<{
|
|
799
|
+
/** Adam optimizer */
|
|
800
|
+
readonly adam: NullType;
|
|
801
|
+
/** Stochastic Gradient Descent */
|
|
802
|
+
readonly sgd: NullType;
|
|
803
|
+
/** AdamW with weight decay */
|
|
804
|
+
readonly adamw: NullType;
|
|
805
|
+
/** RMSprop optimizer */
|
|
806
|
+
readonly rmsprop: NullType;
|
|
807
|
+
}>>;
|
|
808
|
+
/** Early stopping patience, 0 = disabled */
|
|
809
|
+
readonly early_stopping: OptionType<IntegerType>;
|
|
810
|
+
/** Validation split fraction (default 0.2) */
|
|
811
|
+
readonly validation_split: OptionType<FloatType>;
|
|
812
|
+
/** Random seed for reproducibility */
|
|
813
|
+
readonly random_state: OptionType<IntegerType>;
|
|
814
|
+
}>], StructType<{
|
|
815
|
+
/** Trained model blob */
|
|
816
|
+
readonly model: VariantType<{
|
|
817
|
+
readonly torch_mlp: StructType<{
|
|
818
|
+
readonly data: BlobType;
|
|
819
|
+
readonly n_features: IntegerType;
|
|
820
|
+
readonly hidden_layers: ArrayType<IntegerType>;
|
|
821
|
+
readonly output_dim: IntegerType;
|
|
822
|
+
}>;
|
|
823
|
+
}>;
|
|
824
|
+
/** Training result with losses */
|
|
825
|
+
readonly result: StructType<{
|
|
826
|
+
/** Training loss per epoch */
|
|
827
|
+
readonly train_losses: VectorType<FloatType>;
|
|
828
|
+
/** Validation loss per epoch */
|
|
829
|
+
readonly val_losses: VectorType<FloatType>;
|
|
830
|
+
/** Best epoch (for early stopping) */
|
|
831
|
+
readonly best_epoch: IntegerType;
|
|
832
|
+
}>;
|
|
833
|
+
}>>;
|
|
834
|
+
/**
|
|
835
|
+
* Make predictions with a trained PyTorch MLP (single output).
|
|
836
|
+
*
|
|
837
|
+
* @example
|
|
838
|
+
* ```ts
|
|
839
|
+
* import { East, FloatType, MatrixType, VectorType } from "@elaraai/east";
|
|
840
|
+
* import { Torch } from "@elaraai/east-py-datascience";
|
|
841
|
+
*
|
|
842
|
+
* const predictFn = East.function(
|
|
843
|
+
* [Torch.Types.ModelBlobType, MatrixType(FloatType)],
|
|
844
|
+
* VectorType(FloatType),
|
|
845
|
+
* ($, model, X) => {
|
|
846
|
+
* return $.return(Torch.mlpPredict(model, X));
|
|
847
|
+
* }
|
|
848
|
+
* );
|
|
849
|
+
* ```
|
|
850
|
+
*/
|
|
851
|
+
readonly mlpPredict: import("@elaraai/east").PlatformDefinition<[VariantType<{
|
|
852
|
+
/** PyTorch MLP model */
|
|
853
|
+
readonly torch_mlp: StructType<{
|
|
854
|
+
/** Cloudpickle serialized model */
|
|
855
|
+
readonly data: BlobType;
|
|
856
|
+
/** Number of input features */
|
|
857
|
+
readonly n_features: IntegerType;
|
|
858
|
+
/** Hidden layer sizes */
|
|
859
|
+
readonly hidden_layers: ArrayType<IntegerType>;
|
|
860
|
+
/** Output dimension */
|
|
861
|
+
readonly output_dim: IntegerType;
|
|
862
|
+
}>;
|
|
863
|
+
}>, MatrixType<FloatType>], VectorType<FloatType>>;
|
|
864
|
+
/**
|
|
865
|
+
* Train a PyTorch MLP model with multi-output support.
|
|
866
|
+
*
|
|
867
|
+
* Supports multi-output regression and autoencoders (X = y).
|
|
868
|
+
*
|
|
869
|
+
* @example
|
|
870
|
+
* ```ts
|
|
871
|
+
* import { East, FloatType, MatrixType, variant } from "@elaraai/east";
|
|
872
|
+
* import { Torch, TorchMLPConfigType, TorchTrainConfigType } from "@elaraai/east-py-datascience";
|
|
873
|
+
*
|
|
874
|
+
* const train = East.function([], Torch.Types.TorchTrainOutputType, ($) => {
|
|
875
|
+
* const X = $.let([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]);
|
|
876
|
+
* const y = $.let([[10.0, 20.0], [30.0, 40.0], [50.0, 60.0]]);
|
|
877
|
+
* const mlp_config = $.let({
|
|
878
|
+
* hidden_layers: new BigInt64Array([16n, 8n]),
|
|
879
|
+
* activation: variant("some", variant("relu", null)),
|
|
880
|
+
* dropout: variant("none", null),
|
|
881
|
+
* output_dim: variant("none", null),
|
|
882
|
+
* output_activation: variant("none", null),
|
|
883
|
+
* }, TorchMLPConfigType);
|
|
884
|
+
* const train_config = $.let({
|
|
885
|
+
* epochs: variant("some", 50n), batch_size: variant("some", 4n),
|
|
886
|
+
* learning_rate: variant("some", 0.01), loss: variant("none", null),
|
|
887
|
+
* optimizer: variant("none", null), early_stopping: variant("none", null),
|
|
888
|
+
* validation_split: variant("some", 0.2), random_state: variant("some", 42n),
|
|
889
|
+
* }, TorchTrainConfigType);
|
|
890
|
+
* return $.return(Torch.mlpTrainMulti(X, y, mlp_config, train_config));
|
|
891
|
+
* });
|
|
892
|
+
* ```
|
|
893
|
+
*/
|
|
894
|
+
readonly mlpTrainMulti: import("@elaraai/east").PlatformDefinition<[MatrixType<FloatType>, MatrixType<FloatType>, StructType<{
|
|
895
|
+
/** Hidden layer sizes, e.g., [64, 32] */
|
|
896
|
+
readonly hidden_layers: ArrayType<IntegerType>;
|
|
897
|
+
/** Activation function for hidden layers (default relu) */
|
|
898
|
+
readonly activation: OptionType<VariantType<{
|
|
899
|
+
/** Rectified Linear Unit */
|
|
900
|
+
readonly relu: NullType;
|
|
901
|
+
/** Hyperbolic tangent */
|
|
902
|
+
readonly tanh: NullType;
|
|
903
|
+
/** Sigmoid function */
|
|
904
|
+
readonly sigmoid: NullType;
|
|
905
|
+
/** Leaky ReLU */
|
|
906
|
+
readonly leaky_relu: NullType;
|
|
907
|
+
}>>;
|
|
908
|
+
/** Output activation function (default none/linear). Ignored if output_constraints is set. */
|
|
909
|
+
readonly output_activation: OptionType<VariantType<{
|
|
910
|
+
/** No activation (linear output) - default */
|
|
911
|
+
readonly none: NullType;
|
|
912
|
+
/** Softmax (outputs sum to 1, for probability distributions) */
|
|
913
|
+
readonly softmax: NullType;
|
|
914
|
+
/** Sigmoid (each output independently in [0,1]) */
|
|
915
|
+
readonly sigmoid: NullType;
|
|
916
|
+
}>>;
|
|
917
|
+
/** Dropout rate (default 0.0) */
|
|
918
|
+
readonly dropout: OptionType<FloatType>;
|
|
919
|
+
/** Output dimension (default 1) */
|
|
920
|
+
readonly output_dim: OptionType<IntegerType>;
|
|
921
|
+
}>, StructType<{
|
|
922
|
+
/** Number of epochs (default 100) */
|
|
923
|
+
readonly epochs: OptionType<IntegerType>;
|
|
924
|
+
/** Batch size (default 32) */
|
|
925
|
+
readonly batch_size: OptionType<IntegerType>;
|
|
926
|
+
/** Learning rate (default 0.001) */
|
|
927
|
+
readonly learning_rate: OptionType<FloatType>;
|
|
928
|
+
/** Loss function (default mse) */
|
|
929
|
+
readonly loss: OptionType<VariantType<{
|
|
930
|
+
/** Mean Squared Error (regression) */
|
|
931
|
+
readonly mse: NullType;
|
|
932
|
+
/** Mean Absolute Error (regression) */
|
|
933
|
+
readonly mae: NullType;
|
|
934
|
+
/** Cross Entropy (multi-class classification with integer targets) */
|
|
935
|
+
readonly cross_entropy: NullType;
|
|
936
|
+
/** KL Divergence (distribution matching, use with softmax output) */
|
|
937
|
+
readonly kl_div: NullType;
|
|
938
|
+
/** Binary Cross Entropy (multi-label binary, requires sigmoid output) */
|
|
939
|
+
readonly bce: NullType;
|
|
940
|
+
/** Binary Cross Entropy with Logits (more stable, applies sigmoid internally - do NOT use with sigmoid output_activation) */
|
|
941
|
+
readonly bce_with_logits: NullType;
|
|
942
|
+
}>>;
|
|
943
|
+
/** Optimizer (default adam) */
|
|
944
|
+
readonly optimizer: OptionType<VariantType<{
|
|
945
|
+
/** Adam optimizer */
|
|
946
|
+
readonly adam: NullType;
|
|
947
|
+
/** Stochastic Gradient Descent */
|
|
948
|
+
readonly sgd: NullType;
|
|
949
|
+
/** AdamW with weight decay */
|
|
950
|
+
readonly adamw: NullType;
|
|
951
|
+
/** RMSprop optimizer */
|
|
952
|
+
readonly rmsprop: NullType;
|
|
953
|
+
}>>;
|
|
954
|
+
/** Early stopping patience, 0 = disabled */
|
|
955
|
+
readonly early_stopping: OptionType<IntegerType>;
|
|
956
|
+
/** Validation split fraction (default 0.2) */
|
|
957
|
+
readonly validation_split: OptionType<FloatType>;
|
|
958
|
+
/** Random seed for reproducibility */
|
|
959
|
+
readonly random_state: OptionType<IntegerType>;
|
|
960
|
+
}>], StructType<{
|
|
961
|
+
/** Trained model blob */
|
|
962
|
+
readonly model: VariantType<{
|
|
963
|
+
readonly torch_mlp: StructType<{
|
|
964
|
+
readonly data: BlobType;
|
|
965
|
+
readonly n_features: IntegerType;
|
|
966
|
+
readonly hidden_layers: ArrayType<IntegerType>;
|
|
967
|
+
readonly output_dim: IntegerType;
|
|
968
|
+
}>;
|
|
969
|
+
}>;
|
|
970
|
+
/** Training result with losses */
|
|
971
|
+
readonly result: StructType<{
|
|
972
|
+
/** Training loss per epoch */
|
|
973
|
+
readonly train_losses: VectorType<FloatType>;
|
|
974
|
+
/** Validation loss per epoch */
|
|
975
|
+
readonly val_losses: VectorType<FloatType>;
|
|
976
|
+
/** Best epoch (for early stopping) */
|
|
977
|
+
readonly best_epoch: IntegerType;
|
|
978
|
+
}>;
|
|
979
|
+
}>>;
|
|
980
|
+
/**
|
|
981
|
+
* Make predictions with a trained PyTorch MLP (multi-output).
|
|
982
|
+
*
|
|
983
|
+
* Returns a matrix where each row contains predicted outputs for a sample.
|
|
984
|
+
*
|
|
985
|
+
* @example
|
|
986
|
+
* ```ts
|
|
987
|
+
* import { East, FloatType, MatrixType } from "@elaraai/east";
|
|
988
|
+
* import { Torch } from "@elaraai/east-py-datascience";
|
|
989
|
+
*
|
|
990
|
+
* const predictFn = East.function(
|
|
991
|
+
* [Torch.Types.ModelBlobType, MatrixType(FloatType)],
|
|
992
|
+
* MatrixType(FloatType),
|
|
993
|
+
* ($, model, X) => {
|
|
994
|
+
* // Returns (n_samples x n_outputs) matrix
|
|
995
|
+
* return $.return(Torch.mlpPredictMulti(model, X));
|
|
996
|
+
* }
|
|
997
|
+
* );
|
|
998
|
+
* ```
|
|
999
|
+
*/
|
|
1000
|
+
readonly mlpPredictMulti: import("@elaraai/east").PlatformDefinition<[VariantType<{
|
|
1001
|
+
/** PyTorch MLP model */
|
|
1002
|
+
readonly torch_mlp: StructType<{
|
|
1003
|
+
/** Cloudpickle serialized model */
|
|
1004
|
+
readonly data: BlobType;
|
|
1005
|
+
/** Number of input features */
|
|
1006
|
+
readonly n_features: IntegerType;
|
|
1007
|
+
/** Hidden layer sizes */
|
|
1008
|
+
readonly hidden_layers: ArrayType<IntegerType>;
|
|
1009
|
+
/** Output dimension */
|
|
1010
|
+
readonly output_dim: IntegerType;
|
|
1011
|
+
}>;
|
|
1012
|
+
}>, MatrixType<FloatType>], MatrixType<FloatType>>;
|
|
1013
|
+
/** Extract intermediate layer activations (embeddings) from MLP */
|
|
1014
|
+
readonly mlpEncode: import("@elaraai/east").PlatformDefinition<[VariantType<{
|
|
1015
|
+
/** PyTorch MLP model */
|
|
1016
|
+
readonly torch_mlp: StructType<{
|
|
1017
|
+
/** Cloudpickle serialized model */
|
|
1018
|
+
readonly data: BlobType;
|
|
1019
|
+
/** Number of input features */
|
|
1020
|
+
readonly n_features: IntegerType;
|
|
1021
|
+
/** Hidden layer sizes */
|
|
1022
|
+
readonly hidden_layers: ArrayType<IntegerType>;
|
|
1023
|
+
/** Output dimension */
|
|
1024
|
+
readonly output_dim: IntegerType;
|
|
1025
|
+
}>;
|
|
1026
|
+
}>, MatrixType<FloatType>, IntegerType], MatrixType<FloatType>>;
|
|
1027
|
+
/** Decode embeddings back through decoder portion of MLP */
|
|
1028
|
+
readonly mlpDecode: import("@elaraai/east").PlatformDefinition<[VariantType<{
|
|
1029
|
+
/** PyTorch MLP model */
|
|
1030
|
+
readonly torch_mlp: StructType<{
|
|
1031
|
+
/** Cloudpickle serialized model */
|
|
1032
|
+
readonly data: BlobType;
|
|
1033
|
+
/** Number of input features */
|
|
1034
|
+
readonly n_features: IntegerType;
|
|
1035
|
+
/** Hidden layer sizes */
|
|
1036
|
+
readonly hidden_layers: ArrayType<IntegerType>;
|
|
1037
|
+
/** Output dimension */
|
|
1038
|
+
readonly output_dim: IntegerType;
|
|
1039
|
+
}>;
|
|
1040
|
+
}>, MatrixType<FloatType>, IntegerType], MatrixType<FloatType>>;
|
|
1041
|
+
/** Type definitions */
|
|
1042
|
+
readonly Types: {
|
|
1043
|
+
/** Activation function type for hidden layers */
|
|
1044
|
+
readonly TorchActivationType: VariantType<{
|
|
1045
|
+
/** Rectified Linear Unit */
|
|
1046
|
+
readonly relu: NullType;
|
|
1047
|
+
/** Hyperbolic tangent */
|
|
1048
|
+
readonly tanh: NullType;
|
|
1049
|
+
/** Sigmoid function */
|
|
1050
|
+
readonly sigmoid: NullType;
|
|
1051
|
+
/** Leaky ReLU */
|
|
1052
|
+
readonly leaky_relu: NullType;
|
|
1053
|
+
}>;
|
|
1054
|
+
/** Output activation function type */
|
|
1055
|
+
readonly TorchOutputActivationType: VariantType<{
|
|
1056
|
+
/** No activation (linear output) - default */
|
|
1057
|
+
readonly none: NullType;
|
|
1058
|
+
/** Softmax (outputs sum to 1, for probability distributions) */
|
|
1059
|
+
readonly softmax: NullType;
|
|
1060
|
+
/** Sigmoid (each output independently in [0,1]) */
|
|
1061
|
+
readonly sigmoid: NullType;
|
|
1062
|
+
}>;
|
|
1063
|
+
/** Loss function type */
|
|
1064
|
+
readonly TorchLossType: VariantType<{
|
|
1065
|
+
/** Mean Squared Error (regression) */
|
|
1066
|
+
readonly mse: NullType;
|
|
1067
|
+
/** Mean Absolute Error (regression) */
|
|
1068
|
+
readonly mae: NullType;
|
|
1069
|
+
/** Cross Entropy (multi-class classification with integer targets) */
|
|
1070
|
+
readonly cross_entropy: NullType;
|
|
1071
|
+
/** KL Divergence (distribution matching, use with softmax output) */
|
|
1072
|
+
readonly kl_div: NullType;
|
|
1073
|
+
/** Binary Cross Entropy (multi-label binary, requires sigmoid output) */
|
|
1074
|
+
readonly bce: NullType;
|
|
1075
|
+
/** Binary Cross Entropy with Logits (more stable, applies sigmoid internally - do NOT use with sigmoid output_activation) */
|
|
1076
|
+
readonly bce_with_logits: NullType;
|
|
1077
|
+
}>;
|
|
1078
|
+
/** Optimizer type */
|
|
1079
|
+
readonly TorchOptimizerType: VariantType<{
|
|
1080
|
+
/** Adam optimizer */
|
|
1081
|
+
readonly adam: NullType;
|
|
1082
|
+
/** Stochastic Gradient Descent */
|
|
1083
|
+
readonly sgd: NullType;
|
|
1084
|
+
/** AdamW with weight decay */
|
|
1085
|
+
readonly adamw: NullType;
|
|
1086
|
+
/** RMSprop optimizer */
|
|
1087
|
+
readonly rmsprop: NullType;
|
|
1088
|
+
}>;
|
|
1089
|
+
/** MLP configuration type */
|
|
1090
|
+
readonly TorchMLPConfigType: StructType<{
|
|
1091
|
+
/** Hidden layer sizes, e.g., [64, 32] */
|
|
1092
|
+
readonly hidden_layers: ArrayType<IntegerType>;
|
|
1093
|
+
/** Activation function for hidden layers (default relu) */
|
|
1094
|
+
readonly activation: OptionType<VariantType<{
|
|
1095
|
+
/** Rectified Linear Unit */
|
|
1096
|
+
readonly relu: NullType;
|
|
1097
|
+
/** Hyperbolic tangent */
|
|
1098
|
+
readonly tanh: NullType;
|
|
1099
|
+
/** Sigmoid function */
|
|
1100
|
+
readonly sigmoid: NullType;
|
|
1101
|
+
/** Leaky ReLU */
|
|
1102
|
+
readonly leaky_relu: NullType;
|
|
1103
|
+
}>>;
|
|
1104
|
+
/** Output activation function (default none/linear). Ignored if output_constraints is set. */
|
|
1105
|
+
readonly output_activation: OptionType<VariantType<{
|
|
1106
|
+
/** No activation (linear output) - default */
|
|
1107
|
+
readonly none: NullType;
|
|
1108
|
+
/** Softmax (outputs sum to 1, for probability distributions) */
|
|
1109
|
+
readonly softmax: NullType;
|
|
1110
|
+
/** Sigmoid (each output independently in [0,1]) */
|
|
1111
|
+
readonly sigmoid: NullType;
|
|
1112
|
+
}>>;
|
|
1113
|
+
/** Dropout rate (default 0.0) */
|
|
1114
|
+
readonly dropout: OptionType<FloatType>;
|
|
1115
|
+
/** Output dimension (default 1) */
|
|
1116
|
+
readonly output_dim: OptionType<IntegerType>;
|
|
1117
|
+
}>;
|
|
1118
|
+
/** Training configuration type */
|
|
1119
|
+
readonly TorchTrainConfigType: StructType<{
|
|
1120
|
+
/** Number of epochs (default 100) */
|
|
1121
|
+
readonly epochs: OptionType<IntegerType>;
|
|
1122
|
+
/** Batch size (default 32) */
|
|
1123
|
+
readonly batch_size: OptionType<IntegerType>;
|
|
1124
|
+
/** Learning rate (default 0.001) */
|
|
1125
|
+
readonly learning_rate: OptionType<FloatType>;
|
|
1126
|
+
/** Loss function (default mse) */
|
|
1127
|
+
readonly loss: OptionType<VariantType<{
|
|
1128
|
+
/** Mean Squared Error (regression) */
|
|
1129
|
+
readonly mse: NullType;
|
|
1130
|
+
/** Mean Absolute Error (regression) */
|
|
1131
|
+
readonly mae: NullType;
|
|
1132
|
+
/** Cross Entropy (multi-class classification with integer targets) */
|
|
1133
|
+
readonly cross_entropy: NullType;
|
|
1134
|
+
/** KL Divergence (distribution matching, use with softmax output) */
|
|
1135
|
+
readonly kl_div: NullType;
|
|
1136
|
+
/** Binary Cross Entropy (multi-label binary, requires sigmoid output) */
|
|
1137
|
+
readonly bce: NullType;
|
|
1138
|
+
/** Binary Cross Entropy with Logits (more stable, applies sigmoid internally - do NOT use with sigmoid output_activation) */
|
|
1139
|
+
readonly bce_with_logits: NullType;
|
|
1140
|
+
}>>;
|
|
1141
|
+
/** Optimizer (default adam) */
|
|
1142
|
+
readonly optimizer: OptionType<VariantType<{
|
|
1143
|
+
/** Adam optimizer */
|
|
1144
|
+
readonly adam: NullType;
|
|
1145
|
+
/** Stochastic Gradient Descent */
|
|
1146
|
+
readonly sgd: NullType;
|
|
1147
|
+
/** AdamW with weight decay */
|
|
1148
|
+
readonly adamw: NullType;
|
|
1149
|
+
/** RMSprop optimizer */
|
|
1150
|
+
readonly rmsprop: NullType;
|
|
1151
|
+
}>>;
|
|
1152
|
+
/** Early stopping patience, 0 = disabled */
|
|
1153
|
+
readonly early_stopping: OptionType<IntegerType>;
|
|
1154
|
+
/** Validation split fraction (default 0.2) */
|
|
1155
|
+
readonly validation_split: OptionType<FloatType>;
|
|
1156
|
+
/** Random seed for reproducibility */
|
|
1157
|
+
readonly random_state: OptionType<IntegerType>;
|
|
1158
|
+
}>;
|
|
1159
|
+
/** Training result type */
|
|
1160
|
+
readonly TorchTrainResultType: StructType<{
|
|
1161
|
+
/** Training loss per epoch */
|
|
1162
|
+
readonly train_losses: VectorType<FloatType>;
|
|
1163
|
+
/** Validation loss per epoch */
|
|
1164
|
+
readonly val_losses: VectorType<FloatType>;
|
|
1165
|
+
/** Best epoch (for early stopping) */
|
|
1166
|
+
readonly best_epoch: IntegerType;
|
|
1167
|
+
}>;
|
|
1168
|
+
/** Training output type (model + result) */
|
|
1169
|
+
readonly TorchTrainOutputType: StructType<{
|
|
1170
|
+
/** Trained model blob */
|
|
1171
|
+
readonly model: VariantType<{
|
|
1172
|
+
readonly torch_mlp: StructType<{
|
|
1173
|
+
readonly data: BlobType;
|
|
1174
|
+
readonly n_features: IntegerType;
|
|
1175
|
+
readonly hidden_layers: ArrayType<IntegerType>;
|
|
1176
|
+
readonly output_dim: IntegerType;
|
|
1177
|
+
}>;
|
|
1178
|
+
}>;
|
|
1179
|
+
/** Training result with losses */
|
|
1180
|
+
readonly result: StructType<{
|
|
1181
|
+
/** Training loss per epoch */
|
|
1182
|
+
readonly train_losses: VectorType<FloatType>;
|
|
1183
|
+
/** Validation loss per epoch */
|
|
1184
|
+
readonly val_losses: VectorType<FloatType>;
|
|
1185
|
+
/** Best epoch (for early stopping) */
|
|
1186
|
+
readonly best_epoch: IntegerType;
|
|
1187
|
+
}>;
|
|
1188
|
+
}>;
|
|
1189
|
+
/** Model blob type for PyTorch models */
|
|
1190
|
+
readonly ModelBlobType: VariantType<{
|
|
1191
|
+
/** PyTorch MLP model */
|
|
1192
|
+
readonly torch_mlp: StructType<{
|
|
1193
|
+
/** Cloudpickle serialized model */
|
|
1194
|
+
readonly data: BlobType;
|
|
1195
|
+
/** Number of input features */
|
|
1196
|
+
readonly n_features: IntegerType;
|
|
1197
|
+
/** Hidden layer sizes */
|
|
1198
|
+
readonly hidden_layers: ArrayType<IntegerType>;
|
|
1199
|
+
/** Output dimension */
|
|
1200
|
+
readonly output_dim: IntegerType;
|
|
1201
|
+
}>;
|
|
1202
|
+
}>;
|
|
1203
|
+
};
|
|
1204
|
+
};
|
|
1205
|
+
//# sourceMappingURL=torch.d.ts.map
|