@elaraai/east-py-datascience 0.0.2-beta.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. package/LICENSE.md +18 -0
  2. package/README.md +104 -0
  3. package/dist/gp/gp.d.ts +398 -0
  4. package/dist/gp/gp.d.ts.map +1 -0
  5. package/dist/gp/gp.js +170 -0
  6. package/dist/gp/gp.js.map +1 -0
  7. package/dist/index.d.ts +27 -0
  8. package/dist/index.d.ts.map +1 -0
  9. package/dist/index.js +39 -0
  10. package/dist/index.js.map +1 -0
  11. package/dist/lightgbm/lightgbm.d.ts +494 -0
  12. package/dist/lightgbm/lightgbm.d.ts.map +1 -0
  13. package/dist/lightgbm/lightgbm.js +155 -0
  14. package/dist/lightgbm/lightgbm.js.map +1 -0
  15. package/dist/mads/mads.d.ts +413 -0
  16. package/dist/mads/mads.d.ts.map +1 -0
  17. package/dist/mads/mads.js +221 -0
  18. package/dist/mads/mads.js.map +1 -0
  19. package/dist/ngboost/ngboost.d.ts +433 -0
  20. package/dist/ngboost/ngboost.d.ts.map +1 -0
  21. package/dist/ngboost/ngboost.js +178 -0
  22. package/dist/ngboost/ngboost.js.map +1 -0
  23. package/dist/optuna/optuna.d.ts +797 -0
  24. package/dist/optuna/optuna.d.ts.map +1 -0
  25. package/dist/optuna/optuna.js +268 -0
  26. package/dist/optuna/optuna.js.map +1 -0
  27. package/dist/scipy/scipy.d.ts +954 -0
  28. package/dist/scipy/scipy.d.ts.map +1 -0
  29. package/dist/scipy/scipy.js +287 -0
  30. package/dist/scipy/scipy.js.map +1 -0
  31. package/dist/shap/shap.d.ts +657 -0
  32. package/dist/shap/shap.d.ts.map +1 -0
  33. package/dist/shap/shap.js +241 -0
  34. package/dist/shap/shap.js.map +1 -0
  35. package/dist/simanneal/simanneal.d.ts +531 -0
  36. package/dist/simanneal/simanneal.d.ts.map +1 -0
  37. package/dist/simanneal/simanneal.js +231 -0
  38. package/dist/simanneal/simanneal.js.map +1 -0
  39. package/dist/sklearn/sklearn.d.ts +1272 -0
  40. package/dist/sklearn/sklearn.d.ts.map +1 -0
  41. package/dist/sklearn/sklearn.js +307 -0
  42. package/dist/sklearn/sklearn.js.map +1 -0
  43. package/dist/torch/torch.d.ts +658 -0
  44. package/dist/torch/torch.d.ts.map +1 -0
  45. package/dist/torch/torch.js +233 -0
  46. package/dist/torch/torch.js.map +1 -0
  47. package/dist/tsconfig.tsbuildinfo +1 -0
  48. package/dist/types.d.ts +80 -0
  49. package/dist/types.d.ts.map +1 -0
  50. package/dist/types.js +81 -0
  51. package/dist/types.js.map +1 -0
  52. package/dist/xgboost/xgboost.d.ts +504 -0
  53. package/dist/xgboost/xgboost.d.ts.map +1 -0
  54. package/dist/xgboost/xgboost.js +177 -0
  55. package/dist/xgboost/xgboost.js.map +1 -0
  56. package/package.json +82 -0
