deepbox 0.1.0 → 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +1 -1
- package/README.md +39 -37
- package/dist/{CSRMatrix-KzNt6QpS.d.ts → CSRMatrix-B7XtUAZO.d.cts} +3 -3
- package/dist/{CSRMatrix-CwGwQRea.d.cts → CSRMatrix-CtD23fRM.d.ts} +3 -3
- package/dist/{Tensor-BQLk1ltW.d.cts → Tensor-BORFp_zt.d.ts} +27 -7
- package/dist/{Tensor-g8mUClel.d.ts → Tensor-fxBg-TFZ.d.cts} +27 -7
- package/dist/{chunk-FJYLIGJX.js → chunk-3AX37GPK.js} +33 -7
- package/dist/chunk-3AX37GPK.js.map +1 -0
- package/dist/{chunk-PR647I7R.js → chunk-3YFEYDHN.js} +21 -4
- package/dist/chunk-3YFEYDHN.js.map +1 -0
- package/dist/{chunk-XMWVME2W.js → chunk-6SX26MAJ.js} +4 -4
- package/dist/{chunk-XMWVME2W.js.map → chunk-6SX26MAJ.js.map} +1 -1
- package/dist/{chunk-C4PKXY74.cjs → chunk-6X7XFNDO.cjs} +94 -77
- package/dist/chunk-6X7XFNDO.cjs.map +1 -0
- package/dist/{chunk-6AE5FKKQ.cjs → chunk-724CXHFH.cjs} +1211 -919
- package/dist/chunk-724CXHFH.cjs.map +1 -0
- package/dist/{chunk-AU7XHGKJ.js → chunk-AJTKVBY5.js} +4 -4
- package/dist/{chunk-AU7XHGKJ.js.map → chunk-AJTKVBY5.js.map} +1 -1
- package/dist/{chunk-ZB75FESB.cjs → chunk-AV6WGSYX.cjs} +130 -104
- package/dist/chunk-AV6WGSYX.cjs.map +1 -0
- package/dist/{chunk-ZLW62TJG.cjs → chunk-BWOSU234.cjs} +142 -141
- package/dist/chunk-BWOSU234.cjs.map +1 -0
- package/dist/{chunk-4S73VUBD.js → chunk-CZOMBUI7.js} +3 -3
- package/dist/chunk-CZOMBUI7.js.map +1 -0
- package/dist/{chunk-QERHVCHC.cjs → chunk-EUZHJDZ6.cjs} +419 -364
- package/dist/chunk-EUZHJDZ6.cjs.map +1 -0
- package/dist/{chunk-AD436M45.js → chunk-G2G55ATL.js} +120 -58
- package/dist/chunk-G2G55ATL.js.map +1 -0
- package/dist/{chunk-5R4S63PF.js → chunk-G3WNLNYS.js} +119 -64
- package/dist/chunk-G3WNLNYS.js.map +1 -0
- package/dist/{chunk-XEG44RF6.cjs → chunk-G7KXZHG6.cjs} +105 -95
- package/dist/chunk-G7KXZHG6.cjs.map +1 -0
- package/dist/{chunk-MLBMYKCG.js → chunk-H3JR7SV2.js} +255 -113
- package/dist/chunk-H3JR7SV2.js.map +1 -0
- package/dist/{chunk-PHV2DKRS.cjs → chunk-HDKMIG6E.cjs} +107 -107
- package/dist/{chunk-PHV2DKRS.cjs.map → chunk-HDKMIG6E.cjs.map} +1 -1
- package/dist/{chunk-ALS7ETWZ.cjs → chunk-HI2EZHCJ.cjs} +111 -102
- package/dist/chunk-HI2EZHCJ.cjs.map +1 -0
- package/dist/{chunk-OX6QXFMV.cjs → chunk-IT4BZUYE.cjs} +490 -428
- package/dist/chunk-IT4BZUYE.cjs.map +1 -0
- package/dist/{chunk-E3EU5FZO.cjs → chunk-JTZPRV6E.cjs} +123 -123
- package/dist/{chunk-E3EU5FZO.cjs.map → chunk-JTZPRV6E.cjs.map} +1 -1
- package/dist/{chunk-PL7TAYKI.js → chunk-K2L5C5YH.js} +8 -7
- package/dist/chunk-K2L5C5YH.js.map +1 -0
- package/dist/{chunk-BCR7G3A6.js → chunk-KCF6P34A.js} +356 -64
- package/dist/chunk-KCF6P34A.js.map +1 -0
- package/dist/{chunk-ZXKBDFP3.js → chunk-LZHVHD62.js} +15 -6
- package/dist/chunk-LZHVHD62.js.map +1 -0
- package/dist/{chunk-LWECRCW2.cjs → chunk-MTJF52AJ.cjs} +141 -141
- package/dist/{chunk-LWECRCW2.cjs.map → chunk-MTJF52AJ.cjs.map} +1 -1
- package/dist/{chunk-B5TNKUEY.js → chunk-NDDTUFKK.js} +16 -6
- package/dist/chunk-NDDTUFKK.js.map +1 -0
- package/dist/{chunk-DWZY6PIP.cjs → chunk-NOQI6OFL.cjs} +615 -473
- package/dist/chunk-NOQI6OFL.cjs.map +1 -0
- package/dist/{chunk-F3JWBINJ.js → chunk-OEXDJFHA.js} +4 -4
- package/dist/{chunk-F3JWBINJ.js.map → chunk-OEXDJFHA.js.map} +1 -1
- package/dist/{chunk-JSCDE774.cjs → chunk-Z6BGACIH.cjs} +3 -3
- package/dist/chunk-Z6BGACIH.cjs.map +1 -0
- package/dist/core/index.cjs +50 -50
- package/dist/core/index.d.cts +2 -2
- package/dist/core/index.d.ts +2 -2
- package/dist/core/index.js +1 -1
- package/dist/dataframe/index.cjs +6 -6
- package/dist/dataframe/index.d.cts +3 -3
- package/dist/dataframe/index.d.ts +3 -3
- package/dist/dataframe/index.js +3 -3
- package/dist/datasets/index.cjs +34 -34
- package/dist/datasets/index.d.cts +3 -3
- package/dist/datasets/index.d.ts +3 -3
- package/dist/datasets/index.js +3 -3
- package/dist/{index-C1mfVYoo.d.ts → index-B18dHc8q.d.ts} +81 -46
- package/dist/{index-GFAVyOWO.d.ts → index-BHHX0qTY.d.cts} +14 -12
- package/dist/{index-tk4lSYod.d.ts → index-BI6QOUvV.d.ts} +106 -80
- package/dist/{index-DIp_RrRt.d.ts → index-BKvK21lf.d.ts} +13 -35
- package/dist/{index-BJY2SI4i.d.ts → index-BL8jLf3K.d.cts} +12 -11
- package/dist/{index-Cn3SdB0O.d.ts → index-BNbX167d.d.cts} +16 -10
- package/dist/{index-BWGhrDlr.d.ts → index-BT2ofL7Z.d.cts} +35 -35
- package/dist/{index-BbA2Gxfl.d.ts → index-BqcfIcL4.d.ts} +15 -15
- package/dist/{index-ZtI1Iy4L.d.ts → index-BrgrECM2.d.ts} +41 -38
- package/dist/{index-CDw5CnOU.d.ts → index-BtYKI9yJ.d.ts} +10 -8
- package/dist/{index-DIT_OO9C.d.cts → index-C7nLsAOC.d.cts} +10 -8
- package/dist/{index-D9Loo1_A.d.cts → index-CNj2Mxwf.d.cts} +81 -46
- package/dist/{index-DmEg_LCm.d.cts → index-CYlxeNW1.d.cts} +5 -3
- package/dist/{index-D61yaSMY.d.cts → index-CiTd61a5.d.ts} +12 -11
- package/dist/{index-BndMbqsM.d.ts → index-Cjnn0KeN.d.cts} +35 -21
- package/dist/{index-9oQx1HgV.d.cts → index-CkGGAn69.d.cts} +41 -38
- package/dist/{index-74AB8Cyh.d.cts → index-D4URSgqA.d.ts} +16 -10
- package/dist/{index-DoPWVxPo.d.cts → index-D4pn5zLT.d.ts} +35 -21
- package/dist/{index-DuCxd-8d.d.ts → index-D9ztTlDr.d.ts} +60 -42
- package/dist/{index-BgHYAoSS.d.cts → index-DF28ZPB5.d.cts} +60 -42
- package/dist/{index-eJgeni9c.d.cts → index-DLdiQzf0.d.cts} +106 -80
- package/dist/{index-WHQLn0e8.d.cts → index-DN4omPQw.d.ts} +35 -35
- package/dist/{index-CrqLlS-a.d.ts → index-DUnFq1WV.d.ts} +5 -3
- package/dist/{index-DbultU6X.d.cts → index-DgaYshkF.d.ts} +14 -12
- package/dist/{index-B_DK4FKY.d.cts → index-GUHYEhxs.d.cts} +13 -35
- package/dist/{index-CCvlwAmL.d.cts → index-TP--4irE.d.cts} +16 -14
- package/dist/{index-Dx42TZaY.d.ts → index-x0z_sanT.d.ts} +16 -14
- package/dist/{index-DyZ4QQf5.d.cts → index-xWH7ujWa.d.cts} +15 -15
- package/dist/index.cjs +26 -26
- package/dist/index.d.cts +17 -17
- package/dist/index.d.ts +17 -17
- package/dist/index.js +13 -13
- package/dist/linalg/index.cjs +22 -22
- package/dist/linalg/index.d.cts +3 -3
- package/dist/linalg/index.d.ts +3 -3
- package/dist/linalg/index.js +3 -3
- package/dist/metrics/index.cjs +40 -40
- package/dist/metrics/index.d.cts +3 -3
- package/dist/metrics/index.d.ts +3 -3
- package/dist/metrics/index.js +3 -3
- package/dist/ml/index.cjs +23 -23
- package/dist/ml/index.d.cts +3 -3
- package/dist/ml/index.d.ts +3 -3
- package/dist/ml/index.js +4 -4
- package/dist/ndarray/index.cjs +125 -125
- package/dist/ndarray/index.d.cts +5 -5
- package/dist/ndarray/index.d.ts +5 -5
- package/dist/ndarray/index.js +2 -2
- package/dist/nn/index.cjs +36 -36
- package/dist/nn/index.d.cts +6 -6
- package/dist/nn/index.d.ts +6 -6
- package/dist/nn/index.js +3 -3
- package/dist/optim/index.cjs +19 -19
- package/dist/optim/index.d.cts +4 -4
- package/dist/optim/index.d.ts +4 -4
- package/dist/optim/index.js +2 -2
- package/dist/plot/index.cjs +29 -29
- package/dist/plot/index.d.cts +6 -6
- package/dist/plot/index.d.ts +6 -6
- package/dist/plot/index.js +3 -3
- package/dist/preprocess/index.cjs +21 -21
- package/dist/preprocess/index.d.cts +4 -4
- package/dist/preprocess/index.d.ts +4 -4
- package/dist/preprocess/index.js +3 -3
- package/dist/random/index.cjs +19 -19
- package/dist/random/index.d.cts +3 -3
- package/dist/random/index.d.ts +3 -3
- package/dist/random/index.js +3 -3
- package/dist/stats/index.cjs +36 -36
- package/dist/stats/index.d.cts +3 -3
- package/dist/stats/index.d.ts +3 -3
- package/dist/stats/index.js +3 -3
- package/dist/{tensor-B96jjJLQ.d.cts → tensor-IlVTF0bz.d.cts} +16 -3
- package/dist/{tensor-B96jjJLQ.d.ts → tensor-IlVTF0bz.d.ts} +16 -3
- package/package.json +3 -2
- package/dist/chunk-4S73VUBD.js.map +0 -1
- package/dist/chunk-5R4S63PF.js.map +0 -1
- package/dist/chunk-6AE5FKKQ.cjs.map +0 -1
- package/dist/chunk-AD436M45.js.map +0 -1
- package/dist/chunk-ALS7ETWZ.cjs.map +0 -1
- package/dist/chunk-B5TNKUEY.js.map +0 -1
- package/dist/chunk-BCR7G3A6.js.map +0 -1
- package/dist/chunk-C4PKXY74.cjs.map +0 -1
- package/dist/chunk-DWZY6PIP.cjs.map +0 -1
- package/dist/chunk-FJYLIGJX.js.map +0 -1
- package/dist/chunk-JSCDE774.cjs.map +0 -1
- package/dist/chunk-MLBMYKCG.js.map +0 -1
- package/dist/chunk-OX6QXFMV.cjs.map +0 -1
- package/dist/chunk-PL7TAYKI.js.map +0 -1
- package/dist/chunk-PR647I7R.js.map +0 -1
- package/dist/chunk-QERHVCHC.cjs.map +0 -1
- package/dist/chunk-XEG44RF6.cjs.map +0 -1
- package/dist/chunk-ZB75FESB.cjs.map +0 -1
- package/dist/chunk-ZLW62TJG.cjs.map +0 -1
- package/dist/chunk-ZXKBDFP3.js.map +0 -1
|
@@ -1,14 +1,15 @@
|
|
|
1
|
-
import { T as Tensor } from './Tensor-
|
|
1
|
+
import { T as Tensor } from './Tensor-fxBg-TFZ.cjs';
|
|
2
|
+
import { S as Shape, c as ScalarDType } from './tensor-IlVTF0bz.cjs';
|
|
2
3
|
|
|
3
4
|
/**
|
|
4
5
|
* Base type for all estimators (models) in Deepbox.
|
|
5
6
|
*
|
|
6
|
-
*
|
|
7
|
+
* Base estimator type for all ML models.
|
|
7
8
|
*
|
|
8
9
|
* @template FitParams - Type of parameters passed to fit method
|
|
9
10
|
*
|
|
10
11
|
* References:
|
|
11
|
-
* -
|
|
12
|
+
* - Deepbox ML: https://deepbox.dev/docs/ml-linear
|
|
12
13
|
*/
|
|
13
14
|
type Estimator<FitParams = void> = {
|
|
14
15
|
/**
|
|
@@ -238,7 +239,7 @@ type OutlierDetector = Estimator<void> & {
|
|
|
238
239
|
* // labels: [0, 0, 0, 1, 1, -1] (-1 = noise)
|
|
239
240
|
* ```
|
|
240
241
|
*
|
|
241
|
-
* @see {@link https://
|
|
242
|
+
* @see {@link https://deepbox.dev/docs/ml-clustering | Deepbox Clustering}
|
|
242
243
|
*/
|
|
243
244
|
declare class DBSCAN implements Clusterer {
|
|
244
245
|
private eps;
|
|
@@ -295,6 +296,13 @@ declare class DBSCAN implements Clusterer {
|
|
|
295
296
|
* @throws {NotFittedError} If the model has not been fitted
|
|
296
297
|
*/
|
|
297
298
|
get labels(): Tensor;
|
|
299
|
+
/**
|
|
300
|
+
* Number of clusters found (excluding noise).
|
|
301
|
+
*
|
|
302
|
+
* @returns Number of distinct clusters (labels >= 0)
|
|
303
|
+
* @throws {NotFittedError} If the model has not been fitted
|
|
304
|
+
*/
|
|
305
|
+
get nClusters(): number;
|
|
298
306
|
/**
|
|
299
307
|
* Get indices of core samples discovered during fitting.
|
|
300
308
|
*
|
|
@@ -352,8 +360,8 @@ declare class DBSCAN implements Clusterer {
|
|
|
352
360
|
* console.log('Centroids:', kmeans.clusterCenters);
|
|
353
361
|
* ```
|
|
354
362
|
*
|
|
355
|
-
* @see {@link https://
|
|
356
|
-
* @see {@link https://
|
|
363
|
+
* @see {@link https://deepbox.dev/docs/ml-clustering | Deepbox Clustering}
|
|
364
|
+
* @see {@link https://deepbox.dev/docs/ml-clustering | Deepbox Clustering}
|
|
357
365
|
*/
|
|
358
366
|
declare class KMeans implements Clusterer {
|
|
359
367
|
private nClusters;
|
|
@@ -482,8 +490,8 @@ declare class KMeans implements Clusterer {
|
|
|
482
490
|
* console.log('Explained variance ratio:', pca.explainedVarianceRatio);
|
|
483
491
|
* ```
|
|
484
492
|
*
|
|
485
|
-
* @see {@link https://
|
|
486
|
-
* @see {@link https://
|
|
493
|
+
* @see {@link https://deepbox.dev/docs/ml-decomposition | Deepbox Dimensionality Reduction}
|
|
494
|
+
* @see {@link https://deepbox.dev/docs/ml-decomposition | Deepbox Dimensionality Reduction}
|
|
487
495
|
*/
|
|
488
496
|
declare class PCA implements Transformer {
|
|
489
497
|
private readonly nComponents?;
|
|
@@ -590,7 +598,7 @@ declare class PCA implements Transformer {
|
|
|
590
598
|
* const predictions = gbr.predict(X);
|
|
591
599
|
* ```
|
|
592
600
|
*
|
|
593
|
-
* @see {@link https://
|
|
601
|
+
* @see {@link https://deepbox.dev/docs/ml-ensemble | Deepbox Ensemble Methods}
|
|
594
602
|
*/
|
|
595
603
|
declare class GradientBoostingRegressor implements Regressor {
|
|
596
604
|
/** Number of boosting stages (trees) */
|
|
@@ -669,8 +677,10 @@ declare class GradientBoostingRegressor implements Regressor {
|
|
|
669
677
|
/**
|
|
670
678
|
* Gradient Boosting Classifier.
|
|
671
679
|
*
|
|
672
|
-
* Uses gradient boosting with shallow regression trees for
|
|
673
|
-
*
|
|
680
|
+
* Uses gradient boosting with shallow regression trees for classification.
|
|
681
|
+
* Supports both binary and multiclass classification.
|
|
682
|
+
* - Binary: optimizes log loss using sigmoid function.
|
|
683
|
+
* - Multiclass: uses One-vs-Rest (OvR) strategy, training one binary model per class.
|
|
674
684
|
*
|
|
675
685
|
* @example
|
|
676
686
|
* ```ts
|
|
@@ -685,7 +695,7 @@ declare class GradientBoostingRegressor implements Regressor {
|
|
|
685
695
|
* const predictions = gbc.predict(X);
|
|
686
696
|
* ```
|
|
687
697
|
*
|
|
688
|
-
* @see {@link https://
|
|
698
|
+
* @see {@link https://deepbox.dev/docs/ml-ensemble | Deepbox Ensemble Methods}
|
|
689
699
|
*/
|
|
690
700
|
declare class GradientBoostingClassifier implements Classifier {
|
|
691
701
|
/** Number of boosting stages */
|
|
@@ -696,10 +706,10 @@ declare class GradientBoostingClassifier implements Classifier {
|
|
|
696
706
|
private maxDepth;
|
|
697
707
|
/** Minimum samples to split */
|
|
698
708
|
private minSamplesSplit;
|
|
699
|
-
/**
|
|
700
|
-
private
|
|
701
|
-
/**
|
|
702
|
-
private
|
|
709
|
+
/** Per-class arrays of weak learners (OvR for multiclass, single for binary) */
|
|
710
|
+
private estimatorsPerClass;
|
|
711
|
+
/** Per-class initial log-odds predictions */
|
|
712
|
+
private initPredictions;
|
|
703
713
|
/** Number of features */
|
|
704
714
|
private nFeatures;
|
|
705
715
|
/** Unique class labels */
|
|
@@ -712,19 +722,29 @@ declare class GradientBoostingClassifier implements Classifier {
|
|
|
712
722
|
readonly maxDepth?: number;
|
|
713
723
|
readonly minSamplesSplit?: number;
|
|
714
724
|
});
|
|
725
|
+
/**
|
|
726
|
+
* Fit a single binary boosting ensemble.
|
|
727
|
+
* Trains nEstimators regression trees to optimize log loss for a binary target.
|
|
728
|
+
*/
|
|
729
|
+
private fitBinary;
|
|
730
|
+
/**
|
|
731
|
+
* Compute raw scores for a single binary ensemble.
|
|
732
|
+
*/
|
|
733
|
+
private predictRawBinary;
|
|
715
734
|
/**
|
|
716
735
|
* Fit the gradient boosting classifier on training data.
|
|
717
736
|
*
|
|
718
737
|
* Builds an additive model by sequentially fitting regression trees
|
|
719
738
|
* to the pseudo-residuals (gradient of log loss).
|
|
739
|
+
* Supports binary (2 classes) and multiclass (>2 classes via OvR).
|
|
720
740
|
*
|
|
721
741
|
* @param X - Training data of shape (n_samples, n_features)
|
|
722
|
-
* @param y - Target class labels of shape (n_samples,). Must contain
|
|
742
|
+
* @param y - Target class labels of shape (n_samples,). Must contain at least 2 classes.
|
|
723
743
|
* @returns this - The fitted estimator
|
|
724
744
|
* @throws {ShapeError} If X is not 2D or y is not 1D
|
|
725
745
|
* @throws {ShapeError} If X and y have different number of samples
|
|
726
746
|
* @throws {DataValidationError} If X or y contain NaN/Inf values
|
|
727
|
-
* @throws {InvalidParameterError} If y does not contain
|
|
747
|
+
* @throws {InvalidParameterError} If y does not contain at least 2 classes
|
|
728
748
|
*/
|
|
729
749
|
fit(X: Tensor, y: Tensor): this;
|
|
730
750
|
/**
|
|
@@ -740,11 +760,10 @@ declare class GradientBoostingClassifier implements Classifier {
|
|
|
740
760
|
/**
|
|
741
761
|
* Predict class probabilities for samples in X.
|
|
742
762
|
*
|
|
743
|
-
* Returns a matrix of shape (n_samples,
|
|
744
|
-
* [P(class_0), P(class_1)].
|
|
763
|
+
* Returns a matrix of shape (n_samples, n_classes).
|
|
745
764
|
*
|
|
746
765
|
* @param X - Samples of shape (n_samples, n_features)
|
|
747
|
-
* @returns Class probability matrix of shape (n_samples,
|
|
766
|
+
* @returns Class probability matrix of shape (n_samples, n_classes)
|
|
748
767
|
* @throws {NotFittedError} If the model has not been fitted
|
|
749
768
|
* @throws {ShapeError} If X has wrong dimensions or feature count
|
|
750
769
|
* @throws {DataValidationError} If X contains NaN/Inf values
|
|
@@ -1041,7 +1060,7 @@ declare class LinearRegression implements Regressor {
|
|
|
1041
1060
|
* @throws {ShapeError} If X has wrong dimensions or feature count
|
|
1042
1061
|
* @throws {DataValidationError} If X contains NaN/Inf values
|
|
1043
1062
|
*/
|
|
1044
|
-
predict(X: Tensor): Tensor
|
|
1063
|
+
predict(X: Tensor): Tensor<Shape, ScalarDType>;
|
|
1045
1064
|
/**
|
|
1046
1065
|
* Return the coefficient of determination R^2 of the prediction.
|
|
1047
1066
|
*
|
|
@@ -1434,7 +1453,7 @@ declare class Ridge implements Regressor {
|
|
|
1434
1453
|
* const embedding = tsne.fitTransform(X);
|
|
1435
1454
|
* ```
|
|
1436
1455
|
*
|
|
1437
|
-
* @see {@link https://
|
|
1456
|
+
* @see {@link https://deepbox.dev/docs/ml-manifold | Deepbox Manifold Learning}
|
|
1438
1457
|
* @see van der Maaten, L.J.P.; Hinton, G.E. (2008). "Visualizing High-Dimensional Data Using t-SNE"
|
|
1439
1458
|
*/
|
|
1440
1459
|
declare class TSNE {
|
|
@@ -1543,6 +1562,14 @@ declare class TSNE {
|
|
|
1543
1562
|
* Fit the model (same as fitTransform for t-SNE).
|
|
1544
1563
|
*/
|
|
1545
1564
|
fit(X: Tensor): this;
|
|
1565
|
+
/**
|
|
1566
|
+
* Return the fitted embedding. For t-SNE, transform is equivalent to
|
|
1567
|
+
* returning the already-computed embedding (t-SNE is non-parametric).
|
|
1568
|
+
*
|
|
1569
|
+
* @param _X - Ignored, present for API compatibility
|
|
1570
|
+
* @returns Low-dimensional embedding of shape (n_samples, n_components)
|
|
1571
|
+
*/
|
|
1572
|
+
transform(_X?: Tensor): Tensor;
|
|
1546
1573
|
/**
|
|
1547
1574
|
* Get the embedding.
|
|
1548
1575
|
*/
|
|
@@ -1585,8 +1612,8 @@ declare class TSNE {
|
|
|
1585
1612
|
* const predictions = nb.predict(tensor([[2.5, 3.5]]));
|
|
1586
1613
|
* ```
|
|
1587
1614
|
*
|
|
1588
|
-
* @see {@link https://
|
|
1589
|
-
* @see {@link https://
|
|
1615
|
+
* @see {@link https://deepbox.dev/docs/ml-naive-bayes | Deepbox Naive Bayes}
|
|
1616
|
+
* @see {@link https://deepbox.dev/docs/ml-naive-bayes | Deepbox Naive Bayes}
|
|
1590
1617
|
*/
|
|
1591
1618
|
declare class GaussianNB implements Classifier {
|
|
1592
1619
|
private readonly varSmoothing;
|
|
@@ -1737,8 +1764,8 @@ declare abstract class KNeighborsBase {
|
|
|
1737
1764
|
* const predictions = knn.predict(tensor([[1.5, 1.5]]));
|
|
1738
1765
|
* ```
|
|
1739
1766
|
*
|
|
1740
|
-
* @see {@link https://
|
|
1741
|
-
* @see {@link https://
|
|
1767
|
+
* @see {@link https://deepbox.dev/docs/ml-neighbors | Deepbox Nearest Neighbors}
|
|
1768
|
+
* @see {@link https://deepbox.dev/docs/ml-neighbors | Deepbox Nearest Neighbors}
|
|
1742
1769
|
*/
|
|
1743
1770
|
declare class KNeighborsClassifier extends KNeighborsBase implements Classifier {
|
|
1744
1771
|
/**
|
|
@@ -1805,7 +1832,7 @@ declare class KNeighborsClassifier extends KNeighborsBase implements Classifier
|
|
|
1805
1832
|
* const predictions = knn.predict(tensor([[1.5]]));
|
|
1806
1833
|
* ```
|
|
1807
1834
|
*
|
|
1808
|
-
* @see {@link https://
|
|
1835
|
+
* @see {@link https://deepbox.dev/docs/ml-neighbors | Deepbox Nearest Neighbors}
|
|
1809
1836
|
*/
|
|
1810
1837
|
declare class KNeighborsRegressor extends KNeighborsBase implements Regressor {
|
|
1811
1838
|
/**
|
|
@@ -1868,7 +1895,7 @@ declare class KNeighborsRegressor extends KNeighborsBase implements Regressor {
|
|
|
1868
1895
|
* const predictions = svm.predict(X);
|
|
1869
1896
|
* ```
|
|
1870
1897
|
*
|
|
1871
|
-
* @see {@link https://
|
|
1898
|
+
* @see {@link https://deepbox.dev/docs/ml-svm | Deepbox SVM}
|
|
1872
1899
|
*/
|
|
1873
1900
|
declare class LinearSVC implements Classifier {
|
|
1874
1901
|
/** Regularization parameter (inverse of regularization strength) */
|
|
@@ -1877,13 +1904,13 @@ declare class LinearSVC implements Classifier {
|
|
|
1877
1904
|
private readonly maxIter;
|
|
1878
1905
|
/** Tolerance for stopping criterion */
|
|
1879
1906
|
private readonly tol;
|
|
1880
|
-
/**
|
|
1881
|
-
private
|
|
1882
|
-
/**
|
|
1883
|
-
private
|
|
1907
|
+
/** Per-class weight vectors (OvR for multiclass, single for binary) */
|
|
1908
|
+
private weightsPerClass;
|
|
1909
|
+
/** Per-class bias terms */
|
|
1910
|
+
private biasPerClass;
|
|
1884
1911
|
/** Number of features seen during fit */
|
|
1885
1912
|
private nFeatures;
|
|
1886
|
-
/** Unique class labels
|
|
1913
|
+
/** Unique class labels */
|
|
1887
1914
|
private classLabels;
|
|
1888
1915
|
/** Whether the model has been fitted */
|
|
1889
1916
|
private fitted;
|
|
@@ -1900,19 +1927,27 @@ declare class LinearSVC implements Classifier {
|
|
|
1900
1927
|
readonly maxIter?: number;
|
|
1901
1928
|
readonly tol?: number;
|
|
1902
1929
|
});
|
|
1930
|
+
/**
|
|
1931
|
+
* Fit a single binary SVM using sub-gradient descent on hinge loss.
|
|
1932
|
+
* Maps labels to {-1, +1} and returns learned weights + bias.
|
|
1933
|
+
*/
|
|
1934
|
+
private fitBinary;
|
|
1935
|
+
/**
|
|
1936
|
+
* Compute decision value for a single binary classifier.
|
|
1937
|
+
*/
|
|
1938
|
+
private decisionBinary;
|
|
1903
1939
|
/**
|
|
1904
1940
|
* Fit the SVM classifier using sub-gradient descent.
|
|
1905
1941
|
*
|
|
1906
|
-
*
|
|
1907
|
-
* Objective: minimize (1/2)||w||² + C * Σmax(0, 1 - y_i(w · x_i + b))
|
|
1942
|
+
* Supports both binary and multiclass classification (via OvR).
|
|
1908
1943
|
*
|
|
1909
1944
|
* @param X - Training data of shape (n_samples, n_features)
|
|
1910
|
-
* @param y - Target labels of shape (n_samples,). Must contain
|
|
1945
|
+
* @param y - Target labels of shape (n_samples,). Must contain at least 2 classes.
|
|
1911
1946
|
* @returns this - The fitted estimator
|
|
1912
1947
|
* @throws {ShapeError} If X is not 2D or y is not 1D
|
|
1913
1948
|
* @throws {ShapeError} If X and y have different number of samples
|
|
1914
1949
|
* @throws {DataValidationError} If X or y contain NaN/Inf values
|
|
1915
|
-
* @throws {InvalidParameterError} If y does not contain
|
|
1950
|
+
* @throws {InvalidParameterError} If y does not contain at least 2 classes
|
|
1916
1951
|
*/
|
|
1917
1952
|
fit(X: Tensor, y: Tensor): this;
|
|
1918
1953
|
/**
|
|
@@ -1954,12 +1989,12 @@ declare class LinearSVC implements Classifier {
|
|
|
1954
1989
|
*/
|
|
1955
1990
|
get coef(): Tensor;
|
|
1956
1991
|
/**
|
|
1957
|
-
* Get the bias
|
|
1992
|
+
* Get the bias terms.
|
|
1958
1993
|
*
|
|
1959
|
-
* @returns Bias
|
|
1994
|
+
* @returns Bias values as tensor
|
|
1960
1995
|
* @throws {NotFittedError} If the model has not been fitted
|
|
1961
1996
|
*/
|
|
1962
|
-
get intercept():
|
|
1997
|
+
get intercept(): Tensor;
|
|
1963
1998
|
/**
|
|
1964
1999
|
* Get hyperparameters for this estimator.
|
|
1965
2000
|
*
|
|
@@ -1992,7 +2027,7 @@ declare class LinearSVC implements Classifier {
|
|
|
1992
2027
|
* const predictions = svr.predict(X);
|
|
1993
2028
|
* ```
|
|
1994
2029
|
*
|
|
1995
|
-
* @see {@link https://
|
|
2030
|
+
* @see {@link https://deepbox.dev/docs/ml-svm | Deepbox SVM}
|
|
1996
2031
|
*/
|
|
1997
2032
|
declare class LinearSVR implements Regressor {
|
|
1998
2033
|
/** Regularization parameter */
|
|
@@ -2088,7 +2123,7 @@ declare class LinearSVR implements Regressor {
|
|
|
2088
2123
|
* const predictions = clf.predict(X);
|
|
2089
2124
|
* ```
|
|
2090
2125
|
*
|
|
2091
|
-
* @see {@link https://
|
|
2126
|
+
* @see {@link https://deepbox.dev/docs/ml-tree | Deepbox Decision Trees}
|
|
2092
2127
|
*/
|
|
2093
2128
|
declare class DecisionTreeClassifier implements Classifier {
|
|
2094
2129
|
private maxDepth;
|
|
@@ -2182,7 +2217,7 @@ declare class DecisionTreeClassifier implements Classifier {
|
|
|
2182
2217
|
*
|
|
2183
2218
|
* Uses MSE reduction to find optimal splits for regression tasks.
|
|
2184
2219
|
*
|
|
2185
|
-
* @see {@link https://
|
|
2220
|
+
* @see {@link https://deepbox.dev/docs/ml-tree | Deepbox Decision Trees}
|
|
2186
2221
|
*/
|
|
2187
2222
|
declare class DecisionTreeRegressor implements Regressor {
|
|
2188
2223
|
private maxDepth;
|
|
@@ -2275,7 +2310,7 @@ declare class DecisionTreeRegressor implements Regressor {
|
|
|
2275
2310
|
* const predictions = clf.predict(X_test);
|
|
2276
2311
|
* ```
|
|
2277
2312
|
*
|
|
2278
|
-
* @see {@link https://
|
|
2313
|
+
* @see {@link https://deepbox.dev/docs/ml-tree | Deepbox Decision Trees}
|
|
2279
2314
|
*/
|
|
2280
2315
|
declare class RandomForestClassifier implements Classifier {
|
|
2281
2316
|
private readonly nEstimators;
|
|
@@ -2389,7 +2424,7 @@ declare class RandomForestClassifier implements Classifier {
|
|
|
2389
2424
|
* const predictions = reg.predict(X_test);
|
|
2390
2425
|
* ```
|
|
2391
2426
|
*
|
|
2392
|
-
* @see {@link https://
|
|
2427
|
+
* @see {@link https://deepbox.dev/docs/ml-tree | Deepbox Decision Trees}
|
|
2393
2428
|
*/
|
|
2394
2429
|
declare class RandomForestRegressor implements Regressor {
|
|
2395
2430
|
private readonly nEstimators;
|
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
import { A as AnyTensor } from './index-
|
|
2
|
-
import { T as Tensor } from './Tensor-
|
|
1
|
+
import { A as AnyTensor } from './index-DLdiQzf0.cjs';
|
|
2
|
+
import { T as Tensor } from './Tensor-fxBg-TFZ.cjs';
|
|
3
3
|
|
|
4
4
|
/**
|
|
5
5
|
* @internal
|
|
@@ -676,7 +676,9 @@ declare function barh(y: Tensor, width: Tensor, options?: PlotOptions): void;
|
|
|
676
676
|
/**
|
|
677
677
|
* Plot a histogram on the current axes.
|
|
678
678
|
*/
|
|
679
|
-
declare function hist(x: Tensor, bins?: number
|
|
679
|
+
declare function hist(x: Tensor, bins?: number | (PlotOptions & {
|
|
680
|
+
bins?: number;
|
|
681
|
+
}), options?: PlotOptions): void;
|
|
680
682
|
/**
|
|
681
683
|
* Plot a box-and-whisker summary on the current axes.
|
|
682
684
|
*/
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { T as Tensor } from './Tensor-
|
|
1
|
+
import { T as Tensor } from './Tensor-BORFp_zt.js';
|
|
2
2
|
|
|
3
3
|
type DataLoaderOptions = {
|
|
4
4
|
batchSize?: number;
|
|
@@ -9,7 +9,7 @@ type DataLoaderOptions = {
|
|
|
9
9
|
/**
|
|
10
10
|
* Data loader for batching and shuffling datasets.
|
|
11
11
|
*
|
|
12
|
-
*
|
|
12
|
+
* Provides efficient iteration over datasets with support for
|
|
13
13
|
* batching, shuffling, and deterministic reproducibility.
|
|
14
14
|
*
|
|
15
15
|
* @remarks
|
|
@@ -58,7 +58,7 @@ type DataLoaderOptions = {
|
|
|
58
58
|
* }
|
|
59
59
|
* ```
|
|
60
60
|
*
|
|
61
|
-
* @see {@link https://
|
|
61
|
+
* @see {@link https://deepbox.dev/docs/datasets-dataloader | Deepbox DataLoader}
|
|
62
62
|
*/
|
|
63
63
|
declare class DataLoader<TTarget extends Tensor | undefined = undefined> {
|
|
64
64
|
private X;
|
|
@@ -75,8 +75,7 @@ declare class DataLoader<TTarget extends Tensor | undefined = undefined> {
|
|
|
75
75
|
* Number of batches in the data loader.
|
|
76
76
|
*/
|
|
77
77
|
get length(): number;
|
|
78
|
-
[Symbol.iterator](
|
|
79
|
-
[Symbol.iterator](this: DataLoader<undefined>): IterableIterator<[Tensor]>;
|
|
78
|
+
[Symbol.iterator](): IterableIterator<TTarget extends Tensor ? [Tensor, Tensor] : [Tensor]>;
|
|
80
79
|
private prepareIteration;
|
|
81
80
|
private iterateX;
|
|
82
81
|
private iterateXY;
|
|
@@ -98,7 +97,7 @@ declare class DataLoader<TTarget extends Tensor | undefined = undefined> {
|
|
|
98
97
|
* @param options.randomState - Seed for reproducibility.
|
|
99
98
|
* @returns A tuple `[X, y]` where X has shape `[nSamples, nFeatures]` and y has shape `[nSamples]` with dtype `int32`.
|
|
100
99
|
*
|
|
101
|
-
* @see {@link https://
|
|
100
|
+
* @see {@link https://deepbox.dev/docs/datasets-synthetic | Deepbox Synthetic Datasets}
|
|
102
101
|
*/
|
|
103
102
|
declare function makeClassification(options?: {
|
|
104
103
|
nSamples?: number;
|
|
@@ -106,6 +105,7 @@ declare function makeClassification(options?: {
|
|
|
106
105
|
nInformative?: number;
|
|
107
106
|
nRedundant?: number;
|
|
108
107
|
nClasses?: number;
|
|
108
|
+
flipY?: number;
|
|
109
109
|
randomState?: number;
|
|
110
110
|
}): [Tensor, Tensor];
|
|
111
111
|
/**
|
|
@@ -121,7 +121,7 @@ declare function makeClassification(options?: {
|
|
|
121
121
|
* @param options.randomState - Seed for reproducibility.
|
|
122
122
|
* @returns A tuple `[X, y]` where X has shape `[nSamples, nFeatures]` and y has shape `[nSamples]`.
|
|
123
123
|
*
|
|
124
|
-
* @see {@link https://
|
|
124
|
+
* @see {@link https://deepbox.dev/docs/datasets-synthetic | Deepbox Synthetic Datasets}
|
|
125
125
|
*/
|
|
126
126
|
declare function makeRegression(options?: {
|
|
127
127
|
nSamples?: number;
|
|
@@ -144,7 +144,7 @@ declare function makeRegression(options?: {
|
|
|
144
144
|
* @param options.randomState - Seed for reproducibility.
|
|
145
145
|
* @returns A tuple `[X, y]` where X has shape `[nSamples, nFeatures]` and y has shape `[nSamples]` with dtype `int32`.
|
|
146
146
|
*
|
|
147
|
-
* @see {@link https://
|
|
147
|
+
* @see {@link https://deepbox.dev/docs/datasets-synthetic | Deepbox Synthetic Datasets}
|
|
148
148
|
*/
|
|
149
149
|
declare function makeBlobs(options?: {
|
|
150
150
|
nSamples?: number;
|
|
@@ -166,7 +166,7 @@ declare function makeBlobs(options?: {
|
|
|
166
166
|
* @param options.randomState - Seed for reproducibility.
|
|
167
167
|
* @returns A tuple `[X, y]` where X has shape `[nSamples, 2]` and y has shape `[nSamples]` with dtype `int32`.
|
|
168
168
|
*
|
|
169
|
-
* @see {@link https://
|
|
169
|
+
* @see {@link https://deepbox.dev/docs/datasets-synthetic | Deepbox Synthetic Datasets}
|
|
170
170
|
*/
|
|
171
171
|
declare function makeMoons(options?: {
|
|
172
172
|
nSamples?: number;
|
|
@@ -187,7 +187,7 @@ declare function makeMoons(options?: {
|
|
|
187
187
|
* @param options.randomState - Seed for reproducibility.
|
|
188
188
|
* @returns A tuple `[X, y]` where X has shape `[nSamples, 2]` and y has shape `[nSamples]` with dtype `int32`.
|
|
189
189
|
*
|
|
190
|
-
* @see {@link https://
|
|
190
|
+
* @see {@link https://deepbox.dev/docs/datasets-synthetic | Deepbox Synthetic Datasets}
|
|
191
191
|
*/
|
|
192
192
|
declare function makeCircles(options?: {
|
|
193
193
|
nSamples?: number;
|
|
@@ -209,7 +209,7 @@ declare function makeCircles(options?: {
|
|
|
209
209
|
* @param options.randomState - Seed for reproducibility.
|
|
210
210
|
* @returns A tuple `[X, y]` where X has shape `[nSamples, nFeatures]` and y has shape `[nSamples]` with dtype `int32`.
|
|
211
211
|
*
|
|
212
|
-
* @see {@link https://
|
|
212
|
+
* @see {@link https://deepbox.dev/docs/datasets-synthetic | Deepbox Synthetic Datasets}
|
|
213
213
|
*/
|
|
214
214
|
declare function makeGaussianQuantiles(options?: {
|
|
215
215
|
nSamples?: number;
|
|
@@ -224,6 +224,7 @@ type Dataset = {
|
|
|
224
224
|
featureNames: string[];
|
|
225
225
|
targetNames?: string[];
|
|
226
226
|
description: string;
|
|
227
|
+
images?: Tensor;
|
|
227
228
|
};
|
|
228
229
|
/**
|
|
229
230
|
* Load the synthetic Iris dataset.
|
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
import { A as AnyTensor } from './index-
|
|
2
|
-
import { D as DType, a as Device, A as Axis } from './tensor-
|
|
3
|
-
import { G as GradTensor } from './index-
|
|
4
|
-
import { T as Tensor } from './Tensor-
|
|
1
|
+
import { A as AnyTensor } from './index-DLdiQzf0.cjs';
|
|
2
|
+
import { D as DType, a as Device, A as Axis } from './tensor-IlVTF0bz.cjs';
|
|
3
|
+
import { G as GradTensor } from './index-GUHYEhxs.cjs';
|
|
4
|
+
import { T as Tensor } from './Tensor-fxBg-TFZ.cjs';
|
|
5
5
|
|
|
6
6
|
type StateEntry = {
|
|
7
7
|
data: Array<number | string | bigint>;
|
|
@@ -31,7 +31,7 @@ type ForwardHook = (module: Module, inputs: AnyTensor[], output: AnyTensor) => A
|
|
|
31
31
|
* All models should subclass this class. Modules can contain other modules,
|
|
32
32
|
* allowing to nest them in a tree structure.
|
|
33
33
|
*
|
|
34
|
-
*
|
|
34
|
+
* { https://deepbox.dev/docs/nn-module | Deepbox Module & Sequential}
|
|
35
35
|
*
|
|
36
36
|
* @example
|
|
37
37
|
* ```ts
|
|
@@ -63,7 +63,7 @@ type ForwardHook = (module: Module, inputs: AnyTensor[], output: AnyTensor) => A
|
|
|
63
63
|
* ```
|
|
64
64
|
*
|
|
65
65
|
* References:
|
|
66
|
-
* -
|
|
66
|
+
* - Deepbox Module: https://deepbox.dev/docs/nn-module
|
|
67
67
|
*
|
|
68
68
|
* @category Neural Networks
|
|
69
69
|
*/
|
|
@@ -362,7 +362,7 @@ declare abstract class Module {
|
|
|
362
362
|
* ```
|
|
363
363
|
*
|
|
364
364
|
* References:
|
|
365
|
-
* -
|
|
365
|
+
* - Deepbox Sequential: https://deepbox.dev/docs/nn-module
|
|
366
366
|
* - Keras Sequential: https://keras.io/guides/sequential_model/
|
|
367
367
|
*
|
|
368
368
|
* @category Neural Network Containers
|
|
@@ -580,7 +580,7 @@ declare class Mish extends Module {
|
|
|
580
580
|
* const output = mha.forward(x, x, x);
|
|
581
581
|
* ```
|
|
582
582
|
*
|
|
583
|
-
* @see {@link https://
|
|
583
|
+
* @see {@link https://deepbox.dev/docs/nn-attention | Deepbox Attention}
|
|
584
584
|
* @see Vaswani et al. (2017) "Attention Is All You Need"
|
|
585
585
|
*/
|
|
586
586
|
declare class MultiheadAttention extends Module {
|
|
@@ -651,7 +651,7 @@ declare class MultiheadAttention extends Module {
|
|
|
651
651
|
* const output = layer.forward(x);
|
|
652
652
|
* ```
|
|
653
653
|
*
|
|
654
|
-
* @see {@link https://
|
|
654
|
+
* @see {@link https://deepbox.dev/docs/nn-attention | Deepbox Attention}
|
|
655
655
|
*/
|
|
656
656
|
declare class TransformerEncoderLayer extends Module {
|
|
657
657
|
private readonly dModel;
|
|
@@ -666,7 +666,19 @@ declare class TransformerEncoderLayer extends Module {
|
|
|
666
666
|
private readonly dropout1;
|
|
667
667
|
private readonly dropout2;
|
|
668
668
|
private readonly dropout3;
|
|
669
|
-
constructor(
|
|
669
|
+
constructor(dModelOrOpts: number | {
|
|
670
|
+
readonly dModel: number;
|
|
671
|
+
readonly nHead: number;
|
|
672
|
+
readonly dimFeedforward?: number;
|
|
673
|
+
readonly dFF?: number;
|
|
674
|
+
readonly dropout?: number;
|
|
675
|
+
readonly eps?: number;
|
|
676
|
+
}, nHead?: number, dFFOrOptions?: number | {
|
|
677
|
+
readonly dimFeedforward?: number;
|
|
678
|
+
readonly dFF?: number;
|
|
679
|
+
readonly dropout?: number;
|
|
680
|
+
readonly eps?: number;
|
|
681
|
+
}, options?: {
|
|
670
682
|
readonly dropout?: number;
|
|
671
683
|
readonly eps?: number;
|
|
672
684
|
});
|
|
@@ -692,7 +704,7 @@ declare class TransformerEncoderLayer extends Module {
|
|
|
692
704
|
* const conv = new Conv1d(16, 33, 3); // in_channels=16, out_channels=33, kernel_size=3
|
|
693
705
|
* ```
|
|
694
706
|
*
|
|
695
|
-
* @see {@link https://
|
|
707
|
+
* @see {@link https://deepbox.dev/docs/nn-layers | Deepbox Layers}
|
|
696
708
|
*/
|
|
697
709
|
declare class Conv1d extends Module {
|
|
698
710
|
private readonly inChannels;
|
|
@@ -724,7 +736,7 @@ declare class Conv1d extends Module {
|
|
|
724
736
|
* const conv = new Conv2d(3, 64, 3); // RGB to 64 channels, 3x3 kernel
|
|
725
737
|
* ```
|
|
726
738
|
*
|
|
727
|
-
* @see {@link https://
|
|
739
|
+
* @see {@link https://deepbox.dev/docs/nn-layers | Deepbox Layers}
|
|
728
740
|
*/
|
|
729
741
|
declare class Conv2d extends Module {
|
|
730
742
|
private readonly inChannels;
|
|
@@ -756,7 +768,7 @@ declare class Conv2d extends Module {
|
|
|
756
768
|
* const pool = new MaxPool2d(2); // 2x2 pooling
|
|
757
769
|
* ```
|
|
758
770
|
*
|
|
759
|
-
* @see {@link https://
|
|
771
|
+
* @see {@link https://deepbox.dev/docs/nn-layers | Deepbox Layers}
|
|
760
772
|
*/
|
|
761
773
|
declare class MaxPool2d extends Module {
|
|
762
774
|
private readonly kernelSizeValue;
|
|
@@ -780,7 +792,7 @@ declare class MaxPool2d extends Module {
|
|
|
780
792
|
* const pool = new AvgPool2d(2); // 2x2 pooling
|
|
781
793
|
* ```
|
|
782
794
|
*
|
|
783
|
-
* @see {@link https://
|
|
795
|
+
* @see {@link https://deepbox.dev/docs/nn-layers | Deepbox Layers}
|
|
784
796
|
*/
|
|
785
797
|
declare class AvgPool2d extends Module {
|
|
786
798
|
private readonly kernelSizeValue;
|
|
@@ -836,7 +848,7 @@ declare class AvgPool2d extends Module {
|
|
|
836
848
|
*
|
|
837
849
|
* References:
|
|
838
850
|
* - Dropout paper: https://jmlr.org/papers/v15/srivastava14a.html
|
|
839
|
-
* -
|
|
851
|
+
* - Deepbox Dropout: https://deepbox.dev/docs/nn-normalization
|
|
840
852
|
*
|
|
841
853
|
* @category Neural Network Layers
|
|
842
854
|
*/
|
|
@@ -923,7 +935,7 @@ declare class Dropout extends Module {
|
|
|
923
935
|
* ```
|
|
924
936
|
*
|
|
925
937
|
* References:
|
|
926
|
-
* -
|
|
938
|
+
* - Deepbox Linear: https://deepbox.dev/docs/nn-layers
|
|
927
939
|
* - Xavier/Glorot initialization: http://proceedings.mlr.press/v9/glorot10a.html
|
|
928
940
|
*
|
|
929
941
|
* @category Neural Network Layers
|
|
@@ -1014,7 +1026,7 @@ declare class Linear extends Module {
|
|
|
1014
1026
|
* const y = bn.forward(x);
|
|
1015
1027
|
* ```
|
|
1016
1028
|
*
|
|
1017
|
-
* @see {@link https://
|
|
1029
|
+
* @see {@link https://deepbox.dev/docs/nn-normalization | Deepbox Normalization & Dropout}
|
|
1018
1030
|
*/
|
|
1019
1031
|
declare class BatchNorm1d extends Module {
|
|
1020
1032
|
private readonly numFeatures;
|
|
@@ -1054,7 +1066,7 @@ declare class BatchNorm1d extends Module {
|
|
|
1054
1066
|
* const y = ln.forward(x);
|
|
1055
1067
|
* ```
|
|
1056
1068
|
*
|
|
1057
|
-
* @see {@link https://
|
|
1069
|
+
* @see {@link https://deepbox.dev/docs/nn-normalization | Deepbox Normalization & Dropout}
|
|
1058
1070
|
*/
|
|
1059
1071
|
declare class LayerNorm extends Module {
|
|
1060
1072
|
private readonly normalizedShape;
|
|
@@ -1087,7 +1099,7 @@ declare class LayerNorm extends Module {
|
|
|
1087
1099
|
* const output = rnn.forward(x);
|
|
1088
1100
|
* ```
|
|
1089
1101
|
*
|
|
1090
|
-
* @see {@link https://
|
|
1102
|
+
* @see {@link https://deepbox.dev/docs/nn-recurrent | Deepbox Recurrent Layers}
|
|
1091
1103
|
*/
|
|
1092
1104
|
declare class RNN extends Module {
|
|
1093
1105
|
private readonly inputSize;
|
|
@@ -1129,7 +1141,7 @@ declare class RNN extends Module {
|
|
|
1129
1141
|
* - Cell state: c_t = f_t * c_{t-1} + i_t * g_t
|
|
1130
1142
|
* - Hidden state: h_t = o_t * tanh(c_t)
|
|
1131
1143
|
*
|
|
1132
|
-
* @see {@link https://
|
|
1144
|
+
* @see {@link https://deepbox.dev/docs/nn-recurrent | Deepbox Recurrent Layers}
|
|
1133
1145
|
*/
|
|
1134
1146
|
declare class LSTM extends Module {
|
|
1135
1147
|
private readonly inputSize;
|
|
@@ -1167,7 +1179,7 @@ declare class LSTM extends Module {
|
|
|
1167
1179
|
* - New gate: n_t = tanh(W_in * x_t + b_in + r_t * (W_hn * h_{t-1} + b_hn))
|
|
1168
1180
|
* - Hidden: h_t = (1 - z_t) * n_t + z_t * h_{t-1}
|
|
1169
1181
|
*
|
|
1170
|
-
* @see {@link https://
|
|
1182
|
+
* @see {@link https://deepbox.dev/docs/nn-recurrent | Deepbox Recurrent Layers}
|
|
1171
1183
|
*/
|
|
1172
1184
|
declare class GRU extends Module {
|
|
1173
1185
|
private readonly inputSize;
|
|
@@ -1223,6 +1235,7 @@ declare class GRU extends Module {
|
|
|
1223
1235
|
*/
|
|
1224
1236
|
declare function crossEntropyLoss(input: Tensor, target: Tensor): number;
|
|
1225
1237
|
declare function crossEntropyLoss(input: GradTensor, target: AnyTensor): GradTensor;
|
|
1238
|
+
declare function crossEntropyLoss(input: AnyTensor, target: AnyTensor): number | GradTensor;
|
|
1226
1239
|
/**
|
|
1227
1240
|
* Binary Cross Entropy Loss with logits.
|
|
1228
1241
|
*
|
|
@@ -1271,6 +1284,7 @@ declare function binaryCrossEntropyWithLogitsLoss(input: GradTensor, target: Any
|
|
|
1271
1284
|
* @category Loss Functions
|
|
1272
1285
|
*/
|
|
1273
1286
|
declare function mseLoss(predictions: Tensor, targets: Tensor, reduction?: "mean" | "sum" | "none"): Tensor;
|
|
1287
|
+
declare function mseLoss(predictions: GradTensor, targets: GradTensor, reduction?: "mean" | "sum" | "none"): GradTensor;
|
|
1274
1288
|
/**
|
|
1275
1289
|
* Mean Absolute Error (MAE) loss function, also known as L1 loss.
|
|
1276
1290
|
*
|