r-scikit-learn 0.1.0__tar.gz → 0.1.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. r_scikit_learn-0.1.1/CHANGELOG.md +23 -0
  2. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/Cargo.lock +1 -1
  3. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/Cargo.toml +2 -1
  4. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/PKG-INFO +20 -10
  5. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/README.md +17 -9
  6. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/benches/benchmark_linear_models.py +6 -0
  7. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/pyproject.toml +3 -1
  8. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/python/rsklearn/__init__.py +1 -1
  9. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/python/rsklearn/linear_model/_least_squares.py +23 -1
  10. r_scikit_learn-0.1.1/tests/release_smoke.py +28 -0
  11. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/tests/test_linear_model_parity.py +82 -0
  12. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/LICENSE +0 -0
  13. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/benches/benchmark_metrics.py +0 -0
  14. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/benches/benchmark_preprocessing.py +0 -0
  15. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/python/rsklearn/_validation.py +0 -0
  16. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/python/rsklearn/base.py +0 -0
  17. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/python/rsklearn/compose/__init__.py +0 -0
  18. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/python/rsklearn/compose/_column_transformer.py +0 -0
  19. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/python/rsklearn/impute/__init__.py +0 -0
  20. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/python/rsklearn/impute/_simple_imputer.py +0 -0
  21. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/python/rsklearn/linear_model/__init__.py +0 -0
  22. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/python/rsklearn/linear_model/_base.py +0 -0
  23. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/python/rsklearn/linear_model/_coordinate_descent.py +0 -0
  24. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/python/rsklearn/linear_model/_logistic.py +0 -0
  25. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/python/rsklearn/linear_model/_warnings.py +0 -0
  26. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/python/rsklearn/metrics/__init__.py +0 -0
  27. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/python/rsklearn/metrics/_classification.py +0 -0
  28. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/python/rsklearn/metrics/_regression.py +0 -0
  29. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/python/rsklearn/metrics/_validation.py +0 -0
  30. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/python/rsklearn/model_selection/__init__.py +0 -0
  31. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/python/rsklearn/model_selection/_split.py +0 -0
  32. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/python/rsklearn/model_selection/_utils.py +0 -0
  33. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/python/rsklearn/model_selection/_validation.py +0 -0
  34. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/python/rsklearn/pipeline.py +0 -0
  35. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/python/rsklearn/preprocessing/__init__.py +0 -0
  36. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/python/rsklearn/preprocessing/_base.py +0 -0
  37. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/python/rsklearn/preprocessing/_categorical.py +0 -0
  38. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/python/rsklearn/preprocessing/_label_encoder.py +0 -0
  39. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/python/rsklearn/preprocessing/_minmax_scaler.py +0 -0
  40. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/python/rsklearn/preprocessing/_normalizer.py +0 -0
  41. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/python/rsklearn/preprocessing/_one_hot_encoder.py +0 -0
  42. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/python/rsklearn/preprocessing/_ordinal_encoder.py +0 -0
  43. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/python/rsklearn/preprocessing/_robust_scaler.py +0 -0
  44. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/python/rsklearn/preprocessing/_standard_scaler.py +0 -0
  45. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/python/rsklearn/py.typed +0 -0
  46. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/python/rsklearn/utils/__init__.py +0 -0
  47. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/python/rsklearn/utils/sparse.py +0 -0
  48. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/python/rsklearn/utils/validation.py +0 -0
  49. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/src/categorical.rs +0 -0
  50. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/src/error.rs +0 -0
  51. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/src/label_encoder.rs +0 -0
  52. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/src/lib.rs +0 -0
  53. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/src/linear_model.rs +0 -0
  54. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/src/metrics.rs +0 -0
  55. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/src/minmax_scaler.rs +0 -0
  56. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/src/normalizer.rs +0 -0
  57. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/src/robust_scaler.rs +0 -0
  58. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/src/simple_imputer.rs +0 -0
  59. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/src/sparse.rs +0 -0
  60. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/src/standard_scaler.rs +0 -0
  61. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/tests/test_base.py +0 -0
  62. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/tests/test_categorical_infrastructure.py +0 -0
  63. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/tests/test_column_transformer.py +0 -0
  64. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/tests/test_column_transformer_parity.py +0 -0
  65. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/tests/test_estimator_compliance.py +0 -0
  66. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/tests/test_label_encoder.py +0 -0
  67. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/tests/test_linear_model.py +0 -0
  68. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/tests/test_metrics.py +0 -0
  69. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/tests/test_metrics_parity.py +0 -0
  70. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/tests/test_minmax_scaler.py +0 -0
  71. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/tests/test_model_selection.py +0 -0
  72. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/tests/test_model_selection_parity.py +0 -0
  73. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/tests/test_normalizer.py +0 -0
  74. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/tests/test_one_hot_encoder.py +0 -0
  75. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/tests/test_ordinal_encoder.py +0 -0
  76. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/tests/test_pipeline.py +0 -0
  77. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/tests/test_pipeline_parity.py +0 -0
  78. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/tests/test_public_validation.py +0 -0
  79. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/tests/test_robust_scaler.py +0 -0
  80. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/tests/test_scikit_learn_parity.py +0 -0
  81. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/tests/test_simple_imputer.py +0 -0
  82. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/tests/test_sparse_infrastructure.py +0 -0
  83. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/tests/test_standard_scaler.py +0 -0
  84. {r_scikit_learn-0.1.0 → r_scikit_learn-0.1.1}/tests/test_validation.py +0 -0