@@ -0,0 +1,658 @@
1
+ /**
2
+ * Copyright (c) 2025 Elara AI Pty Ltd
3
+ * Dual-licensed under AGPL-3.0 and commercial license. See LICENSE for details.
4
+ */
5
+ /**
6
+ * PyTorch platform functions for East.
7
+ *
8
+ * Provides neural network models using PyTorch.
9
+ * Uses cloudpickle for model serialization.
10
+ *
11
+ * @packageDocumentation
12
+ */
13
+ import { StructType, VariantType, OptionType, IntegerType, FloatType, BlobType, ArrayType } from "@elaraai/east";
14
+ export { VectorType, MatrixType } from "../types.js";
15
+ /**
16
+ * Activation function type for neural networks.
17
+ */
18
+ export declare const TorchActivationType: VariantType<{
19
+ /** Rectified Linear Unit */
20
+ relu: StructType<{}>;
21
+ /** Hyperbolic tangent */
22
+ tanh: StructType<{}>;
23
+ /** Sigmoid function */
24
+ sigmoid: StructType<{}>;
25
+ /** Leaky ReLU */
26
+ leaky_relu: StructType<{}>;
27
+ }>;
28
+ /**
29
+ * Loss function type for training.
30
+ */
31
+ export declare const TorchLossType: VariantType<{
32
+ /** Mean Squared Error (regression) */
33
+ mse: StructType<{}>;
34
+ /** Mean Absolute Error (regression) */
35
+ mae: StructType<{}>;
36
+ /** Cross Entropy (classification) */
37
+ cross_entropy: StructType<{}>;
38
+ }>;
39
+ /**
40
+ * Optimizer type for training.
41
+ */
42
+ export declare const TorchOptimizerType: VariantType<{
43
+ /** Adam optimizer */
44
+ adam: StructType<{}>;
45
+ /** Stochastic Gradient Descent */
46
+ sgd: StructType<{}>;
47
+ /** AdamW with weight decay */
48
+ adamw: StructType<{}>;
49
+ /** RMSprop optimizer */
50
+ rmsprop: StructType<{}>;
51
+ }>;
52
+ /**
53
+ * Configuration for MLP architecture.
54
+ */
55
+ export declare const TorchMLPConfigType: StructType<{
56
+ /** Hidden layer sizes, e.g., [64, 32] */
57
+ hidden_layers: ArrayType<IntegerType>;
58
+ /** Activation function (default relu) */
59
+ activation: OptionType<VariantType<{
60
+ /** Rectified Linear Unit */
61
+ relu: StructType<{}>;
62
+ /** Hyperbolic tangent */
63
+ tanh: StructType<{}>;
64
+ /** Sigmoid function */
65
+ sigmoid: StructType<{}>;
66
+ /** Leaky ReLU */
67
+ leaky_relu: StructType<{}>;
68
+ }>>;
69
+ /** Dropout rate (default 0.0) */
70
+ dropout: OptionType<FloatType>;
71
+ /** Output dimension (default 1) */
72
+ output_dim: OptionType<IntegerType>;
73
+ }>;
74
+ /**
75
+ * Configuration for training.
76
+ */
77
+ export declare const TorchTrainConfigType: StructType<{
78
+ /** Number of epochs (default 100) */
79
+ epochs: OptionType<IntegerType>;
80
+ /** Batch size (default 32) */
81
+ batch_size: OptionType<IntegerType>;
82
+ /** Learning rate (default 0.001) */
83
+ learning_rate: OptionType<FloatType>;
84
+ /** Loss function (default mse) */
85
+ loss: OptionType<VariantType<{
86
+ /** Mean Squared Error (regression) */
87
+ mse: StructType<{}>;
88
+ /** Mean Absolute Error (regression) */
89
+ mae: StructType<{}>;
90
+ /** Cross Entropy (classification) */
91
+ cross_entropy: StructType<{}>;
92
+ }>>;
93
+ /** Optimizer (default adam) */
94
+ optimizer: OptionType<VariantType<{
95
+ /** Adam optimizer */
96
+ adam: StructType<{}>;
97
+ /** Stochastic Gradient Descent */
98
+ sgd: StructType<{}>;
99
+ /** AdamW with weight decay */
100
+ adamw: StructType<{}>;
101
+ /** RMSprop optimizer */
102
+ rmsprop: StructType<{}>;
103
+ }>>;
104
+ /** Early stopping patience, 0 = disabled */
105
+ early_stopping: OptionType<IntegerType>;
106
+ /** Validation split fraction (default 0.2) */
107
+ validation_split: OptionType<FloatType>;
108
+ /** Random seed for reproducibility */
109
+ random_state: OptionType<IntegerType>;
110
+ }>;
111
+ /**
112
+ * Result type for training.
113
+ */
114
+ export declare const TorchTrainResultType: StructType<{
115
+ /** Training loss per epoch */
116
+ train_losses: ArrayType<FloatType>;
117
+ /** Validation loss per epoch */
118
+ val_losses: ArrayType<FloatType>;
119
+ /** Best epoch (for early stopping) */
120
+ best_epoch: IntegerType;
121
+ }>;
122
+ /**
123
+ * Combined result from training (model + metrics).
124
+ */
125
+ export declare const TorchTrainOutputType: StructType<{
126
+ /** Trained model blob */
127
+ model: VariantType<{
128
+ torch_mlp: StructType<{
129
+ data: BlobType;
130
+ n_features: IntegerType;
131
+ hidden_layers: ArrayType<IntegerType>;
132
+ output_dim: IntegerType;
133
+ }>;
134
+ }>;
135
+ /** Training result with losses */
136
+ result: StructType<{
137
+ /** Training loss per epoch */
138
+ train_losses: ArrayType<FloatType>;
139
+ /** Validation loss per epoch */
140
+ val_losses: ArrayType<FloatType>;
141
+ /** Best epoch (for early stopping) */
142
+ best_epoch: IntegerType;
143
+ }>;
144
+ }>;
145
+ /**
146
+ * Model blob type for serialized PyTorch models.
147
+ */
148
+ export declare const TorchModelBlobType: VariantType<{
149
+ /** PyTorch MLP model */
150
+ torch_mlp: StructType<{
151
+ /** Cloudpickle serialized model */
152
+ data: BlobType;
153
+ /** Number of input features */
154
+ n_features: IntegerType;
155
+ /** Hidden layer sizes */
156
+ hidden_layers: ArrayType<IntegerType>;
157
+ /** Output dimension */
158
+ output_dim: IntegerType;
159
+ }>;
160
+ }>;
161
+ /**
162
+ * Train a PyTorch MLP model.
163
+ *
164
+ * @param X - Feature matrix
165
+ * @param y - Target vector
166
+ * @param mlp_config - MLP architecture configuration
167
+ * @param train_config - Training configuration
168
+ * @returns Model blob and training result
169
+ */
170
+ export declare const torch_mlp_train: import("@elaraai/east").PlatformDefinition<[ArrayType<ArrayType<FloatType>>, ArrayType<FloatType>, StructType<{
171
+ /** Hidden layer sizes, e.g., [64, 32] */
172
+ hidden_layers: ArrayType<IntegerType>;
173
+ /** Activation function (default relu) */
174
+ activation: OptionType<VariantType<{
175
+ /** Rectified Linear Unit */
176
+ relu: StructType<{}>;
177
+ /** Hyperbolic tangent */
178
+ tanh: StructType<{}>;
179
+ /** Sigmoid function */
180
+ sigmoid: StructType<{}>;
181
+ /** Leaky ReLU */
182
+ leaky_relu: StructType<{}>;
183
+ }>>;
184
+ /** Dropout rate (default 0.0) */
185
+ dropout: OptionType<FloatType>;
186
+ /** Output dimension (default 1) */
187
+ output_dim: OptionType<IntegerType>;
188
+ }>, StructType<{
189
+ /** Number of epochs (default 100) */
190
+ epochs: OptionType<IntegerType>;
191
+ /** Batch size (default 32) */
192
+ batch_size: OptionType<IntegerType>;
193
+ /** Learning rate (default 0.001) */
194
+ learning_rate: OptionType<FloatType>;
195
+ /** Loss function (default mse) */
196
+ loss: OptionType<VariantType<{
197
+ /** Mean Squared Error (regression) */
198
+ mse: StructType<{}>;
199
+ /** Mean Absolute Error (regression) */
200
+ mae: StructType<{}>;
201
+ /** Cross Entropy (classification) */
202
+ cross_entropy: StructType<{}>;
203
+ }>>;
204
+ /** Optimizer (default adam) */
205
+ optimizer: OptionType<VariantType<{
206
+ /** Adam optimizer */
207
+ adam: StructType<{}>;
208
+ /** Stochastic Gradient Descent */
209
+ sgd: StructType<{}>;
210
+ /** AdamW with weight decay */
211
+ adamw: StructType<{}>;
212
+ /** RMSprop optimizer */
213
+ rmsprop: StructType<{}>;
214
+ }>>;
215
+ /** Early stopping patience, 0 = disabled */
216
+ early_stopping: OptionType<IntegerType>;
217
+ /** Validation split fraction (default 0.2) */
218
+ validation_split: OptionType<FloatType>;
219
+ /** Random seed for reproducibility */
220
+ random_state: OptionType<IntegerType>;
221
+ }>], StructType<{
222
+ /** Trained model blob */
223
+ model: VariantType<{
224
+ torch_mlp: StructType<{
225
+ data: BlobType;
226
+ n_features: IntegerType;
227
+ hidden_layers: ArrayType<IntegerType>;
228
+ output_dim: IntegerType;
229
+ }>;
230
+ }>;
231
+ /** Training result with losses */
232
+ result: StructType<{
233
+ /** Training loss per epoch */
234
+ train_losses: ArrayType<FloatType>;
235
+ /** Validation loss per epoch */
236
+ val_losses: ArrayType<FloatType>;
237
+ /** Best epoch (for early stopping) */
238
+ best_epoch: IntegerType;
239
+ }>;
240
+ }>>;
241
+ /**
242
+ * Make predictions with a trained PyTorch MLP.
243
+ *
244
+ * @param model - Trained MLP model blob
245
+ * @param X - Feature matrix
246
+ * @returns Predicted values
247
+ */
248
+ export declare const torch_mlp_predict: import("@elaraai/east").PlatformDefinition<[VariantType<{
249
+ /** PyTorch MLP model */
250
+ torch_mlp: StructType<{
251
+ /** Cloudpickle serialized model */
252
+ data: BlobType;
253
+ /** Number of input features */
254
+ n_features: IntegerType;
255
+ /** Hidden layer sizes */
256
+ hidden_layers: ArrayType<IntegerType>;
257
+ /** Output dimension */
258
+ output_dim: IntegerType;
259
+ }>;
260
+ }>, ArrayType<ArrayType<FloatType>>], ArrayType<FloatType>>;
261
+ /**
262
+ * Type definitions for PyTorch functions.
263
+ */
264
+ export declare const TorchTypes: {
265
+ /** Vector type (array of floats) */
266
+ readonly VectorType: ArrayType<FloatType>;
267
+ /** Matrix type (2D array of floats) */
268
+ readonly MatrixType: ArrayType<ArrayType<FloatType>>;
269
+ /** Activation function type */
270
+ readonly TorchActivationType: VariantType<{
271
+ /** Rectified Linear Unit */
272
+ relu: StructType<{}>;
273
+ /** Hyperbolic tangent */
274
+ tanh: StructType<{}>;
275
+ /** Sigmoid function */
276
+ sigmoid: StructType<{}>;
277
+ /** Leaky ReLU */
278
+ leaky_relu: StructType<{}>;
279
+ }>;
280
+ /** Loss function type */
281
+ readonly TorchLossType: VariantType<{
282
+ /** Mean Squared Error (regression) */
283
+ mse: StructType<{}>;
284
+ /** Mean Absolute Error (regression) */
285
+ mae: StructType<{}>;
286
+ /** Cross Entropy (classification) */
287
+ cross_entropy: StructType<{}>;
288
+ }>;
289
+ /** Optimizer type */
290
+ readonly TorchOptimizerType: VariantType<{
291
+ /** Adam optimizer */
292
+ adam: StructType<{}>;
293
+ /** Stochastic Gradient Descent */
294
+ sgd: StructType<{}>;
295
+ /** AdamW with weight decay */
296
+ adamw: StructType<{}>;
297
+ /** RMSprop optimizer */
298
+ rmsprop: StructType<{}>;
299
+ }>;
300
+ /** MLP configuration type */
301
+ readonly TorchMLPConfigType: StructType<{
302
+ /** Hidden layer sizes, e.g., [64, 32] */
303
+ hidden_layers: ArrayType<IntegerType>;
304
+ /** Activation function (default relu) */
305
+ activation: OptionType<VariantType<{
306
+ /** Rectified Linear Unit */
307
+ relu: StructType<{}>;
308
+ /** Hyperbolic tangent */
309
+ tanh: StructType<{}>;
310
+ /** Sigmoid function */
311
+ sigmoid: StructType<{}>;
312
+ /** Leaky ReLU */
313
+ leaky_relu: StructType<{}>;
314
+ }>>;
315
+ /** Dropout rate (default 0.0) */
316
+ dropout: OptionType<FloatType>;
317
+ /** Output dimension (default 1) */
318
+ output_dim: OptionType<IntegerType>;
319
+ }>;
320
+ /** Training configuration type */
321
+ readonly TorchTrainConfigType: StructType<{
322
+ /** Number of epochs (default 100) */
323
+ epochs: OptionType<IntegerType>;
324
+ /** Batch size (default 32) */
325
+ batch_size: OptionType<IntegerType>;
326
+ /** Learning rate (default 0.001) */
327
+ learning_rate: OptionType<FloatType>;
328
+ /** Loss function (default mse) */
329
+ loss: OptionType<VariantType<{
330
+ /** Mean Squared Error (regression) */
331
+ mse: StructType<{}>;
332
+ /** Mean Absolute Error (regression) */
333
+ mae: StructType<{}>;
334
+ /** Cross Entropy (classification) */
335
+ cross_entropy: StructType<{}>;
336
+ }>>;
337
+ /** Optimizer (default adam) */
338
+ optimizer: OptionType<VariantType<{
339
+ /** Adam optimizer */
340
+ adam: StructType<{}>;
341
+ /** Stochastic Gradient Descent */
342
+ sgd: StructType<{}>;
343
+ /** AdamW with weight decay */
344
+ adamw: StructType<{}>;
345
+ /** RMSprop optimizer */
346
+ rmsprop: StructType<{}>;
347
+ }>>;
348
+ /** Early stopping patience, 0 = disabled */
349
+ early_stopping: OptionType<IntegerType>;
350
+ /** Validation split fraction (default 0.2) */
351
+ validation_split: OptionType<FloatType>;
352
+ /** Random seed for reproducibility */
353
+ random_state: OptionType<IntegerType>;
354
+ }>;
355
+ /** Training result type */
356
+ readonly TorchTrainResultType: StructType<{
357
+ /** Training loss per epoch */
358
+ train_losses: ArrayType<FloatType>;
359
+ /** Validation loss per epoch */
360
+ val_losses: ArrayType<FloatType>;
361
+ /** Best epoch (for early stopping) */
362
+ best_epoch: IntegerType;
363
+ }>;
364
+ /** Training output type (model + result) */
365
+ readonly TorchTrainOutputType: StructType<{
366
+ /** Trained model blob */
367
+ model: VariantType<{
368
+ torch_mlp: StructType<{
369
+ data: BlobType;
370
+ n_features: IntegerType;
371
+ hidden_layers: ArrayType<IntegerType>;
372
+ output_dim: IntegerType;
373
+ }>;
374
+ }>;
375
+ /** Training result with losses */
376
+ result: StructType<{
377
+ /** Training loss per epoch */
378
+ train_losses: ArrayType<FloatType>;
379
+ /** Validation loss per epoch */
380
+ val_losses: ArrayType<FloatType>;
381
+ /** Best epoch (for early stopping) */
382
+ best_epoch: IntegerType;
383
+ }>;
384
+ }>;
385
+ /** Model blob type for PyTorch models */
386
+ readonly ModelBlobType: VariantType<{
387
+ /** PyTorch MLP model */
388
+ torch_mlp: StructType<{
389
+ /** Cloudpickle serialized model */
390
+ data: BlobType;
391
+ /** Number of input features */
392
+ n_features: IntegerType;
393
+ /** Hidden layer sizes */
394
+ hidden_layers: ArrayType<IntegerType>;
395
+ /** Output dimension */
396
+ output_dim: IntegerType;
397
+ }>;
398
+ }>;
399
+ };
400
+ /**
401
+ * PyTorch neural network models.
402
+ *
403
+ * Provides MLP training and inference using PyTorch.
404
+ *
405
+ * @example
406
+ * ```ts
407
+ * import { East, variant } from "@elaraai/east";
408
+ * import { Torch } from "@elaraai/east-py-datascience";
409
+ *
410
+ * const train = East.function([], Torch.Types.TorchTrainOutputType, $ => {
411
+ * const X = $.let([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]);
412
+ * const y = $.let([1.0, 2.0, 3.0, 4.0]);
413
+ * const mlp_config = $.let({
414
+ * hidden_layers: [32n, 16n],
415
+ * activation: variant('none', null),
416
+ * dropout: variant('none', null),
417
+ * output_dim: variant('none', null),
418
+ * });
419
+ * const train_config = $.let({
420
+ * epochs: variant('some', 50n),
421
+ * batch_size: variant('some', 4n),
422
+ * learning_rate: variant('some', 0.01),
423
+ * loss: variant('none', null),
424
+ * optimizer: variant('none', null),
425
+ * early_stopping: variant('none', null),
426
+ * validation_split: variant('some', 0.2),
427
+ * random_state: variant('some', 42n),
428
+ * });
429
+ * return $.return(Torch.mlpTrain(X, y, mlp_config, train_config));
430
+ * });
431
+ * ```
432
+ */
433
+ export declare const Torch: {
434
+ /** Train MLP model */
435
+ readonly mlpTrain: import("@elaraai/east").PlatformDefinition<[ArrayType<ArrayType<FloatType>>, ArrayType<FloatType>, StructType<{
436
+ /** Hidden layer sizes, e.g., [64, 32] */
437
+ hidden_layers: ArrayType<IntegerType>;
438
+ /** Activation function (default relu) */
439
+ activation: OptionType<VariantType<{
440
+ /** Rectified Linear Unit */
441
+ relu: StructType<{}>;
442
+ /** Hyperbolic tangent */
443
+ tanh: StructType<{}>;
444
+ /** Sigmoid function */
445
+ sigmoid: StructType<{}>;
446
+ /** Leaky ReLU */
447
+ leaky_relu: StructType<{}>;
448
+ }>>;
449
+ /** Dropout rate (default 0.0) */
450
+ dropout: OptionType<FloatType>;
451
+ /** Output dimension (default 1) */
452
+ output_dim: OptionType<IntegerType>;
453
+ }>, StructType<{
454
+ /** Number of epochs (default 100) */
455
+ epochs: OptionType<IntegerType>;
456
+ /** Batch size (default 32) */
457
+ batch_size: OptionType<IntegerType>;
458
+ /** Learning rate (default 0.001) */
459
+ learning_rate: OptionType<FloatType>;
460
+ /** Loss function (default mse) */
461
+ loss: OptionType<VariantType<{
462
+ /** Mean Squared Error (regression) */
463
+ mse: StructType<{}>;
464
+ /** Mean Absolute Error (regression) */
465
+ mae: StructType<{}>;
466
+ /** Cross Entropy (classification) */
467
+ cross_entropy: StructType<{}>;
468
+ }>>;
469
+ /** Optimizer (default adam) */
470
+ optimizer: OptionType<VariantType<{
471
+ /** Adam optimizer */
472
+ adam: StructType<{}>;
473
+ /** Stochastic Gradient Descent */
474
+ sgd: StructType<{}>;
475
+ /** AdamW with weight decay */
476
+ adamw: StructType<{}>;
477
+ /** RMSprop optimizer */
478
+ rmsprop: StructType<{}>;
479
+ }>>;
480
+ /** Early stopping patience, 0 = disabled */
481
+ early_stopping: OptionType<IntegerType>;
482
+ /** Validation split fraction (default 0.2) */
483
+ validation_split: OptionType<FloatType>;
484
+ /** Random seed for reproducibility */
485
+ random_state: OptionType<IntegerType>;
486
+ }>], StructType<{
487
+ /** Trained model blob */
488
+ model: VariantType<{
489
+ torch_mlp: StructType<{
490
+ data: BlobType;
491
+ n_features: IntegerType;
492
+ hidden_layers: ArrayType<IntegerType>;
493
+ output_dim: IntegerType;
494
+ }>;
495
+ }>;
496
+ /** Training result with losses */
497
+ result: StructType<{
498
+ /** Training loss per epoch */
499
+ train_losses: ArrayType<FloatType>;
500
+ /** Validation loss per epoch */
501
+ val_losses: ArrayType<FloatType>;
502
+ /** Best epoch (for early stopping) */
503
+ best_epoch: IntegerType;
504
+ }>;
505
+ }>>;
506
+ /** Make predictions with MLP */
507
+ readonly mlpPredict: import("@elaraai/east").PlatformDefinition<[VariantType<{
508
+ /** PyTorch MLP model */
509
+ torch_mlp: StructType<{
510
+ /** Cloudpickle serialized model */
511
+ data: BlobType;
512
+ /** Number of input features */
513
+ n_features: IntegerType;
514
+ /** Hidden layer sizes */
515
+ hidden_layers: ArrayType<IntegerType>;
516
+ /** Output dimension */
517
+ output_dim: IntegerType;
518
+ }>;
519
+ }>, ArrayType<ArrayType<FloatType>>], ArrayType<FloatType>>;
520
+ /** Type definitions */
521
+ readonly Types: {
522
+ /** Vector type (array of floats) */
523
+ readonly VectorType: ArrayType<FloatType>;
524
+ /** Matrix type (2D array of floats) */
525
+ readonly MatrixType: ArrayType<ArrayType<FloatType>>;
526
+ /** Activation function type */
527
+ readonly TorchActivationType: VariantType<{
528
+ /** Rectified Linear Unit */
529
+ relu: StructType<{}>;
530
+ /** Hyperbolic tangent */
531
+ tanh: StructType<{}>;
532
+ /** Sigmoid function */
533
+ sigmoid: StructType<{}>;
534
+ /** Leaky ReLU */
535
+ leaky_relu: StructType<{}>;
536
+ }>;
537
+ /** Loss function type */
538
+ readonly TorchLossType: VariantType<{
539
+ /** Mean Squared Error (regression) */
540
+ mse: StructType<{}>;
541
+ /** Mean Absolute Error (regression) */
542
+ mae: StructType<{}>;
543
+ /** Cross Entropy (classification) */
544
+ cross_entropy: StructType<{}>;
545
+ }>;
546
+ /** Optimizer type */
547
+ readonly TorchOptimizerType: VariantType<{
548
+ /** Adam optimizer */
549
+ adam: StructType<{}>;
550
+ /** Stochastic Gradient Descent */
551
+ sgd: StructType<{}>;
552
+ /** AdamW with weight decay */
553
+ adamw: StructType<{}>;
554
+ /** RMSprop optimizer */
555
+ rmsprop: StructType<{}>;
556
+ }>;
557
+ /** MLP configuration type */
558
+ readonly TorchMLPConfigType: StructType<{
559
+ /** Hidden layer sizes, e.g., [64, 32] */
560
+ hidden_layers: ArrayType<IntegerType>;
561
+ /** Activation function (default relu) */
562
+ activation: OptionType<VariantType<{
563
+ /** Rectified Linear Unit */
564
+ relu: StructType<{}>;
565
+ /** Hyperbolic tangent */
566
+ tanh: StructType<{}>;
567
+ /** Sigmoid function */
568
+ sigmoid: StructType<{}>;
569
+ /** Leaky ReLU */
570
+ leaky_relu: StructType<{}>;
571
+ }>>;
572
+ /** Dropout rate (default 0.0) */
573
+ dropout: OptionType<FloatType>;
574
+ /** Output dimension (default 1) */
575
+ output_dim: OptionType<IntegerType>;
576
+ }>;
577
+ /** Training configuration type */
578
+ readonly TorchTrainConfigType: StructType<{
579
+ /** Number of epochs (default 100) */
580
+ epochs: OptionType<IntegerType>;
581
+ /** Batch size (default 32) */
582
+ batch_size: OptionType<IntegerType>;
583
+ /** Learning rate (default 0.001) */
584
+ learning_rate: OptionType<FloatType>;
585
+ /** Loss function (default mse) */
586
+ loss: OptionType<VariantType<{
587
+ /** Mean Squared Error (regression) */
588
+ mse: StructType<{}>;
589
+ /** Mean Absolute Error (regression) */
590
+ mae: StructType<{}>;
591
+ /** Cross Entropy (classification) */
592
+ cross_entropy: StructType<{}>;
593
+ }>>;
594
+ /** Optimizer (default adam) */
595
+ optimizer: OptionType<VariantType<{
596
+ /** Adam optimizer */
597
+ adam: StructType<{}>;
598
+ /** Stochastic Gradient Descent */
599
+ sgd: StructType<{}>;
600
+ /** AdamW with weight decay */
601
+ adamw: StructType<{}>;
602
+ /** RMSprop optimizer */
603
+ rmsprop: StructType<{}>;
604
+ }>>;
605
+ /** Early stopping patience, 0 = disabled */
606
+ early_stopping: OptionType<IntegerType>;
607
+ /** Validation split fraction (default 0.2) */
608
+ validation_split: OptionType<FloatType>;
609
+ /** Random seed for reproducibility */
610
+ random_state: OptionType<IntegerType>;
611
+ }>;
612
+ /** Training result type */
613
+ readonly TorchTrainResultType: StructType<{
614
+ /** Training loss per epoch */
615
+ train_losses: ArrayType<FloatType>;
616
+ /** Validation loss per epoch */
617
+ val_losses: ArrayType<FloatType>;
618
+ /** Best epoch (for early stopping) */
619
+ best_epoch: IntegerType;
620
+ }>;
621
+ /** Training output type (model + result) */
622
+ readonly TorchTrainOutputType: StructType<{
623
+ /** Trained model blob */
624
+ model: VariantType<{
625
+ torch_mlp: StructType<{
626
+ data: BlobType;
627
+ n_features: IntegerType;
628
+ hidden_layers: ArrayType<IntegerType>;
629
+ output_dim: IntegerType;
630
+ }>;
631
+ }>;
632
+ /** Training result with losses */
633
+ result: StructType<{
634
+ /** Training loss per epoch */
635
+ train_losses: ArrayType<FloatType>;
636
+ /** Validation loss per epoch */
637
+ val_losses: ArrayType<FloatType>;
638
+ /** Best epoch (for early stopping) */
639
+ best_epoch: IntegerType;
640
+ }>;
641
+ }>;
642
+ /** Model blob type for PyTorch models */
643
+ readonly ModelBlobType: VariantType<{
644
+ /** PyTorch MLP model */
645
+ torch_mlp: StructType<{
646
+ /** Cloudpickle serialized model */
647
+ data: BlobType;
648
+ /** Number of input features */
649
+ n_features: IntegerType;
650
+ /** Hidden layer sizes */
651
+ hidden_layers: ArrayType<IntegerType>;
652
+ /** Output dimension */
653
+ output_dim: IntegerType;
654
+ }>;
655
+ }>;
656
+ };
657
+ };
658
+ //# sourceMappingURL=torch.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"torch.d.ts","sourceRoot":"","sources":["../../src/torch/torch.ts"],"names":[],"mappings":"AAAA;;;GAGG;AAEH;;;;;;;GAOG;AAEH,OAAO,EAEH,UAAU,EACV,WAAW,EACX,UAAU,EACV,WAAW,EACX,SAAS,EACT,QAAQ,EACR,SAAS,EACZ,MAAM,eAAe,CAAC;AAIvB,OAAO,EAAE,UAAU,EAAE,UAAU,EAAE,MAAM,aAAa,CAAC;AAMrD;;GAEG;AACH,eAAO,MAAM,mBAAmB;IAC5B,4BAA4B;;IAE5B,yBAAyB;;IAEzB,uBAAuB;;IAEvB,iBAAiB;;EAEnB,CAAC;AAEH;;GAEG;AACH,eAAO,MAAM,aAAa;IACtB,sCAAsC;;IAEtC,uCAAuC;;IAEvC,qCAAqC;;EAEvC,CAAC;AAEH;;GAEG;AACH,eAAO,MAAM,kBAAkB;IAC3B,qBAAqB;;IAErB,kCAAkC;;IAElC,8BAA8B;;IAE9B,wBAAwB;;EAE1B,CAAC;AAMH;;GAEG;AACH,eAAO,MAAM,kBAAkB;IAC3B,yCAAyC;;IAEzC,yCAAyC;;QA9CzC,4BAA4B;;QAE5B,yBAAyB;;QAEzB,uBAAuB;;QAEvB,iBAAiB;;;IA0CjB,iCAAiC;;IAEjC,mCAAmC;;EAErC,CAAC;AAEH;;GAEG;AACH,eAAO,MAAM,oBAAoB;IAC7B,qCAAqC;;IAErC,8BAA8B;;IAE9B,oCAAoC;;IAEpC,kCAAkC;;QAlDlC,sCAAsC;;QAEtC,uCAAuC;;QAEvC,qCAAqC;;;IAgDrC,+BAA+B;;QAxC/B,qBAAqB;;QAErB,kCAAkC;;QAElC,8BAA8B;;QAE9B,wBAAwB;;;IAoCxB,4CAA4C;;IAE5C,8CAA8C;;IAE9C,sCAAsC;;EAExC,CAAC;AAMH;;GAEG;AACH,eAAO,MAAM,oBAAoB;IAC7B,8BAA8B;;IAE9B,gCAAgC;;IAEhC,sCAAsC;;EAExC,CAAC;AAEH;;GAEG;AACH,eAAO,MAAM,oBAAoB;IAC7B,yBAAyB;;;;;;;;;IASzB,kCAAkC;;QArBlC,8BAA8B;;QAE9B,gCAAgC;;QAEhC,sCAAsC;;;EAmBxC,CAAC;AAMH;;GAEG;AACH,eAAO,MAAM,kBAAkB;IAC3B,wBAAwB;;QAEpB,mCAAmC;;QAEnC,+BAA+B;;QAE/B,yBAAyB;;QAEzB,uBAAuB;;;EAG7B,CAAC;AAMH;;;;;;;;GAQG;AACH,eAAO,MAAM,eAAe;IAnGxB,yCAAyC;;IAEzC,yCAAyC;;QA9CzC,4BAA4B;;QAE5B,yBAAyB;;QAEzB,uBAAuB;;QAEvB,iBAAiB;;;IA0CjB,iCAAiC;;IAEjC,mCAAmC;;;IAQnC,qCAAqC;;IAErC,8BAA8B;;IAE9B,oCAAoC;;IAEpC,kCAAkC;;QAlDlC,sCAAsC;;QAEtC,uCAAuC;;QAEvC,qCAAqC;;;IAgDrC,+BAA+B;;QAxC/B,qBAAqB;;QAErB,kCAAkC;;QAElC,8BAA8B;;QAE9B,wBAAwB;;;IAoCxB,4CAA4C;;IAE5C,8CAA8C;;IAE9C,sCAAsC;;;IAwBtC,yBAAyB;;;;;;;;;IASzB,kCAAkC;;QArBlC,8BAA8B;;QAE9B,gCAAgC;;QAEhC,sCAAsC;;;GA2DzC,CAAC;AAEF;;;;;;GAMG;AACH,eAAO,MAAM,iBAAiB;IAvC1B,wBAAwB;;QAEpB,mCAAmC;;QAEnC,+BAA+B;;QAE/B,yBAAyB;;QAEzB,uBAAuB;;;2DAmC9B,CAAC;AAMF;;GAEG;AACH,eAAO,MAAM,UAAU;IACnB,oCAAoC;;IAEpC,uCAAuC;;IAEvC,+BAA+B;;QA9K/B,4BAA4B;;QAE5B,yBAAyB;;QAEzB,uBAAuB;;QAEvB,iBAAiB;;;IA0KjB,yBAAyB;;QAlKzB,sCAAsC;;QAEtC,uCAAuC;;QAEvC,qCAAqC;;;IAgKrC,qBAAqB;;QAxJrB,qBAAqB;;QAErB,kCAAkC;;QAElC,8BAA8B;;QAE9B,wBAAwB;;;IAoJxB,6BAA6B;;QAxI7B,yCAAyC;;QAEzC,yCAAyC;;YA9CzC,4BAA4B;;YAE5B,yBAAyB;;YAEzB,uBAAuB;;YAEvB,iBAAiB;;;QA0CjB,iCAAiC;;QAEjC,mCAAmC;;;IAoInC,kCAAkC;;QA5HlC,qCAAqC;;QAErC,8BAA8B;;QAE9B,oCAAoC;;QAEpC,kCAAkC;;YAlDlC,sCAAsC;;YAEtC,uCAAuC;;YAEvC,qCAAqC;;;QAgDrC,+BAA+B;;YAxC/B,qBAAqB;;YAErB,kCAAkC;;YAElC,8BAA8B;;YAE9B,wBAAwB;;;QAoCxB,4CAA4C;;QAE5C,8CAA8C;;QAE9C,sCAAsC;;;IAgHtC,2BAA2B;;QApG3B,8BAA8B;;QAE9B,gCAAgC;;QAEhC,sCAAsC;;;IAkGtC,4CAA4C;;QA1F5C,yBAAyB;;;;;;;;;QASzB,kCAAkC;;YArBlC,8BAA8B;;YAE9B,gCAAgC;;YAEhC,sCAAsC;;;;IAoGtC,yCAAyC;;QAvEzC,wBAAwB;;YAEpB,mCAAmC;;YAEnC,+BAA+B;;YAE/B,yBAAyB;;YAEzB,uBAAuB;;;;CAiErB,CAAC;AAEX;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;GAgCG;AACH,eAAO,MAAM,KAAK;IACd,sBAAsB;;QAtLtB,yCAAyC;;QAEzC,yCAAyC;;YA9CzC,4BAA4B;;YAE5B,yBAAyB;;YAEzB,uBAAuB;;YAEvB,iBAAiB;;;QA0CjB,iCAAiC;;QAEjC,mCAAmC;;;QAQnC,qCAAqC;;QAErC,8BAA8B;;QAE9B,oCAAoC;;QAEpC,kCAAkC;;YAlDlC,sCAAsC;;YAEtC,uCAAuC;;YAEvC,qCAAqC;;;QAgDrC,+BAA+B;;YAxC/B,qBAAqB;;YAErB,kCAAkC;;YAElC,8BAA8B;;YAE9B,wBAAwB;;;QAoCxB,4CAA4C;;QAE5C,8CAA8C;;QAE9C,sCAAsC;;;QAwBtC,yBAAyB;;;;;;;;;QASzB,kCAAkC;;YArBlC,8BAA8B;;YAE9B,gCAAgC;;YAEhC,sCAAsC;;;;IA4ItC,gCAAgC;;QA/GhC,wBAAwB;;YAEpB,mCAAmC;;YAEnC,+BAA+B;;YAE/B,yBAAyB;;YAEzB,uBAAuB;;;;IAyG3B,uBAAuB;;QA5DvB,oCAAoC;;QAEpC,uCAAuC;;QAEvC,+BAA+B;;YA9K/B,4BAA4B;;YAE5B,yBAAyB;;YAEzB,uBAAuB;;YAEvB,iBAAiB;;;QA0KjB,yBAAyB;;YAlKzB,sCAAsC;;YAEtC,uCAAuC;;YAEvC,qCAAqC;;;QAgKrC,qBAAqB;;YAxJrB,qBAAqB;;YAErB,kCAAkC;;YAElC,8BAA8B;;YAE9B,wBAAwB;;;QAoJxB,6BAA6B;;YAxI7B,yCAAyC;;YAEzC,yCAAyC;;gBA9CzC,4BAA4B;;gBAE5B,yBAAyB;;gBAEzB,uBAAuB;;gBAEvB,iBAAiB;;;YA0CjB,iCAAiC;;YAEjC,mCAAmC;;;QAoInC,kCAAkC;;YA5HlC,qCAAqC;;YAErC,8BAA8B;;YAE9B,oCAAoC;;YAEpC,kCAAkC;;gBAlDlC,sCAAsC;;gBAEtC,uCAAuC;;gBAEvC,qCAAqC;;;YAgDrC,+BAA+B;;gBAxC/B,qBAAqB;;gBAErB,kCAAkC;;gBAElC,8BAA8B;;gBAE9B,wBAAwB;;;YAoCxB,4CAA4C;;YAE5C,8CAA8C;;YAE9C,sCAAsC;;;QAgHtC,2BAA2B;;YApG3B,8BAA8B;;YAE9B,gCAAgC;;YAEhC,sCAAsC;;;QAkGtC,4CAA4C;;YA1F5C,yBAAyB;;;;;;;;;YASzB,kCAAkC;;gBArBlC,8BAA8B;;gBAE9B,gCAAgC;;gBAEhC,sCAAsC;;;;QAoGtC,yCAAyC;;YAvEzC,wBAAwB;;gBAEpB,mCAAmC;;gBAEnC,+BAA+B;;gBAE/B,yBAAyB;;gBAEzB,uBAAuB;;;;;CA2GrB,CAAC"}