deepbox 0.1.0 → 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (165) hide show
  1. package/LICENSE +1 -1
  2. package/README.md +39 -37
  3. package/dist/{CSRMatrix-KzNt6QpS.d.ts → CSRMatrix-B7XtUAZO.d.cts} +3 -3
  4. package/dist/{CSRMatrix-CwGwQRea.d.cts → CSRMatrix-CtD23fRM.d.ts} +3 -3
  5. package/dist/{Tensor-BQLk1ltW.d.cts → Tensor-BORFp_zt.d.ts} +27 -7
  6. package/dist/{Tensor-g8mUClel.d.ts → Tensor-fxBg-TFZ.d.cts} +27 -7
  7. package/dist/{chunk-FJYLIGJX.js → chunk-3AX37GPK.js} +33 -7
  8. package/dist/chunk-3AX37GPK.js.map +1 -0
  9. package/dist/{chunk-PR647I7R.js → chunk-3YFEYDHN.js} +21 -4
  10. package/dist/chunk-3YFEYDHN.js.map +1 -0
  11. package/dist/{chunk-XMWVME2W.js → chunk-6SX26MAJ.js} +4 -4
  12. package/dist/{chunk-XMWVME2W.js.map → chunk-6SX26MAJ.js.map} +1 -1
  13. package/dist/{chunk-C4PKXY74.cjs → chunk-6X7XFNDO.cjs} +94 -77
  14. package/dist/chunk-6X7XFNDO.cjs.map +1 -0
  15. package/dist/{chunk-6AE5FKKQ.cjs → chunk-724CXHFH.cjs} +1211 -919
  16. package/dist/chunk-724CXHFH.cjs.map +1 -0
  17. package/dist/{chunk-AU7XHGKJ.js → chunk-AJTKVBY5.js} +4 -4
  18. package/dist/{chunk-AU7XHGKJ.js.map → chunk-AJTKVBY5.js.map} +1 -1
  19. package/dist/{chunk-ZB75FESB.cjs → chunk-AV6WGSYX.cjs} +130 -104
  20. package/dist/chunk-AV6WGSYX.cjs.map +1 -0
  21. package/dist/{chunk-ZLW62TJG.cjs → chunk-BWOSU234.cjs} +142 -141
  22. package/dist/chunk-BWOSU234.cjs.map +1 -0
  23. package/dist/{chunk-4S73VUBD.js → chunk-CZOMBUI7.js} +3 -3
  24. package/dist/chunk-CZOMBUI7.js.map +1 -0
  25. package/dist/{chunk-QERHVCHC.cjs → chunk-EUZHJDZ6.cjs} +419 -364
  26. package/dist/chunk-EUZHJDZ6.cjs.map +1 -0
  27. package/dist/{chunk-AD436M45.js → chunk-G2G55ATL.js} +120 -58
  28. package/dist/chunk-G2G55ATL.js.map +1 -0
  29. package/dist/{chunk-5R4S63PF.js → chunk-G3WNLNYS.js} +119 -64
  30. package/dist/chunk-G3WNLNYS.js.map +1 -0
  31. package/dist/{chunk-XEG44RF6.cjs → chunk-G7KXZHG6.cjs} +105 -95
  32. package/dist/chunk-G7KXZHG6.cjs.map +1 -0
  33. package/dist/{chunk-MLBMYKCG.js → chunk-H3JR7SV2.js} +255 -113
  34. package/dist/chunk-H3JR7SV2.js.map +1 -0
  35. package/dist/{chunk-PHV2DKRS.cjs → chunk-HDKMIG6E.cjs} +107 -107
  36. package/dist/{chunk-PHV2DKRS.cjs.map → chunk-HDKMIG6E.cjs.map} +1 -1
  37. package/dist/{chunk-ALS7ETWZ.cjs → chunk-HI2EZHCJ.cjs} +111 -102
  38. package/dist/chunk-HI2EZHCJ.cjs.map +1 -0
  39. package/dist/{chunk-OX6QXFMV.cjs → chunk-IT4BZUYE.cjs} +490 -428
  40. package/dist/chunk-IT4BZUYE.cjs.map +1 -0
  41. package/dist/{chunk-E3EU5FZO.cjs → chunk-JTZPRV6E.cjs} +123 -123
  42. package/dist/{chunk-E3EU5FZO.cjs.map → chunk-JTZPRV6E.cjs.map} +1 -1
  43. package/dist/{chunk-PL7TAYKI.js → chunk-K2L5C5YH.js} +8 -7
  44. package/dist/chunk-K2L5C5YH.js.map +1 -0
  45. package/dist/{chunk-BCR7G3A6.js → chunk-KCF6P34A.js} +356 -64
  46. package/dist/chunk-KCF6P34A.js.map +1 -0
  47. package/dist/{chunk-ZXKBDFP3.js → chunk-LZHVHD62.js} +15 -6
  48. package/dist/chunk-LZHVHD62.js.map +1 -0
  49. package/dist/{chunk-LWECRCW2.cjs → chunk-MTJF52AJ.cjs} +141 -141
  50. package/dist/{chunk-LWECRCW2.cjs.map → chunk-MTJF52AJ.cjs.map} +1 -1
  51. package/dist/{chunk-B5TNKUEY.js → chunk-NDDTUFKK.js} +16 -6
  52. package/dist/chunk-NDDTUFKK.js.map +1 -0
  53. package/dist/{chunk-DWZY6PIP.cjs → chunk-NOQI6OFL.cjs} +615 -473
  54. package/dist/chunk-NOQI6OFL.cjs.map +1 -0
  55. package/dist/{chunk-F3JWBINJ.js → chunk-OEXDJFHA.js} +4 -4
  56. package/dist/{chunk-F3JWBINJ.js.map → chunk-OEXDJFHA.js.map} +1 -1
  57. package/dist/{chunk-JSCDE774.cjs → chunk-Z6BGACIH.cjs} +3 -3
  58. package/dist/chunk-Z6BGACIH.cjs.map +1 -0
  59. package/dist/core/index.cjs +50 -50
  60. package/dist/core/index.d.cts +2 -2
  61. package/dist/core/index.d.ts +2 -2
  62. package/dist/core/index.js +1 -1
  63. package/dist/dataframe/index.cjs +6 -6
  64. package/dist/dataframe/index.d.cts +3 -3
  65. package/dist/dataframe/index.d.ts +3 -3
  66. package/dist/dataframe/index.js +3 -3
  67. package/dist/datasets/index.cjs +34 -34
  68. package/dist/datasets/index.d.cts +3 -3
  69. package/dist/datasets/index.d.ts +3 -3
  70. package/dist/datasets/index.js +3 -3
  71. package/dist/{index-C1mfVYoo.d.ts → index-B18dHc8q.d.ts} +81 -46
  72. package/dist/{index-GFAVyOWO.d.ts → index-BHHX0qTY.d.cts} +14 -12
  73. package/dist/{index-tk4lSYod.d.ts → index-BI6QOUvV.d.ts} +106 -80
  74. package/dist/{index-DIp_RrRt.d.ts → index-BKvK21lf.d.ts} +13 -35
  75. package/dist/{index-BJY2SI4i.d.ts → index-BL8jLf3K.d.cts} +12 -11
  76. package/dist/{index-Cn3SdB0O.d.ts → index-BNbX167d.d.cts} +16 -10
  77. package/dist/{index-BWGhrDlr.d.ts → index-BT2ofL7Z.d.cts} +35 -35
  78. package/dist/{index-BbA2Gxfl.d.ts → index-BqcfIcL4.d.ts} +15 -15
  79. package/dist/{index-ZtI1Iy4L.d.ts → index-BrgrECM2.d.ts} +41 -38
  80. package/dist/{index-CDw5CnOU.d.ts → index-BtYKI9yJ.d.ts} +10 -8
  81. package/dist/{index-DIT_OO9C.d.cts → index-C7nLsAOC.d.cts} +10 -8
  82. package/dist/{index-D9Loo1_A.d.cts → index-CNj2Mxwf.d.cts} +81 -46
  83. package/dist/{index-DmEg_LCm.d.cts → index-CYlxeNW1.d.cts} +5 -3
  84. package/dist/{index-D61yaSMY.d.cts → index-CiTd61a5.d.ts} +12 -11
  85. package/dist/{index-BndMbqsM.d.ts → index-Cjnn0KeN.d.cts} +35 -21
  86. package/dist/{index-9oQx1HgV.d.cts → index-CkGGAn69.d.cts} +41 -38
  87. package/dist/{index-74AB8Cyh.d.cts → index-D4URSgqA.d.ts} +16 -10
  88. package/dist/{index-DoPWVxPo.d.cts → index-D4pn5zLT.d.ts} +35 -21
  89. package/dist/{index-DuCxd-8d.d.ts → index-D9ztTlDr.d.ts} +60 -42
  90. package/dist/{index-BgHYAoSS.d.cts → index-DF28ZPB5.d.cts} +60 -42
  91. package/dist/{index-eJgeni9c.d.cts → index-DLdiQzf0.d.cts} +106 -80
  92. package/dist/{index-WHQLn0e8.d.cts → index-DN4omPQw.d.ts} +35 -35
  93. package/dist/{index-CrqLlS-a.d.ts → index-DUnFq1WV.d.ts} +5 -3
  94. package/dist/{index-DbultU6X.d.cts → index-DgaYshkF.d.ts} +14 -12
  95. package/dist/{index-B_DK4FKY.d.cts → index-GUHYEhxs.d.cts} +13 -35
  96. package/dist/{index-CCvlwAmL.d.cts → index-TP--4irE.d.cts} +16 -14
  97. package/dist/{index-Dx42TZaY.d.ts → index-x0z_sanT.d.ts} +16 -14
  98. package/dist/{index-DyZ4QQf5.d.cts → index-xWH7ujWa.d.cts} +15 -15
  99. package/dist/index.cjs +26 -26
  100. package/dist/index.d.cts +17 -17
  101. package/dist/index.d.ts +17 -17
  102. package/dist/index.js +13 -13
  103. package/dist/linalg/index.cjs +22 -22
  104. package/dist/linalg/index.d.cts +3 -3
  105. package/dist/linalg/index.d.ts +3 -3
  106. package/dist/linalg/index.js +3 -3
  107. package/dist/metrics/index.cjs +40 -40
  108. package/dist/metrics/index.d.cts +3 -3
  109. package/dist/metrics/index.d.ts +3 -3
  110. package/dist/metrics/index.js +3 -3
  111. package/dist/ml/index.cjs +23 -23
  112. package/dist/ml/index.d.cts +3 -3
  113. package/dist/ml/index.d.ts +3 -3
  114. package/dist/ml/index.js +4 -4
  115. package/dist/ndarray/index.cjs +125 -125
  116. package/dist/ndarray/index.d.cts +5 -5
  117. package/dist/ndarray/index.d.ts +5 -5
  118. package/dist/ndarray/index.js +2 -2
  119. package/dist/nn/index.cjs +36 -36
  120. package/dist/nn/index.d.cts +6 -6
  121. package/dist/nn/index.d.ts +6 -6
  122. package/dist/nn/index.js +3 -3
  123. package/dist/optim/index.cjs +19 -19
  124. package/dist/optim/index.d.cts +4 -4
  125. package/dist/optim/index.d.ts +4 -4
  126. package/dist/optim/index.js +2 -2
  127. package/dist/plot/index.cjs +29 -29
  128. package/dist/plot/index.d.cts +6 -6
  129. package/dist/plot/index.d.ts +6 -6
  130. package/dist/plot/index.js +3 -3
  131. package/dist/preprocess/index.cjs +21 -21
  132. package/dist/preprocess/index.d.cts +4 -4
  133. package/dist/preprocess/index.d.ts +4 -4
  134. package/dist/preprocess/index.js +3 -3
  135. package/dist/random/index.cjs +19 -19
  136. package/dist/random/index.d.cts +3 -3
  137. package/dist/random/index.d.ts +3 -3
  138. package/dist/random/index.js +3 -3
  139. package/dist/stats/index.cjs +36 -36
  140. package/dist/stats/index.d.cts +3 -3
  141. package/dist/stats/index.d.ts +3 -3
  142. package/dist/stats/index.js +3 -3
  143. package/dist/{tensor-B96jjJLQ.d.cts → tensor-IlVTF0bz.d.cts} +16 -3
  144. package/dist/{tensor-B96jjJLQ.d.ts → tensor-IlVTF0bz.d.ts} +16 -3
  145. package/package.json +3 -2
  146. package/dist/chunk-4S73VUBD.js.map +0 -1
  147. package/dist/chunk-5R4S63PF.js.map +0 -1
  148. package/dist/chunk-6AE5FKKQ.cjs.map +0 -1
  149. package/dist/chunk-AD436M45.js.map +0 -1
  150. package/dist/chunk-ALS7ETWZ.cjs.map +0 -1
  151. package/dist/chunk-B5TNKUEY.js.map +0 -1
  152. package/dist/chunk-BCR7G3A6.js.map +0 -1
  153. package/dist/chunk-C4PKXY74.cjs.map +0 -1
  154. package/dist/chunk-DWZY6PIP.cjs.map +0 -1
  155. package/dist/chunk-FJYLIGJX.js.map +0 -1
  156. package/dist/chunk-JSCDE774.cjs.map +0 -1
  157. package/dist/chunk-MLBMYKCG.js.map +0 -1
  158. package/dist/chunk-OX6QXFMV.cjs.map +0 -1
  159. package/dist/chunk-PL7TAYKI.js.map +0 -1
  160. package/dist/chunk-PR647I7R.js.map +0 -1
  161. package/dist/chunk-QERHVCHC.cjs.map +0 -1
  162. package/dist/chunk-XEG44RF6.cjs.map +0 -1
  163. package/dist/chunk-ZB75FESB.cjs.map +0 -1
  164. package/dist/chunk-ZLW62TJG.cjs.map +0 -1
  165. package/dist/chunk-ZXKBDFP3.js.map +0 -1