@@ -0,0 +1,23 @@
1
+ # Changelog
2
+
3
+ All notable changes to r-scikit-learn are documented here. Release tags and
4
+ published package versions are immutable.
5
+
6
+ ## Unreleased
7
+
8
+ ## 0.1.1 - 2026-06-15
9
+
10
+ - Added wheel and source-distribution installation testing across supported
11
+ operating systems and Python versions.
12
+ - Added a numerical-safety fallback for ill-conditioned tall least-squares
13
+ problems.
14
+ - Added TestPyPI, cross-platform benchmark, and immutable manual release
15
+ workflows.
16
+
17
+ ## 0.1.0
18
+
19
+ - Added Rust-powered preprocessing, categorical encoding, sparse
20
+ infrastructure, composition, metrics, model selection, and linear models.
21
+ - Added Linux, macOS, and Windows wheel builds for Python 3.10 through 3.13.
22
+ - Added Rust-native tall-matrix least squares and multinomial logistic
23
+ optimization.
@@ -998,7 +998,7 @@ dependencies = [
998
998
 
999
999
  [[package]]
1000
1000
  name = "r-scikit-learn-core"
1001
- version = "0.1.0"
1001
+ version = "0.1.1"
1002
1002
  dependencies = [
1003
1003
  "faer",
1004
1004
  "nalgebra",
@@ -1,6 +1,6 @@
1
1
  [package]
2
2
  name = "r-scikit-learn-core"
3
- version = "0.1.0"
3
+ version = "0.1.1"
4
4
  edition = "2021"
5
5
  license = "MIT"
6
6
  description = "Rust computational core for r-scikit-learn"
@@ -9,6 +9,7 @@ repository = "https://github.com/rishib42/r-scikit-learn"
9
9
  include = [
10
10
  "/Cargo.lock",
11
11
  "/Cargo.toml",
12
+ "/CHANGELOG.md",
12
13
  "/LICENSE",
13
14
  "/README.md",
14
15
  "/benches/*.py",
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: r-scikit-learn
3
- Version: 0.1.0
3
+ Version: 0.1.1
4
4
  Classifier: Development Status :: 3 - Alpha
5
5
  Classifier: License :: OSI Approved :: MIT License
6
6
  Classifier: Programming Language :: Python :: 3
@@ -12,6 +12,7 @@ Classifier: Programming Language :: Rust
12
12
  Classifier: Typing :: Typed
13
13
  Requires-Dist: numpy>=1.23
14
14
  Requires-Dist: scipy>=1.10
15
+ Requires-Dist: hypothesis>=6.100,<7 ; extra == 'dev'
15
16
  Requires-Dist: maturin>=1.9,<2.0 ; extra == 'dev'
16
17
  Requires-Dist: pytest>=8 ; extra == 'dev'
17
18
  Requires-Dist: ruff>=0.11 ; extra == 'dev'
@@ -25,6 +26,7 @@ Author: r-scikit-learn contributors
25
26
  License-Expression: MIT
26
27
  Requires-Python: >=3.10
27
28
  Description-Content-Type: text/markdown; charset=UTF-8; variant=GFM
29
+ Project-URL: Changelog, https://github.com/rishib42/r-scikit-learn/blob/main/CHANGELOG.md
28
30
  Project-URL: Homepage, https://github.com/rishib42/r-scikit-learn
29
31
  Project-URL: Issues, https://github.com/rishib42/r-scikit-learn/issues
30
32
  Project-URL: Repository, https://github.com/rishib42/r-scikit-learn
@@ -34,7 +36,7 @@ Project-URL: Repository, https://github.com/rishib42/r-scikit-learn
34
36
  Fast, familiar machine-learning building blocks powered by safe Rust. 🦀
35
37
 
36
38
  `r-scikit-learn` combines a Rust computational core with lightweight,
37
- scikit-learn-style Python estimators. Version 0.1.0 includes:
39
+ scikit-learn-style Python estimators. Version 0.1.1 includes:
38
40
 
39
41
  - Preprocessing, categorical encoding, and missing-value imputation
40
42
  - Pipelines and column transformers
@@ -327,14 +329,22 @@ Substantial numerical loops release the Python GIL.
327
329
 
328
330
  ## Release
329
331
 
330
- 1. Run all development checks and build a release wheel.
331
- 2. Install the wheel into a clean virtual environment and run the import smoke
332
- test.
333
- 3. Verify the distribution name on PyPI.
334
- 4. Tag the release as `v0.1.0` and push the tag.
335
- 5. Approve the GitHub Actions Trusted Publishing environment.
336
-
337
- The release workflow uses PyPI Trusted Publishing and contains no API token.
332
+ 1. Update the matching versions in `pyproject.toml`, `Cargo.toml`, and
333
+ `python/rsklearn/__init__.py`, then update `CHANGELOG.md`.
334
+ 2. Push the release commit and wait for CI, including manylinux and sdist
335
+ installation checks, to pass.
336
+ 3. Run the manual TestPyPI workflow and verify its distributions.
337
+ 4. Run the manual Release workflow with the version number without a `v`
338
+ prefix.
339
+ 5. Approve the PyPI environment if required.
340
+
341
+ The release workflow refuses existing versions, installs every wheel on
342
+ Python 3.10-3.13 across Linux, macOS, and Windows, verifies sdist installation,
343
+ publishes through PyPI Trusted Publishing, creates the immutable GitHub tag and
344
+ release, attaches artifacts, and verifies installation from PyPI. No API token
345
+ is stored in the repository. Configure separate `pypi` and `testpypi` GitHub
346
+ environments and matching Trusted Publishers for `release.yml` and
347
+ `test-pypi.yml`, respectively.
338
348
 
339
349
  ## Roadmap
340
350
 
@@ -3,7 +3,7 @@
3
3
  Fast, familiar machine-learning building blocks powered by safe Rust. 🦀
4
4
 
5
5
  `r-scikit-learn` combines a Rust computational core with lightweight,
6
- scikit-learn-style Python estimators. Version 0.1.0 includes:
6
+ scikit-learn-style Python estimators. Version 0.1.1 includes:
7
7
 
8
8
  - Preprocessing, categorical encoding, and missing-value imputation
9
9
  - Pipelines and column transformers
@@ -296,14 +296,22 @@ Substantial numerical loops release the Python GIL.
296
296
 
297
297
  ## Release
298
298
 
299
- 1. Run all development checks and build a release wheel.
300
- 2. Install the wheel into a clean virtual environment and run the import smoke
301
- test.
302
- 3. Verify the distribution name on PyPI.
303
- 4. Tag the release as `v0.1.0` and push the tag.
304
- 5. Approve the GitHub Actions Trusted Publishing environment.
305
-
306
- The release workflow uses PyPI Trusted Publishing and contains no API token.
299
+ 1. Update the matching versions in `pyproject.toml`, `Cargo.toml`, and
300
+ `python/rsklearn/__init__.py`, then update `CHANGELOG.md`.
301
+ 2. Push the release commit and wait for CI, including manylinux and sdist
302
+ installation checks, to pass.
303
+ 3. Run the manual TestPyPI workflow and verify its distributions.
304
+ 4. Run the manual Release workflow with the version number without a `v`
305
+ prefix.
306
+ 5. Approve the PyPI environment if required.
307
+
308
+ The release workflow refuses existing versions, installs every wheel on
309
+ Python 3.10-3.13 across Linux, macOS, and Windows, verifies sdist installation,
310
+ publishes through PyPI Trusted Publishing, creates the immutable GitHub tag and
311
+ release, attaches artifacts, and verifies installation from PyPI. No API token
312
+ is stored in the repository. Configure separate `pypi` and `testpypi` GitHub
313
+ environments and matching Trusted Publishers for `release.yml` and
314
+ `test-pypi.yml`, respectively.
307
315
 
308
316
  ## Roadmap
309
317
 
@@ -10,6 +10,8 @@ from collections.abc import Callable
10
10
 
11
11
  import numpy as np
12
12
  import rsklearn.linear_model as rlinear
13
+ import scipy
14
+ import sklearn
13
15
  import sklearn.linear_model as slinear
14
16
  from rsklearn import _core
15
17
 
@@ -65,6 +67,10 @@ def main() -> None:
65
67
  )
66
68
  print(f"Python: {sys.executable}")
67
69
  print(f"Rust extension: {_core.__file__} ({profile})")
70
+ print(
71
+ f"Dependencies: numpy {np.__version__}, scipy {scipy.__version__}, "
72
+ f"scikit-learn {sklearn.__version__}"
73
+ )
68
74
  rng = np.random.default_rng(20260614)
69
75
  X = rng.normal(size=(args.samples, args.features))
70
76
  coefficients = rng.normal(size=args.features)
@@ -4,7 +4,7 @@ build-backend = "maturin"
4
4
 
5
5
  [project]
6
6
  name = "r-scikit-learn"
7
- version = "0.1.0"
7
+ version = "0.1.1"
8
8
  description = "High-performance scikit-learn-style machine learning powered by safe Rust"
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.10"
@@ -26,6 +26,7 @@ dependencies = ["numpy>=1.23", "scipy>=1.10"]
26
26
 
27
27
  [project.optional-dependencies]
28
28
  dev = [
29
+ "hypothesis>=6.100,<7",
29
30
  "maturin>=1.9,<2.0",
30
31
  "pytest>=8",
31
32
  "ruff>=0.11",
@@ -36,6 +37,7 @@ dev = [
36
37
  Homepage = "https://github.com/rishib42/r-scikit-learn"
37
38
  Repository = "https://github.com/rishib42/r-scikit-learn"
38
39
  Issues = "https://github.com/rishib42/r-scikit-learn/issues"
40
+ Changelog = "https://github.com/rishib42/r-scikit-learn/blob/main/CHANGELOG.md"
39
41
 
40
42
  [tool.maturin]
41
43
  python-source = "python"
@@ -45,4 +45,4 @@ __all__ = [
45
45
  "make_column_transformer",
46
46
  "make_pipeline",
47
47
  ]
48
- __version__ = "0.1.0"
48
+ __version__ = "0.1.1"
@@ -12,6 +12,26 @@ from rsklearn.base import BaseEstimator, RegressorMixin
12
12
 
13
13
  from ._base import LinearModel, validate_regression_fit
14
14
 
15
+ # Normal equations square the condition number. This cutoff limits the
16
+ # resulting float64 error amplification before selecting the fast Gram path.
17
+ _GRAM_MIN_SINGULAR_RATIO = np.finfo(np.float64).eps ** 0.25
18
+ _GRAM_RANK_RESOLUTION = np.sqrt(np.finfo(np.float64).eps)
19
+
20
+
21
+ def _tall_solution_is_stable(singular: np.ndarray, rank: int, tolerance: float) -> bool:
22
+ """Return whether normal-equation accuracy is reliable for this spectrum."""
23
+ if rank == 0 or singular.size == 0 or not np.isfinite(singular).all():
24
+ return False
25
+ if rank < singular.size and tolerance < _GRAM_RANK_RESOLUTION:
26
+ return False
27
+ largest = singular[0]
28
+ smallest_retained = singular[rank - 1]
29
+ return (
30
+ largest > 0
31
+ and smallest_retained > 0
32
+ and smallest_retained / largest >= _GRAM_MIN_SINGULAR_RATIO
33
+ )
34
+
15
35
 
16
36
  def _fit_lstsq(
17
37
  X: np.ndarray,
@@ -22,7 +42,9 @@ def _fit_lstsq(
22
42
  ) -> tuple[np.ndarray, np.ndarray, int, np.ndarray]:
23
43
  """Solve unregularized least squares through a shape-aware dense backend."""
24
44
  if X.shape[0] >= 4 * X.shape[1]:
25
- return _core.linear_fit_tall(X, y, weights, fit_intercept, tolerance)
45
+ tall_fit = _core.linear_fit_tall(X, y, weights, fit_intercept, tolerance)
46
+ if _tall_solution_is_stable(tall_fit[3], tall_fit[2], tolerance):
47
+ return tall_fit
26
48
  uniform_weights = np.all(weights == weights[0])
27
49
  if fit_intercept:
28
50
  if uniform_weights:
@@ -0,0 +1,28 @@
1
+ """Minimal installed-distribution smoke test used by release workflows."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import numpy as np
6
+ import rsklearn
7
+ from rsklearn.linear_model import LinearRegression, LogisticRegression
8
+ from rsklearn.preprocessing import OneHotEncoder, StandardScaler
9
+
10
+
11
+ def main() -> None:
12
+ X = np.asarray([[1.0, 2.0], [2.0, 1.0], [3.0, 4.0], [4.0, 3.0]])
13
+ regression = LinearRegression().fit(X, [3.0, 3.0, 7.0, 7.0])
14
+ np.testing.assert_allclose(regression.predict(X), [3.0, 3.0, 7.0, 7.0])
15
+
16
+ classification = LogisticRegression(max_iter=500).fit(X, [0, 0, 1, 1])
17
+ np.testing.assert_array_equal(classification.predict(X), [0, 0, 1, 1])
18
+
19
+ scaled = StandardScaler().fit_transform(X)
20
+ np.testing.assert_allclose(scaled.mean(axis=0), 0.0, atol=1e-12)
21
+
22
+ encoded = OneHotEncoder().fit_transform([["a"], ["b"], ["a"]])
23
+ assert encoded.shape == (3, 2)
24
+ assert rsklearn.__version__
25
+
26
+
27
+ if __name__ == "__main__":
28
+ main()
@@ -1,12 +1,16 @@
1
1
  import numpy as np
2
2
  import pytest
3
+ from hypothesis import given, settings
4
+ from hypothesis import strategies as st
3
5
  from rsklearn.linear_model import (
4
6
  ElasticNet,
5
7
  Lasso,
6
8
  LinearRegression,
7
9
  LogisticRegression,
8
10
  Ridge,
11
+ _least_squares,
9
12
  )
13
+ from rsklearn.linear_model._least_squares import _tall_solution_is_stable
10
14
  from scipy import linalg
11
15
 
12
16
  sklearn_linear = pytest.importorskip("sklearn.linear_model")
@@ -60,6 +64,84 @@ def test_tall_linear_regression_matches_svd_near_rank_deficiency(perturbation):
60
64
  np.testing.assert_allclose(ours.predict(X), expected, rtol=1e-7, atol=5e-9)
61
65
 
62
66
 
67
+ @given(
68
+ rows=st.integers(min_value=80, max_value=240),
69
+ columns=st.integers(min_value=2, max_value=12),
70
+ log_condition=st.floats(
71
+ min_value=0.0, max_value=12.0, allow_nan=False, allow_infinity=False
72
+ ),
73
+ weighted=st.booleans(),
74
+ fit_intercept=st.booleans(),
75
+ )
76
+ @settings(max_examples=40, deadline=None)
77
+ def test_linear_regression_matches_svd_across_condition_numbers(
78
+ rows, columns, log_condition, weighted, fit_intercept
79
+ ):
80
+ rows = max(rows, 4 * columns)
81
+ rng = np.random.default_rng(
82
+ rows * 10_000 + columns * 100 + int(log_condition * 10) + int(weighted)
83
+ )
84
+ left, _ = np.linalg.qr(rng.normal(size=(rows, columns)))
85
+ right, _ = np.linalg.qr(rng.normal(size=(columns, columns)))
86
+ singular = np.geomspace(1.0, 10.0**-log_condition, columns)
87
+ X = np.ascontiguousarray((left * singular) @ right.T)
88
+ y = rng.normal(size=rows)
89
+ weights = rng.uniform(0.1, 2.0, size=rows) if weighted else None
90
+ ours = LinearRegression(tol=1e-10, fit_intercept=fit_intercept).fit(
91
+ X, y, sample_weight=weights
92
+ )
93
+ reference_weights = np.ones(rows) if weights is None else weights
94
+ if fit_intercept:
95
+ x_mean = np.average(X, axis=0, weights=reference_weights)
96
+ y_mean = np.average(y, weights=reference_weights)
97
+ else:
98
+ x_mean = np.zeros(columns)
99
+ y_mean = 0.0
100
+ root_weights = np.sqrt(reference_weights)
101
+ coefficients, _, _, _ = linalg.lstsq(
102
+ (X - x_mean) * root_weights[:, None],
103
+ (y - y_mean) * root_weights,
104
+ cond=1e-10,
105
+ check_finite=False,
106
+ lapack_driver="gelsd",
107
+ )
108
+ expected = X @ coefficients + y_mean - coefficients @ x_mean
109
+ np.testing.assert_allclose(ours.predict(X), expected, rtol=1e-7, atol=1e-9)
110
+
111
+
112
+ def test_tall_solution_stability_gate_rejects_unsafe_spectra():
113
+ assert _tall_solution_is_stable(np.asarray([10.0, 1.0]), 2, 1e-6)
114
+ assert not _tall_solution_is_stable(np.asarray([10.0, 1e-10]), 2, 1e-6)
115
+ assert _tall_solution_is_stable(np.asarray([10.0, 1.0, 0.0]), 2, 1e-6)
116
+ assert not _tall_solution_is_stable(np.asarray([10.0, 1.0, 0.0]), 2, 1e-10)
117
+ assert not _tall_solution_is_stable(np.asarray([0.0, 0.0]), 0, 1e-6)
118
+
119
+
120
+ @pytest.mark.parametrize(
121
+ "log_condition, expects_fallback", [(2.0, False), (10.0, True)]
122
+ )
123
+ def test_tall_linear_regression_falls_back_only_when_numerically_unsafe(
124
+ monkeypatch, log_condition, expects_fallback
125
+ ):
126
+ rng = np.random.default_rng(1234)
127
+ left, _ = np.linalg.qr(rng.normal(size=(1_000, 5)))
128
+ right, _ = np.linalg.qr(rng.normal(size=(5, 5)))
129
+ X = np.ascontiguousarray(
130
+ (left * np.geomspace(1.0, 10.0**-log_condition, 5)) @ right.T
131
+ )
132
+ original = _least_squares.linalg.lstsq
133
+ calls = 0
134
+
135
+ def tracked_lstsq(*args, **kwargs):
136
+ nonlocal calls
137
+ calls += 1
138
+ return original(*args, **kwargs)
139
+
140
+ monkeypatch.setattr(_least_squares.linalg, "lstsq", tracked_lstsq)
141
+ LinearRegression(tol=1e-10).fit(X, rng.normal(size=X.shape[0]))
142
+ assert bool(calls) is expects_fallback
143
+
144
+
63
145
  @pytest.mark.parametrize("alpha", [0.0, 0.1, 10.0])
64
146
  @pytest.mark.parametrize("fit_intercept", [True, False])
65
147
  def test_ridge_matches_scikit_learn_svd(alpha, fit_intercept):
File without changes