@@ -1,14 +1,15 @@
1
- import { T as Tensor } from './Tensor-BQLk1ltW.cjs';
1
+ import { T as Tensor } from './Tensor-fxBg-TFZ.cjs';
2
+ import { S as Shape, c as ScalarDType } from './tensor-IlVTF0bz.cjs';
2
3
 
3
4
  /**
4
5
  * Base type for all estimators (models) in Deepbox.
5
6
  *
6
- * This follows the scikit-learn estimator type design.
7
+ * Base estimator type for all ML models.
7
8
  *
8
9
  * @template FitParams - Type of parameters passed to fit method
9
10
  *
10
11
  * References:
11
- * - scikit-learn API: https://scikit-learn.org/stable/developers/develop.html
12
+ * - Deepbox ML: https://deepbox.dev/docs/ml-linear
12
13
  */
13
14
  type Estimator<FitParams = void> = {
14
15
  /**
@@ -238,7 +239,7 @@ type OutlierDetector = Estimator<void> & {
238
239
  * // labels: [0, 0, 0, 1, 1, -1] (-1 = noise)
239
240
  * ```
240
241
  *
241
- * @see {@link https://scikit-learn.org/stable/modules/generated/sklearn.cluster.DBSCAN.html | scikit-learn DBSCAN}
242
+ * @see {@link https://deepbox.dev/docs/ml-clustering | Deepbox Clustering}
242
243
  */
243
244
  declare class DBSCAN implements Clusterer {
244
245
  private eps;
@@ -295,6 +296,13 @@ declare class DBSCAN implements Clusterer {
295
296
  * @throws {NotFittedError} If the model has not been fitted
296
297
  */
297
298
  get labels(): Tensor;
299
+ /**
300
+ * Number of clusters found (excluding noise).
301
+ *
302
+ * @returns Number of distinct clusters (labels >= 0)
303
+ * @throws {NotFittedError} If the model has not been fitted
304
+ */
305
+ get nClusters(): number;
298
306
  /**
299
307
  * Get indices of core samples discovered during fitting.
300
308
  *
@@ -352,8 +360,8 @@ declare class DBSCAN implements Clusterer {
352
360
  * console.log('Centroids:', kmeans.clusterCenters);
353
361
  * ```
354
362
  *
355
- * @see {@link https://en.wikipedia.org/wiki/K-means_clustering | Wikipedia: K-means}
356
- * @see {@link https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html | scikit-learn KMeans}
363
+ * @see {@link https://deepbox.dev/docs/ml-clustering | Deepbox Clustering}
364
+ * @see {@link https://deepbox.dev/docs/ml-clustering | Deepbox Clustering}
357
365
  */
358
366
  declare class KMeans implements Clusterer {
359
367
  private nClusters;
@@ -482,8 +490,8 @@ declare class KMeans implements Clusterer {
482
490
  * console.log('Explained variance ratio:', pca.explainedVarianceRatio);
483
491
  * ```
484
492
  *
485
- * @see {@link https://en.wikipedia.org/wiki/Principal_component_analysis | Wikipedia: PCA}
486
- * @see {@link https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.PCA.html | scikit-learn PCA}
493
+ * @see {@link https://deepbox.dev/docs/ml-decomposition | Deepbox Dimensionality Reduction}
494
+ * @see {@link https://deepbox.dev/docs/ml-decomposition | Deepbox Dimensionality Reduction}
487
495
  */
488
496
  declare class PCA implements Transformer {
489
497
  private readonly nComponents?;
@@ -590,7 +598,7 @@ declare class PCA implements Transformer {
590
598
  * const predictions = gbr.predict(X);
591
599
  * ```
592
600
  *
593
- * @see {@link https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.GradientBoostingRegressor.html | scikit-learn GradientBoostingRegressor}
601
+ * @see {@link https://deepbox.dev/docs/ml-ensemble | Deepbox Ensemble Methods}
594
602
  */
595
603
  declare class GradientBoostingRegressor implements Regressor {
596
604
  /** Number of boosting stages (trees) */
@@ -669,8 +677,10 @@ declare class GradientBoostingRegressor implements Regressor {
669
677
  /**
670
678
  * Gradient Boosting Classifier.
671
679
  *
672
- * Uses gradient boosting with shallow regression trees for binary classification.
673
- * Optimizes log loss (cross-entropy) using sigmoid function.
680
+ * Uses gradient boosting with shallow regression trees for classification.
681
+ * Supports both binary and multiclass classification.
682
+ * - Binary: optimizes log loss using sigmoid function.
683
+ * - Multiclass: uses One-vs-Rest (OvR) strategy, training one binary model per class.
674
684
  *
675
685
  * @example
676
686
  * ```ts
@@ -685,7 +695,7 @@ declare class GradientBoostingRegressor implements Regressor {
685
695
  * const predictions = gbc.predict(X);
686
696
  * ```
687
697
  *
688
- * @see {@link https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.GradientBoostingClassifier.html | scikit-learn GradientBoostingClassifier}
698
+ * @see {@link https://deepbox.dev/docs/ml-ensemble | Deepbox Ensemble Methods}
689
699
  */
690
700
  declare class GradientBoostingClassifier implements Classifier {
691
701
  /** Number of boosting stages */
@@ -696,10 +706,10 @@ declare class GradientBoostingClassifier implements Classifier {
696
706
  private maxDepth;
697
707
  /** Minimum samples to split */
698
708
  private minSamplesSplit;
699
- /** Array of weak learners */
700
- private estimators;
701
- /** Initial log-odds prediction */
702
- private initPrediction;
709
+ /** Per-class arrays of weak learners (OvR for multiclass, single for binary) */
710
+ private estimatorsPerClass;
711
+ /** Per-class initial log-odds predictions */
712
+ private initPredictions;
703
713
  /** Number of features */
704
714
  private nFeatures;
705
715
  /** Unique class labels */
@@ -712,19 +722,29 @@ declare class GradientBoostingClassifier implements Classifier {
712
722
  readonly maxDepth?: number;
713
723
  readonly minSamplesSplit?: number;
714
724
  });
725
+ /**
726
+ * Fit a single binary boosting ensemble.
727
+ * Trains nEstimators regression trees to optimize log loss for a binary target.
728
+ */
729
+ private fitBinary;
730
+ /**
731
+ * Compute raw scores for a single binary ensemble.
732
+ */
733
+ private predictRawBinary;
715
734
  /**
716
735
  * Fit the gradient boosting classifier on training data.
717
736
  *
718
737
  * Builds an additive model by sequentially fitting regression trees
719
738
  * to the pseudo-residuals (gradient of log loss).
739
+ * Supports binary (2 classes) and multiclass (>2 classes via OvR).
720
740
  *
721
741
  * @param X - Training data of shape (n_samples, n_features)
722
- * @param y - Target class labels of shape (n_samples,). Must contain exactly 2 classes.
742
+ * @param y - Target class labels of shape (n_samples,). Must contain at least 2 classes.
723
743
  * @returns this - The fitted estimator
724
744
  * @throws {ShapeError} If X is not 2D or y is not 1D
725
745
  * @throws {ShapeError} If X and y have different number of samples
726
746
  * @throws {DataValidationError} If X or y contain NaN/Inf values
727
- * @throws {InvalidParameterError} If y does not contain exactly 2 classes
747
+ * @throws {InvalidParameterError} If y does not contain at least 2 classes
728
748
  */
729
749
  fit(X: Tensor, y: Tensor): this;
730
750
  /**
@@ -740,11 +760,10 @@ declare class GradientBoostingClassifier implements Classifier {
740
760
  /**
741
761
  * Predict class probabilities for samples in X.
742
762
  *
743
- * Returns a matrix of shape (n_samples, 2) where columns are
744
- * [P(class_0), P(class_1)].
763
+ * Returns a matrix of shape (n_samples, n_classes).
745
764
  *
746
765
  * @param X - Samples of shape (n_samples, n_features)
747
- * @returns Class probability matrix of shape (n_samples, 2)
766
+ * @returns Class probability matrix of shape (n_samples, n_classes)
748
767
  * @throws {NotFittedError} If the model has not been fitted
749
768
  * @throws {ShapeError} If X has wrong dimensions or feature count
750
769
  * @throws {DataValidationError} If X contains NaN/Inf values
@@ -1041,7 +1060,7 @@ declare class LinearRegression implements Regressor {
1041
1060
  * @throws {ShapeError} If X has wrong dimensions or feature count
1042
1061
  * @throws {DataValidationError} If X contains NaN/Inf values
1043
1062
  */
1044
- predict(X: Tensor): Tensor;
1063
+ predict(X: Tensor): Tensor<Shape, ScalarDType>;
1045
1064
  /**
1046
1065
  * Return the coefficient of determination R^2 of the prediction.
1047
1066
  *
@@ -1434,7 +1453,7 @@ declare class Ridge implements Regressor {
1434
1453
  * const embedding = tsne.fitTransform(X);
1435
1454
  * ```
1436
1455
  *
1437
- * @see {@link https://scikit-learn.org/stable/modules/generated/sklearn.manifold.TSNE.html | scikit-learn TSNE}
1456
+ * @see {@link https://deepbox.dev/docs/ml-manifold | Deepbox Manifold Learning}
1438
1457
  * @see van der Maaten, L.J.P.; Hinton, G.E. (2008). "Visualizing High-Dimensional Data Using t-SNE"
1439
1458
  */
1440
1459
  declare class TSNE {
@@ -1543,6 +1562,14 @@ declare class TSNE {
1543
1562
  * Fit the model (same as fitTransform for t-SNE).
1544
1563
  */
1545
1564
  fit(X: Tensor): this;
1565
+ /**
1566
+ * Return the fitted embedding. For t-SNE, transform is equivalent to
1567
+ * returning the already-computed embedding (t-SNE is non-parametric).
1568
+ *
1569
+ * @param _X - Ignored, present for API compatibility
1570
+ * @returns Low-dimensional embedding of shape (n_samples, n_components)
1571
+ */
1572
+ transform(_X?: Tensor): Tensor;
1546
1573
  /**
1547
1574
  * Get the embedding.
1548
1575
  */
@@ -1585,8 +1612,8 @@ declare class TSNE {
1585
1612
  * const predictions = nb.predict(tensor([[2.5, 3.5]]));
1586
1613
  * ```
1587
1614
  *
1588
- * @see {@link https://en.wikipedia.org/wiki/Naive_Bayes_classifier | Wikipedia: Naive Bayes}
1589
- * @see {@link https://scikit-learn.org/stable/modules/generated/sklearn.naive_bayes.GaussianNB.html | scikit-learn GaussianNB}
1615
+ * @see {@link https://deepbox.dev/docs/ml-naive-bayes | Deepbox Naive Bayes}
1616
+ * @see {@link https://deepbox.dev/docs/ml-naive-bayes | Deepbox Naive Bayes}
1590
1617
  */
1591
1618
  declare class GaussianNB implements Classifier {
1592
1619
  private readonly varSmoothing;
@@ -1737,8 +1764,8 @@ declare abstract class KNeighborsBase {
1737
1764
  * const predictions = knn.predict(tensor([[1.5, 1.5]]));
1738
1765
  * ```
1739
1766
  *
1740
- * @see {@link https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm | Wikipedia: KNN}
1741
- * @see {@link https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KNeighborsClassifier.html | scikit-learn KNeighborsClassifier}
1767
+ * @see {@link https://deepbox.dev/docs/ml-neighbors | Deepbox Nearest Neighbors}
1768
+ * @see {@link https://deepbox.dev/docs/ml-neighbors | Deepbox Nearest Neighbors}
1742
1769
  */
1743
1770
  declare class KNeighborsClassifier extends KNeighborsBase implements Classifier {
1744
1771
  /**
@@ -1805,7 +1832,7 @@ declare class KNeighborsClassifier extends KNeighborsBase implements Classifier
1805
1832
  * const predictions = knn.predict(tensor([[1.5]]));
1806
1833
  * ```
1807
1834
  *
1808
- * @see {@link https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KNeighborsRegressor.html | scikit-learn KNeighborsRegressor}
1835
+ * @see {@link https://deepbox.dev/docs/ml-neighbors | Deepbox Nearest Neighbors}
1809
1836
  */
1810
1837
  declare class KNeighborsRegressor extends KNeighborsBase implements Regressor {
1811
1838
  /**
@@ -1868,7 +1895,7 @@ declare class KNeighborsRegressor extends KNeighborsBase implements Regressor {
1868
1895
  * const predictions = svm.predict(X);
1869
1896
  * ```
1870
1897
  *
1871
- * @see {@link https://scikit-learn.org/stable/modules/generated/sklearn.svm.LinearSVC.html | scikit-learn LinearSVC}
1898
+ * @see {@link https://deepbox.dev/docs/ml-svm | Deepbox SVM}
1872
1899
  */
1873
1900
  declare class LinearSVC implements Classifier {
1874
1901
  /** Regularization parameter (inverse of regularization strength) */
@@ -1877,13 +1904,13 @@ declare class LinearSVC implements Classifier {
1877
1904
  private readonly maxIter;
1878
1905
  /** Tolerance for stopping criterion */
1879
1906
  private readonly tol;
1880
- /** Weight vector of shape (n_features,) */
1881
- private weights;
1882
- /** Bias term */
1883
- private bias;
1907
+ /** Per-class weight vectors (OvR for multiclass, single for binary) */
1908
+ private weightsPerClass;
1909
+ /** Per-class bias terms */
1910
+ private biasPerClass;
1884
1911
  /** Number of features seen during fit */
1885
1912
  private nFeatures;
1886
- /** Unique class labels [0, 1] mapped from original labels */
1913
+ /** Unique class labels */
1887
1914
  private classLabels;
1888
1915
  /** Whether the model has been fitted */
1889
1916
  private fitted;
@@ -1900,19 +1927,27 @@ declare class LinearSVC implements Classifier {
1900
1927
  readonly maxIter?: number;
1901
1928
  readonly tol?: number;
1902
1929
  });
1930
+ /**
1931
+ * Fit a single binary SVM using sub-gradient descent on hinge loss.
1932
+ * Maps labels to {-1, +1} and returns learned weights + bias.
1933
+ */
1934
+ private fitBinary;
1935
+ /**
1936
+ * Compute decision value for a single binary classifier.
1937
+ */
1938
+ private decisionBinary;
1903
1939
  /**
1904
1940
  * Fit the SVM classifier using sub-gradient descent.
1905
1941
  *
1906
- * Uses a simplified hinge loss optimization with L2 regularization.
1907
- * Objective: minimize (1/2)||w||² + C * Σmax(0, 1 - y_i(w · x_i + b))
1942
+ * Supports both binary and multiclass classification (via OvR).
1908
1943
  *
1909
1944
  * @param X - Training data of shape (n_samples, n_features)
1910
- * @param y - Target labels of shape (n_samples,). Must contain exactly 2 classes.
1945
+ * @param y - Target labels of shape (n_samples,). Must contain at least 2 classes.
1911
1946
  * @returns this - The fitted estimator
1912
1947
  * @throws {ShapeError} If X is not 2D or y is not 1D
1913
1948
  * @throws {ShapeError} If X and y have different number of samples
1914
1949
  * @throws {DataValidationError} If X or y contain NaN/Inf values
1915
- * @throws {InvalidParameterError} If y does not contain exactly 2 classes
1950
+ * @throws {InvalidParameterError} If y does not contain at least 2 classes
1916
1951
  */
1917
1952
  fit(X: Tensor, y: Tensor): this;
1918
1953
  /**
@@ -1954,12 +1989,12 @@ declare class LinearSVC implements Classifier {
1954
1989
  */
1955
1990
  get coef(): Tensor;
1956
1991
  /**
1957
- * Get the bias term.
1992
+ * Get the bias terms.
1958
1993
  *
1959
- * @returns Bias value
1994
+ * @returns Bias values as tensor
1960
1995
  * @throws {NotFittedError} If the model has not been fitted
1961
1996
  */
1962
- get intercept(): number;
1997
+ get intercept(): Tensor;
1963
1998
  /**
1964
1999
  * Get hyperparameters for this estimator.
1965
2000
  *
@@ -1992,7 +2027,7 @@ declare class LinearSVC implements Classifier {
1992
2027
  * const predictions = svr.predict(X);
1993
2028
  * ```
1994
2029
  *
1995
- * @see {@link https://scikit-learn.org/stable/modules/generated/sklearn.svm.LinearSVR.html | scikit-learn LinearSVR}
2030
+ * @see {@link https://deepbox.dev/docs/ml-svm | Deepbox SVM}
1996
2031
  */
1997
2032
  declare class LinearSVR implements Regressor {
1998
2033
  /** Regularization parameter */
@@ -2088,7 +2123,7 @@ declare class LinearSVR implements Regressor {
2088
2123
  * const predictions = clf.predict(X);
2089
2124
  * ```
2090
2125
  *
2091
- * @see {@link https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html | scikit-learn DecisionTreeClassifier}
2126
+ * @see {@link https://deepbox.dev/docs/ml-tree | Deepbox Decision Trees}
2092
2127
  */
2093
2128
  declare class DecisionTreeClassifier implements Classifier {
2094
2129
  private maxDepth;
@@ -2182,7 +2217,7 @@ declare class DecisionTreeClassifier implements Classifier {
2182
2217
  *
2183
2218
  * Uses MSE reduction to find optimal splits for regression tasks.
2184
2219
  *
2185
- * @see {@link https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeRegressor.html | scikit-learn DecisionTreeRegressor}
2220
+ * @see {@link https://deepbox.dev/docs/ml-tree | Deepbox Decision Trees}
2186
2221
  */
2187
2222
  declare class DecisionTreeRegressor implements Regressor {
2188
2223
  private maxDepth;
@@ -2275,7 +2310,7 @@ declare class DecisionTreeRegressor implements Regressor {
2275
2310
  * const predictions = clf.predict(X_test);
2276
2311
  * ```
2277
2312
  *
2278
- * @see {@link https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html | scikit-learn RandomForestClassifier}
2313
+ * @see {@link https://deepbox.dev/docs/ml-tree | Deepbox Decision Trees}
2279
2314
  */
2280
2315
  declare class RandomForestClassifier implements Classifier {
2281
2316
  private readonly nEstimators;
@@ -2389,7 +2424,7 @@ declare class RandomForestClassifier implements Classifier {
2389
2424
  * const predictions = reg.predict(X_test);
2390
2425
  * ```
2391
2426
  *
2392
- * @see {@link https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestRegressor.html | scikit-learn RandomForestRegressor}
2427
+ * @see {@link https://deepbox.dev/docs/ml-tree | Deepbox Decision Trees}
2393
2428
  */
2394
2429
  declare class RandomForestRegressor implements Regressor {
2395
2430
  private readonly nEstimators;
@@ -1,5 +1,5 @@
1
- import { A as AnyTensor } from './index-eJgeni9c.cjs';
2
- import { T as Tensor } from './Tensor-BQLk1ltW.cjs';
1
+ import { A as AnyTensor } from './index-DLdiQzf0.cjs';
2
+ import { T as Tensor } from './Tensor-fxBg-TFZ.cjs';
3
3
 
4
4
  /**
5
5
  * @internal
@@ -676,7 +676,9 @@ declare function barh(y: Tensor, width: Tensor, options?: PlotOptions): void;
676
676
  /**
677
677
  * Plot a histogram on the current axes.
678
678
  */
679
- declare function hist(x: Tensor, bins?: number, options?: PlotOptions): void;
679
+ declare function hist(x: Tensor, bins?: number | (PlotOptions & {
680
+ bins?: number;
681
+ }), options?: PlotOptions): void;
680
682
  /**
681
683
  * Plot a box-and-whisker summary on the current axes.
682
684
  */
@@ -1,4 +1,4 @@
1
- import { T as Tensor } from './Tensor-BQLk1ltW.cjs';
1
+ import { T as Tensor } from './Tensor-BORFp_zt.js';
2
2
 
3
3
  type DataLoaderOptions = {
4
4
  batchSize?: number;
@@ -9,7 +9,7 @@ type DataLoaderOptions = {
9
9
  /**
10
10
  * Data loader for batching and shuffling datasets.
11
11
  *
12
- * Similar to PyTorch's DataLoader. Provides efficient iteration over datasets with support for
12
+ * Provides efficient iteration over datasets with support for
13
13
  * batching, shuffling, and deterministic reproducibility.
14
14
  *
15
15
  * @remarks
@@ -58,7 +58,7 @@ type DataLoaderOptions = {
58
58
  * }
59
59
  * ```
60
60
  *
61
- * @see {@link https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader | PyTorch DataLoader}
61
+ * @see {@link https://deepbox.dev/docs/datasets-dataloader | Deepbox DataLoader}
62
62
  */
63
63
  declare class DataLoader<TTarget extends Tensor | undefined = undefined> {
64
64
  private X;
@@ -75,8 +75,7 @@ declare class DataLoader<TTarget extends Tensor | undefined = undefined> {
75
75
  * Number of batches in the data loader.
76
76
  */
77
77
  get length(): number;
78
- [Symbol.iterator](this: DataLoader<Tensor>): IterableIterator<[Tensor, Tensor]>;
79
- [Symbol.iterator](this: DataLoader<undefined>): IterableIterator<[Tensor]>;
78
+ [Symbol.iterator](): IterableIterator<TTarget extends Tensor ? [Tensor, Tensor] : [Tensor]>;
80
79
  private prepareIteration;
81
80
  private iterateX;
82
81
  private iterateXY;
@@ -98,7 +97,7 @@ declare class DataLoader<TTarget extends Tensor | undefined = undefined> {
98
97
  * @param options.randomState - Seed for reproducibility.
99
98
  * @returns A tuple `[X, y]` where X has shape `[nSamples, nFeatures]` and y has shape `[nSamples]` with dtype `int32`.
100
99
  *
101
- * @see {@link https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_classification.html | sklearn.datasets.make_classification}
100
+ * @see {@link https://deepbox.dev/docs/datasets-synthetic | Deepbox Synthetic Datasets}
102
101
  */
103
102
  declare function makeClassification(options?: {
104
103
  nSamples?: number;
@@ -106,6 +105,7 @@ declare function makeClassification(options?: {
106
105
  nInformative?: number;
107
106
  nRedundant?: number;
108
107
  nClasses?: number;
108
+ flipY?: number;
109
109
  randomState?: number;
110
110
  }): [Tensor, Tensor];
111
111
  /**
@@ -121,7 +121,7 @@ declare function makeClassification(options?: {
121
121
  * @param options.randomState - Seed for reproducibility.
122
122
  * @returns A tuple `[X, y]` where X has shape `[nSamples, nFeatures]` and y has shape `[nSamples]`.
123
123
  *
124
- * @see {@link https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_regression.html | sklearn.datasets.make_regression}
124
+ * @see {@link https://deepbox.dev/docs/datasets-synthetic | Deepbox Synthetic Datasets}
125
125
  */
126
126
  declare function makeRegression(options?: {
127
127
  nSamples?: number;
@@ -144,7 +144,7 @@ declare function makeRegression(options?: {
144
144
  * @param options.randomState - Seed for reproducibility.
145
145
  * @returns A tuple `[X, y]` where X has shape `[nSamples, nFeatures]` and y has shape `[nSamples]` with dtype `int32`.
146
146
  *
147
- * @see {@link https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_blobs.html | sklearn.datasets.make_blobs}
147
+ * @see {@link https://deepbox.dev/docs/datasets-synthetic | Deepbox Synthetic Datasets}
148
148
  */
149
149
  declare function makeBlobs(options?: {
150
150
  nSamples?: number;
@@ -166,7 +166,7 @@ declare function makeBlobs(options?: {
166
166
  * @param options.randomState - Seed for reproducibility.
167
167
  * @returns A tuple `[X, y]` where X has shape `[nSamples, 2]` and y has shape `[nSamples]` with dtype `int32`.
168
168
  *
169
- * @see {@link https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_moons.html | sklearn.datasets.make_moons}
169
+ * @see {@link https://deepbox.dev/docs/datasets-synthetic | Deepbox Synthetic Datasets}
170
170
  */
171
171
  declare function makeMoons(options?: {
172
172
  nSamples?: number;
@@ -187,7 +187,7 @@ declare function makeMoons(options?: {
187
187
  * @param options.randomState - Seed for reproducibility.
188
188
  * @returns A tuple `[X, y]` where X has shape `[nSamples, 2]` and y has shape `[nSamples]` with dtype `int32`.
189
189
  *
190
- * @see {@link https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_circles.html | sklearn.datasets.make_circles}
190
+ * @see {@link https://deepbox.dev/docs/datasets-synthetic | Deepbox Synthetic Datasets}
191
191
  */
192
192
  declare function makeCircles(options?: {
193
193
  nSamples?: number;
@@ -209,7 +209,7 @@ declare function makeCircles(options?: {
209
209
  * @param options.randomState - Seed for reproducibility.
210
210
  * @returns A tuple `[X, y]` where X has shape `[nSamples, nFeatures]` and y has shape `[nSamples]` with dtype `int32`.
211
211
  *
212
- * @see {@link https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_gaussian_quantiles.html | sklearn.datasets.make_gaussian_quantiles}
212
+ * @see {@link https://deepbox.dev/docs/datasets-synthetic | Deepbox Synthetic Datasets}
213
213
  */
214
214
  declare function makeGaussianQuantiles(options?: {
215
215
  nSamples?: number;
@@ -224,6 +224,7 @@ type Dataset = {
224
224
  featureNames: string[];
225
225
  targetNames?: string[];
226
226
  description: string;
227
+ images?: Tensor;
227
228
  };
228
229
  /**
229
230
  * Load the synthetic Iris dataset.
@@ -1,7 +1,7 @@
1
- import { A as AnyTensor } from './index-tk4lSYod.js';
2
- import { D as DType, a as Device, A as Axis } from './tensor-B96jjJLQ.js';
3
- import { G as GradTensor } from './index-DIp_RrRt.js';
4
- import { T as Tensor } from './Tensor-g8mUClel.js';
1
+ import { A as AnyTensor } from './index-DLdiQzf0.cjs';
2
+ import { D as DType, a as Device, A as Axis } from './tensor-IlVTF0bz.cjs';
3
+ import { G as GradTensor } from './index-GUHYEhxs.cjs';
4
+ import { T as Tensor } from './Tensor-fxBg-TFZ.cjs';
5
5
 
6
6
  type StateEntry = {
7
7
  data: Array<number | string | bigint>;
@@ -31,7 +31,7 @@ type ForwardHook = (module: Module, inputs: AnyTensor[], output: AnyTensor) => A
31
31
  * All models should subclass this class. Modules can contain other modules,
32
32
  * allowing to nest them in a tree structure.
33
33
  *
34
- * This is analogous to PyTorch's nn.Module.
34
+ * { https://deepbox.dev/docs/nn-module | Deepbox Module & Sequential}
35
35
  *
36
36
  * @example
37
37
  * ```ts
@@ -63,7 +63,7 @@ type ForwardHook = (module: Module, inputs: AnyTensor[], output: AnyTensor) => A
63
63
  * ```
64
64
  *
65
65
  * References:
66
- * - PyTorch nn.Module: https://pytorch.org/docs/stable/generated/torch.nn.Module.html
66
+ * - Deepbox Module: https://deepbox.dev/docs/nn-module
67
67
  *
68
68
  * @category Neural Networks
69
69
  */
@@ -362,7 +362,7 @@ declare abstract class Module {
362
362
  * ```
363
363
  *
364
364
  * References:
365
- * - PyTorch Sequential: https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html
365
+ * - Deepbox Sequential: https://deepbox.dev/docs/nn-module
366
366
  * - Keras Sequential: https://keras.io/guides/sequential_model/
367
367
  *
368
368
  * @category Neural Network Containers
@@ -580,7 +580,7 @@ declare class Mish extends Module {
580
580
  * const output = mha.forward(x, x, x);
581
581
  * ```
582
582
  *
583
- * @see {@link https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html | PyTorch MultiheadAttention}
583
+ * @see {@link https://deepbox.dev/docs/nn-attention | Deepbox Attention}
584
584
  * @see Vaswani et al. (2017) "Attention Is All You Need"
585
585
  */
586
586
  declare class MultiheadAttention extends Module {
@@ -651,7 +651,7 @@ declare class MultiheadAttention extends Module {
651
651
  * const output = layer.forward(x);
652
652
  * ```
653
653
  *
654
- * @see {@link https://pytorch.org/docs/stable/generated/torch.nn.TransformerEncoderLayer.html | PyTorch TransformerEncoderLayer}
654
+ * @see {@link https://deepbox.dev/docs/nn-attention | Deepbox Attention}
655
655
  */
656
656
  declare class TransformerEncoderLayer extends Module {
657
657
  private readonly dModel;
@@ -666,7 +666,19 @@ declare class TransformerEncoderLayer extends Module {
666
666
  private readonly dropout1;
667
667
  private readonly dropout2;
668
668
  private readonly dropout3;
669
- constructor(dModel: number, nHead: number, dFF: number, options?: {
669
+ constructor(dModelOrOpts: number | {
670
+ readonly dModel: number;
671
+ readonly nHead: number;
672
+ readonly dimFeedforward?: number;
673
+ readonly dFF?: number;
674
+ readonly dropout?: number;
675
+ readonly eps?: number;
676
+ }, nHead?: number, dFFOrOptions?: number | {
677
+ readonly dimFeedforward?: number;
678
+ readonly dFF?: number;
679
+ readonly dropout?: number;
680
+ readonly eps?: number;
681
+ }, options?: {
670
682
  readonly dropout?: number;
671
683
  readonly eps?: number;
672
684
  });
@@ -692,7 +704,7 @@ declare class TransformerEncoderLayer extends Module {
692
704
  * const conv = new Conv1d(16, 33, 3); // in_channels=16, out_channels=33, kernel_size=3
693
705
  * ```
694
706
  *
695
- * @see {@link https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html | PyTorch Conv1d}
707
+ * @see {@link https://deepbox.dev/docs/nn-layers | Deepbox Layers}
696
708
  */
697
709
  declare class Conv1d extends Module {
698
710
  private readonly inChannels;
@@ -724,7 +736,7 @@ declare class Conv1d extends Module {
724
736
  * const conv = new Conv2d(3, 64, 3); // RGB to 64 channels, 3x3 kernel
725
737
  * ```
726
738
  *
727
- * @see {@link https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html | PyTorch Conv2d}
739
+ * @see {@link https://deepbox.dev/docs/nn-layers | Deepbox Layers}
728
740
  */
729
741
  declare class Conv2d extends Module {
730
742
  private readonly inChannels;
@@ -756,7 +768,7 @@ declare class Conv2d extends Module {
756
768
  * const pool = new MaxPool2d(2); // 2x2 pooling
757
769
  * ```
758
770
  *
759
- * @see {@link https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html | PyTorch MaxPool2d}
771
+ * @see {@link https://deepbox.dev/docs/nn-layers | Deepbox Layers}
760
772
  */
761
773
  declare class MaxPool2d extends Module {
762
774
  private readonly kernelSizeValue;
@@ -780,7 +792,7 @@ declare class MaxPool2d extends Module {
780
792
  * const pool = new AvgPool2d(2); // 2x2 pooling
781
793
  * ```
782
794
  *
783
- * @see {@link https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html | PyTorch AvgPool2d}
795
+ * @see {@link https://deepbox.dev/docs/nn-layers | Deepbox Layers}
784
796
  */
785
797
  declare class AvgPool2d extends Module {
786
798
  private readonly kernelSizeValue;
@@ -836,7 +848,7 @@ declare class AvgPool2d extends Module {
836
848
  *
837
849
  * References:
838
850
  * - Dropout paper: https://jmlr.org/papers/v15/srivastava14a.html
839
- * - PyTorch Dropout: https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html
851
+ * - Deepbox Dropout: https://deepbox.dev/docs/nn-normalization
840
852
  *
841
853
  * @category Neural Network Layers
842
854
  */
@@ -923,7 +935,7 @@ declare class Dropout extends Module {
923
935
  * ```
924
936
  *
925
937
  * References:
926
- * - PyTorch Linear: https://pytorch.org/docs/stable/generated/torch.nn.Linear.html
938
+ * - Deepbox Linear: https://deepbox.dev/docs/nn-layers
927
939
  * - Xavier/Glorot initialization: http://proceedings.mlr.press/v9/glorot10a.html
928
940
  *
929
941
  * @category Neural Network Layers
@@ -1014,7 +1026,7 @@ declare class Linear extends Module {
1014
1026
  * const y = bn.forward(x);
1015
1027
  * ```
1016
1028
  *
1017
- * @see {@link https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html | PyTorch BatchNorm1d}
1029
+ * @see {@link https://deepbox.dev/docs/nn-normalization | Deepbox Normalization & Dropout}
1018
1030
  */
1019
1031
  declare class BatchNorm1d extends Module {
1020
1032
  private readonly numFeatures;
@@ -1054,7 +1066,7 @@ declare class BatchNorm1d extends Module {
1054
1066
  * const y = ln.forward(x);
1055
1067
  * ```
1056
1068
  *
1057
- * @see {@link https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html | PyTorch LayerNorm}
1069
+ * @see {@link https://deepbox.dev/docs/nn-normalization | Deepbox Normalization & Dropout}
1058
1070
  */
1059
1071
  declare class LayerNorm extends Module {
1060
1072
  private readonly normalizedShape;
@@ -1087,7 +1099,7 @@ declare class LayerNorm extends Module {
1087
1099
  * const output = rnn.forward(x);
1088
1100
  * ```
1089
1101
  *
1090
- * @see {@link https://pytorch.org/docs/stable/generated/torch.nn.RNN.html | PyTorch RNN}
1102
+ * @see {@link https://deepbox.dev/docs/nn-recurrent | Deepbox Recurrent Layers}
1091
1103
  */
1092
1104
  declare class RNN extends Module {
1093
1105
  private readonly inputSize;
@@ -1129,7 +1141,7 @@ declare class RNN extends Module {
1129
1141
  * - Cell state: c_t = f_t * c_{t-1} + i_t * g_t
1130
1142
  * - Hidden state: h_t = o_t * tanh(c_t)
1131
1143
  *
1132
- * @see {@link https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html | PyTorch LSTM}
1144
+ * @see {@link https://deepbox.dev/docs/nn-recurrent | Deepbox Recurrent Layers}
1133
1145
  */
1134
1146
  declare class LSTM extends Module {
1135
1147
  private readonly inputSize;
@@ -1167,7 +1179,7 @@ declare class LSTM extends Module {
1167
1179
  * - New gate: n_t = tanh(W_in * x_t + b_in + r_t * (W_hn * h_{t-1} + b_hn))
1168
1180
  * - Hidden: h_t = (1 - z_t) * n_t + z_t * h_{t-1}
1169
1181
  *
1170
- * @see {@link https://pytorch.org/docs/stable/generated/torch.nn.GRU.html | PyTorch GRU}
1182
+ * @see {@link https://deepbox.dev/docs/nn-recurrent | Deepbox Recurrent Layers}
1171
1183
  */
1172
1184
  declare class GRU extends Module {
1173
1185
  private readonly inputSize;
@@ -1223,6 +1235,7 @@ declare class GRU extends Module {
1223
1235
  */
1224
1236
  declare function crossEntropyLoss(input: Tensor, target: Tensor): number;
1225
1237
  declare function crossEntropyLoss(input: GradTensor, target: AnyTensor): GradTensor;
1238
+ declare function crossEntropyLoss(input: AnyTensor, target: AnyTensor): number | GradTensor;
1226
1239
  /**
1227
1240
  * Binary Cross Entropy Loss with logits.
1228
1241
  *
@@ -1271,6 +1284,7 @@ declare function binaryCrossEntropyWithLogitsLoss(input: GradTensor, target: Any
1271
1284
  * @category Loss Functions
1272
1285
  */
1273
1286
  declare function mseLoss(predictions: Tensor, targets: Tensor, reduction?: "mean" | "sum" | "none"): Tensor;
1287
+ declare function mseLoss(predictions: GradTensor, targets: GradTensor, reduction?: "mean" | "sum" | "none"): GradTensor;
1274
1288
  /**
1275
1289
  * Mean Absolute Error (MAE) loss function, also known as L1 loss.
1276
1290
  *