sgemm-bi 0.1.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. sgemm_bi-0.1.1/.gitignore +5 -0
  2. sgemm_bi-0.1.1/CHANGELOG.md +85 -0
  3. sgemm_bi-0.1.1/Cargo.lock +114 -0
  4. sgemm_bi-0.1.1/Cargo.toml +32 -0
  5. sgemm_bi-0.1.1/LICENSE-APACHE +19 -0
  6. sgemm_bi-0.1.1/LICENSE-MIT +21 -0
  7. sgemm_bi-0.1.1/PKG-INFO +84 -0
  8. sgemm_bi-0.1.1/README.md +147 -0
  9. sgemm_bi-0.1.1/deny.toml +21 -0
  10. sgemm_bi-0.1.1/docs/usage-guide.md +172 -0
  11. sgemm_bi-0.1.1/examples/capi/smoke.c +162 -0
  12. sgemm_bi-0.1.1/examples/deterministic_training.rs +99 -0
  13. sgemm_bi-0.1.1/include/sgemm_bi.h +126 -0
  14. sgemm_bi-0.1.1/kernels/casts.cu +36 -0
  15. sgemm_bi-0.1.1/kernels/prelude.cuh +24 -0
  16. sgemm_bi-0.1.1/kernels/sgemm_bi.cu +6354 -0
  17. sgemm_bi-0.1.1/pyproject.toml +29 -0
  18. sgemm_bi-0.1.1/python/.gitignore +5 -0
  19. sgemm_bi-0.1.1/python/Cargo.lock +172 -0
  20. sgemm_bi-0.1.1/python/Cargo.toml +33 -0
  21. sgemm_bi-0.1.1/python/README.md +67 -0
  22. sgemm_bi-0.1.1/python/examples/train_deterministic.py +55 -0
  23. sgemm_bi-0.1.1/python/sgemm_bi/__init__.py +25 -0
  24. sgemm_bi-0.1.1/python/sgemm_bi/_sgemm_bi.pyi +96 -0
  25. sgemm_bi-0.1.1/python/sgemm_bi/py.typed +0 -0
  26. sgemm_bi-0.1.1/python/sgemm_bi/torch.py +230 -0
  27. sgemm_bi-0.1.1/python/src/lib.rs +307 -0
  28. sgemm_bi-0.1.1/python/tests/test_torch.py +142 -0
  29. sgemm_bi-0.1.1/src/capi.rs +423 -0
  30. sgemm_bi-0.1.1/src/dispatch.rs +2284 -0
  31. sgemm_bi-0.1.1/src/dtype.rs +36 -0
  32. sgemm_bi-0.1.1/src/engine.rs +333 -0
  33. sgemm_bi-0.1.1/src/error.rs +57 -0
  34. sgemm_bi-0.1.1/src/kernels.rs +270 -0
  35. sgemm_bi-0.1.1/src/lib.rs +90 -0
  36. sgemm_bi-0.1.1/tests/common/mod.rs +115 -0
  37. sgemm_bi-0.1.1/tests/contracts.rs +203 -0
  38. sgemm_bi-0.1.1/tests/tensor_cores.rs +322 -0
@@ -0,0 +1,5 @@
1
+ /target
2
+ .idea/
3
+ .DS_Store
4
+ *.iml
5
+ /internal/
@@ -0,0 +1,85 @@
1
+ # Changelog
2
+
3
+ All notable changes to this project are documented here. The format is
4
+ based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and
5
+ this project adheres to [Semantic Versioning](https://semver.org/).
6
+
7
+ ## [0.1.1] - 2026-06-12
8
+
9
+ ### Added
10
+
11
+ - **Tensor-core small-tile family + BK=64 staging**: six
12
+ `sgemm_bi_*_tc64_*` kernels (64×64 tiles, 128 threads) extend the TC
13
+ tier to both output dims >= 64 and to GPU-underfilling grids; both
14
+ tile families stage 64-deep reduction slabs (half the barrier count)
15
+ with the 128-tile family on dynamic shared memory. The two families
16
+ are BIT-IDENTICAL per output element (same ascending mma chain), so
17
+ the shape-only routing never changes output bits and the strict
18
+ all-M forward invariance survives tile switching — asserted through
19
+ the public API by `tc_cross_tile_strict_all_m_invariance`. Measured
20
+ (RTX 6000 Ada, bf16): TC forward 3.5–6.3× over the scalar tier at
21
+ GEMM level, ~116 TFLOPS on M2048 K768 N3072; in a host training
22
+ loop, parity with cuBLAS-PEDANTIC on d128/d256 models and 16–30 %
23
+ faster from d768 up.
24
+
25
+ - **PyTorch binding** (`python/`, PyPI package `sgemm-bi`, import
26
+ `sgemm_bi`): PyO3 0.29 + maturin, abi3 wheel for Python >= 3.9. No
27
+ libtorch linkage — tensors cross as raw device pointers, so one wheel
28
+ works with any PyTorch build; runtime needs only the NVIDIA driver.
29
+ Ships `sgemm_bi.Linear` (deterministic `nn.Linear` replacement with
30
+ GEMM-natural `[in, out]` weight layout and `from_torch` converter),
31
+ the functional `deterministic_linear` autograd op (dW accumulated in
32
+ f32 inside the kernel, one rounding to the parameter dtype), and the
33
+ low-level `Engine`. Engine work is ordered against torch's current
34
+ stream with a CUDA-event bridge (no host syncs); calls release the
35
+ GIL; forward/backward are safe across torch's autograd thread.
36
+ Desk-reviewed against PyTorch/PyO3/maturin/CUDA driver documentation;
37
+ GPU test suite (`python/tests/`) green on RTX 6000 Ada: parity vs
38
+ float64 references, bit-identity across runs in all three dtypes,
39
+ strict all-M batch invariance of the tensor-core forward, end-to-end
40
+ training.
41
+ - **CI/release for the binding**: `python-binding` job (fmt, clippy,
42
+ wheel build artifact) and a tag-gated `publish-pypi` job using PyPI
43
+ trusted publishing (OIDC, no token secret).
44
+
45
+ - **C ABI** behind the `capi` feature (`src/capi.rs`, header
46
+ `include/sgemm_bi.h`, `cdylib`/`staticlib` crate types): engine
47
+ create/destroy/synchronize on a device ordinal, one `SgbGemm`
48
+ descriptor for all six GEMM entry points (scalar forward/dW/dX +
49
+ tensor-core triad) over raw `CUdeviceptr`s, per-thread error strings
50
+ (`sgb_last_error`), raw stream access for event-based ordering, and
51
+ upcast-scratch pre-sizing for CUDA Graph capture. Panics convert to
52
+ `SGB_ERR_PANIC` instead of unwinding across the boundary. Smoke test:
53
+ `examples/capi/smoke.c`.
54
+ - **Explicit architecture gate**: `SgemmBi::new` now rejects devices
55
+ below `sm_80` with the new `Error::UnsupportedArch` ("requires Ampere
56
+ or newer") instead of surfacing an opaque NVRTC failure — the kernel
57
+ blob uses `cp.async` and native bf16 in every tier, so pre-Ampere
58
+ devices were never able to run it.
59
+
60
+ ## [0.1.0] - 2026-06-12
61
+
62
+ ### Added
63
+
64
+ - Initial release: deterministic, batch-invariant CUDA GEMM engine with
65
+ the full training triad — forward `Y = X@W + bias`, weight gradient
66
+ `dW += X^T@dY` (f32 master accumulator), input gradient `dX = dY@W^T`.
67
+ - **f32 tier**: full shape coverage via bucketed dispatch (GEMV,
68
+ ultra-thin, narrow, split-K, gap-fill, Big, Slim, split-M/N); fixed
69
+ reduction order, no atomics, no cuBLAS anywhere.
70
+ - **Typed tier (bf16/f16)**: native buckets keep f32 shared memory and
71
+ accumulation with the f32 tier's exact FMA chain; uncovered shapes
72
+ take "upcast → f32 kernel → RNE downcast". Both routes are
73
+ bit-identical by contract.
74
+ - **Tensor-core tier (bf16/f16)**: `mma.sync.m16n8k16` with f32
75
+ accumulators, 2-stage `cp.async` staging, `ldmatrix` fragment loads.
76
+ Separate numeric contract; bit-identical across runs and strictly
77
+ batch-invariant forward across all M. 3-7x faster than the scalar
78
+ tiers on 128x128-tile shapes.
79
+ - GPU contract tests (`tests/contracts.rs`, `tests/tensor_cores.rs`)
80
+ validated on RTX 6000 Ada / CUDA 13.2; CI with fmt, clippy, docs,
81
+ MSRV 1.94, cargo-deny, cross-platform build matrix, and a manual
82
+ tag-gated release pipeline.
83
+
84
+ [0.1.1]: https://github.com/silvermpx/sgemm-bi/compare/v0.1.0...v0.1.1
85
+ [0.1.0]: https://github.com/silvermpx/sgemm-bi/releases/tag/v0.1.0
@@ -0,0 +1,114 @@
1
+ # This file is automatically @generated by Cargo.
2
+ # It is not intended for manual editing.
3
+ version = 4
4
+
5
+ [[package]]
6
+ name = "cfg-if"
7
+ version = "1.0.4"
8
+ source = "registry+https://github.com/rust-lang/crates.io-index"
9
+ checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801"
10
+
11
+ [[package]]
12
+ name = "crunchy"
13
+ version = "0.2.4"
14
+ source = "registry+https://github.com/rust-lang/crates.io-index"
15
+ checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5"
16
+
17
+ [[package]]
18
+ name = "cudarc"
19
+ version = "0.19.7"
20
+ source = "registry+https://github.com/rust-lang/crates.io-index"
21
+ checksum = "1cea5f10a99e025c1b44ae2354c2d8326b25ddbd0baf76bde8e55cfd4018a2cc"
22
+ dependencies = [
23
+ "libloading",
24
+ ]
25
+
26
+ [[package]]
27
+ name = "half"
28
+ version = "2.7.1"
29
+ source = "registry+https://github.com/rust-lang/crates.io-index"
30
+ checksum = "6ea2d84b969582b4b1864a92dc5d27cd2b77b622a8d79306834f1be5ba20d84b"
31
+ dependencies = [
32
+ "cfg-if",
33
+ "crunchy",
34
+ "zerocopy",
35
+ ]
36
+
37
+ [[package]]
38
+ name = "libloading"
39
+ version = "0.9.0"
40
+ source = "registry+https://github.com/rust-lang/crates.io-index"
41
+ checksum = "754ca22de805bb5744484a5b151a9e1a8e837d5dc232c2d7d8c2e3492edc8b60"
42
+ dependencies = [
43
+ "cfg-if",
44
+ "windows-link",
45
+ ]
46
+
47
+ [[package]]
48
+ name = "proc-macro2"
49
+ version = "1.0.106"
50
+ source = "registry+https://github.com/rust-lang/crates.io-index"
51
+ checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934"
52
+ dependencies = [
53
+ "unicode-ident",
54
+ ]
55
+
56
+ [[package]]
57
+ name = "quote"
58
+ version = "1.0.45"
59
+ source = "registry+https://github.com/rust-lang/crates.io-index"
60
+ checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924"
61
+ dependencies = [
62
+ "proc-macro2",
63
+ ]
64
+
65
+ [[package]]
66
+ name = "sgemm-bi"
67
+ version = "0.1.1"
68
+ dependencies = [
69
+ "cudarc",
70
+ "half",
71
+ ]
72
+
73
+ [[package]]
74
+ name = "syn"
75
+ version = "2.0.117"
76
+ source = "registry+https://github.com/rust-lang/crates.io-index"
77
+ checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99"
78
+ dependencies = [
79
+ "proc-macro2",
80
+ "quote",
81
+ "unicode-ident",
82
+ ]
83
+
84
+ [[package]]
85
+ name = "unicode-ident"
86
+ version = "1.0.24"
87
+ source = "registry+https://github.com/rust-lang/crates.io-index"
88
+ checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75"
89
+
90
+ [[package]]
91
+ name = "windows-link"
92
+ version = "0.2.1"
93
+ source = "registry+https://github.com/rust-lang/crates.io-index"
94
+ checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5"
95
+
96
+ [[package]]
97
+ name = "zerocopy"
98
+ version = "0.8.52"
99
+ source = "registry+https://github.com/rust-lang/crates.io-index"
100
+ checksum = "ce1022995ff5ff5d841ad7d994facc23098cd40152f2c1d11cd607c6f530653f"
101
+ dependencies = [
102
+ "zerocopy-derive",
103
+ ]
104
+
105
+ [[package]]
106
+ name = "zerocopy-derive"
107
+ version = "0.8.52"
108
+ source = "registry+https://github.com/rust-lang/crates.io-index"
109
+ checksum = "1ae7f38b72ec2a254e2b87ef277cf2cd4fb97cbebf944faa6f33354da0867930"
110
+ dependencies = [
111
+ "proc-macro2",
112
+ "quote",
113
+ "syn",
114
+ ]
@@ -0,0 +1,32 @@
1
+ [package]
2
+ name = "sgemm-bi"
3
+ version = "0.1.1"
4
+ edition = "2024"
5
+ rust-version = "1.94"
6
+ authors = ["silvermpx"]
7
+ description = "Deterministic, batch-invariant CUDA GEMM engine with a full training triad (forward, dW, dX) in f32 / bf16 / f16, plus an opt-in tensor-core tier that is faster than cuBLAS PEDANTIC. Bit-identical results across runs; fixed reduction order; no atomics; no cuBLAS dependency."
8
+ homepage = "https://github.com/silvermpx/sgemm-bi"
9
+ documentation = "https://docs.rs/sgemm-bi"
10
+ license = "MIT OR Apache-2.0"
11
+ repository = "https://github.com/silvermpx/sgemm-bi"
12
+ readme = "README.md"
13
+ exclude = ["python/", ".github/"]
14
+ keywords = ["cuda", "gemm", "deterministic", "deep-learning", "reproducibility"]
15
+ categories = ["science", "mathematics"]
16
+
17
+ [lib]
18
+ # cdylib/staticlib carry the C ABI (`capi` feature); rlib is the normal
19
+ # Rust dependency form.
20
+ crate-type = ["lib", "cdylib", "staticlib"]
21
+
22
+ [features]
23
+ # Flat extern "C" interface (src/capi.rs + include/sgemm_bi.h).
24
+ capi = []
25
+
26
+ [dependencies.cudarc]
27
+ version = "0.19.7"
28
+ default-features = false
29
+ features = ["driver", "nvrtc", "dynamic-loading", "cuda-version-from-build-system"]
30
+
31
+ [dev-dependencies]
32
+ half = "2"
@@ -0,0 +1,19 @@
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ Copyright 2026 silvermpx
8
+
9
+ Licensed under the Apache License, Version 2.0 (the "License");
10
+ you may not use this file except in compliance with the License.
11
+ You may obtain a copy of the License at
12
+
13
+ http://www.apache.org/licenses/LICENSE-2.0
14
+
15
+ Unless required by applicable law or agreed to in writing, software
16
+ distributed under the License is distributed on an "AS IS" BASIS,
17
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ See the License for the specific language governing permissions and
19
+ limitations under the License.
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 silvermpx
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,84 @@
1
+ Metadata-Version: 2.4
2
+ Name: sgemm-bi
3
+ Version: 0.1.1
4
+ Classifier: Programming Language :: Python :: 3
5
+ Classifier: Programming Language :: Rust
6
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
7
+ Classifier: Environment :: GPU :: NVIDIA CUDA
8
+ Summary: Deterministic, batch-invariant CUDA GEMM for PyTorch: bit-identical training matmuls (forward / dW / dX) in f32, bf16, f16, with an opt-in tensor-core tier.
9
+ Keywords: cuda,gemm,deterministic,pytorch,reproducibility
10
+ Author: silvermpx
11
+ License-Expression: MIT OR Apache-2.0
12
+ Requires-Python: >=3.9
13
+ Description-Content-Type: text/markdown; charset=UTF-8; variant=GFM
14
+ Project-URL: Changelog, https://github.com/silvermpx/sgemm-bi/blob/main/CHANGELOG.md
15
+ Project-URL: Repository, https://github.com/silvermpx/sgemm-bi
16
+
17
+ # sgemm-bi for PyTorch
18
+
19
+ Deterministic, batch-invariant CUDA GEMM for PyTorch training:
20
+ bit-identical matmuls (forward / dW / dX) in float32, bfloat16, and
21
+ float16, with an opt-in tensor-core tier that is faster than cuBLAS
22
+ PEDANTIC on transformer-class shapes.
23
+
24
+ Built on the [sgemm-bi](https://github.com/silvermpx/sgemm-bi) Rust
25
+ engine. No libtorch linkage — tensors cross as raw device pointers, so
26
+ one wheel works with any PyTorch build. Kernels compile once at engine
27
+ creation via NVRTC; the runtime needs only the NVIDIA driver.
28
+
29
+ **Requirements**: NVIDIA Ampere or newer (sm_80+), PyTorch with CUDA.
30
+
31
+ ## Install
32
+
33
+ ```sh
34
+ pip install maturin
35
+ cd python && maturin build --release
36
+ pip install target/wheels/sgemm_bi-*.whl
37
+ ```
38
+
39
+ ## Use
40
+
41
+ ```python
42
+ import torch, sgemm_bi
43
+
44
+ # Drop-in layer (weight stored [in, out] — GEMM-natural; convert
45
+ # existing layers with Linear.from_torch):
46
+ layer = sgemm_bi.Linear(768, 3072, device="cuda", dtype=torch.bfloat16,
47
+ tensor_cores=True)
48
+ y = layer(x)
49
+ y.sum().backward() # deterministic dW (f32-accumulated) and dX
50
+
51
+ # Functional form:
52
+ y = sgemm_bi.deterministic_linear(x, weight, bias, tensor_cores=True)
53
+
54
+ # Low-level engine (raw pointers, explicit stream):
55
+ eng = sgemm_bi.Engine(0)
56
+ eng.forward(y.data_ptr(), x.data_ptr(), w.data_ptr(), None,
57
+ m, k, n, "bfloat16",
58
+ torch.cuda.current_stream().cuda_stream, True)
59
+ ```
60
+
61
+ ## Determinism contracts
62
+
63
+ | tier | guarantee |
64
+ |---|---|
65
+ | f32 / typed scalar | bit-identical across runs; bf16/f16 ≡ "upcast → f32 kernel → RNE downcast"; full shape coverage |
66
+ | tensor cores (`tensor_cores=True`) | own deterministic contract (mma.sync, f32 accumulate); bit-identical across runs; forward strictly batch-invariant across all M; falls back to the scalar tier when an output dim is < 64 (shape-only dispatch — still deterministic; two bit-identical tile families, 128×128 and 64×64, cover everything above) |
67
+
68
+ Bias gradient is a plain f32 column sum done in torch (deterministic
69
+ run-to-run for a fixed shape; not part of the engine's batch-invariance
70
+ contract).
71
+
72
+ ## Examples
73
+
74
+ [`examples/train_deterministic.py`](examples/train_deterministic.py)
75
+ trains twice from one seed and asserts bit-identical weights, then
76
+ checks batch invariance. Typed stubs ship in the wheel (`py.typed`) —
77
+ IDEs autocomplete and document the native `Engine` class.
78
+
79
+ ## Tests
80
+
81
+ ```sh
82
+ pip install pytest && pytest tests/ -v # needs a CUDA GPU
83
+ ```
84
+
@@ -0,0 +1,147 @@
1
+ # sgemm-bi
2
+
3
+ Deterministic, batch-invariant CUDA GEMM engine with a full **training
4
+ triad** — forward, weight gradient, and input gradient — in **f32, bf16,
5
+ and f16**, plus an opt-in **tensor-core tier**.
6
+
7
+ Existing batch-invariant kernel collections cover inference only and trade
8
+ 10–40% throughput for determinism. `sgemm-bi` covers the backward pass
9
+ too, and on tile-friendly shapes the tensor-core tier makes deterministic
10
+ training *faster* than a CUDA-core cuBLAS baseline.
11
+
12
+ ## Guarantees
13
+
14
+ - **Run-to-run determinism** — fixed reduction order in every kernel: no
15
+ atomics, no data-dependent splits, no vendor-BLAS fallback. Same inputs
16
+ → bit-identical outputs, including through CUDA Graph replay.
17
+ - **Batch invariance** — within a dispatch bucket, output row 0 is
18
+ bit-identical regardless of the batch dimension M. The tensor-core
19
+ forward is strictly batch-invariant across *all* M.
20
+ - **Typed bit contract** — bf16/f16 results are bit-identical to "upcast
21
+ the inputs to f32, run the f32 tier, round-to-nearest-even downcast the
22
+ output". Accumulation never happens in reduced precision; exactly one
23
+ rounding is applied, at the output store.
24
+
25
+ ## Operations
26
+
27
+ | op | math | output |
28
+ |---|---|---|
29
+ | `forward` | `Y[M,N] = X[M,K] @ W[K,N] + bias[N]` | typed / f32 |
30
+ | `backward_dw` | `dW[K,N] += X^T[K,M] @ dY[M,N]` | f32 accumulate |
31
+ | `backward_dx` | `dX[M,K] = dY[M,N] @ W^T[N,K]` | typed / f32 |
32
+
33
+ Each op exists in three tiers: `*_f32` (the reference chain), typed
34
+ (bf16/f16, bit-equal to the f32 tier on upcast inputs), and `*_tc`
35
+ (tensor cores — a separate deterministic contract; mma.sync with f32
36
+ accumulators cannot bit-match a scalar FMA chain, but it is deterministic
37
+ and strictly batch-invariant).
38
+
39
+ The f32 and typed tiers cover **every** shape: a bucketed dispatcher
40
+ (Big / Slim / narrow / ultra-thin / GEMV / split-K/M/N with fixed-order
41
+ tree reduction) handles the common cases natively and the typed tier
42
+ falls back to "upcast → f32 kernel → downcast" — same bits by contract —
43
+ for the rest. The tensor-core tier covers both output dims ≥ 64 (two
44
+ bit-identical kernel families, 128×128 and 64×64 tiles, routed by shape)
45
+ and returns `Error::Uncovered` otherwise.
46
+
47
+ ## Performance (RTX 6000 Ada, bf16)
48
+
49
+ Tensor-core tier vs the scalar deterministic tier, GEMM level
50
+ (forward; measured on this crate's bench suite):
51
+
52
+ | shape (M, K, N) | scalar | tensor cores | speedup |
53
+ |---|---:|---:|---:|
54
+ | 2048, 768, 3072 | 290.9 µs | 83.4 µs | **3.5×** |
55
+ | 4096, 1536, 3072 | 1123.0 µs | 353.5 µs | 3.2× |
56
+ | 2048, 768, 512 | 123.5 µs | 19.5 µs | **6.3×** |
57
+
58
+ ~116 TFLOPS bf16 at M2048 K768 N3072 (~32 % of Ada dense peak). dW and
59
+ dX see similar gains (4.0–5.6× and 3.5–5.1× on the same shapes).
60
+
61
+ Against cuBLAS (measured in a host application using this engine for
62
+ every training GEMM, same GPU, per optimizer step):
63
+
64
+ | dtype × tier | vs cuBLAS | model size |
65
+ |---|---|---|
66
+ | f32 scalar vs TF32 | 1.28–1.53× | full f32 precision vs truncated-mantissa TF32 |
67
+ | bf16/f16 scalar vs PEDANTIC | 1.09–1.37× | bit-contract, CUDA cores |
68
+ | bf16 TC vs PEDANTIC | **1.04× (d128) → 0.70× (d1536)** | parity on small models, 16–30 % FASTER from d768 |
69
+ | f16 TC vs PEDANTIC | 1.19× (d128) → 0.76× (d1536) | |
70
+
71
+ The cost of determinism is zero-to-negative on transformer-class
72
+ shapes; the deterministic bf16 step at d1536 also beats the f32-TF32
73
+ baseline outright.
74
+
75
+ ## Documentation and examples
76
+
77
+ - [Usage guide](docs/usage-guide.md) — recipes for all three interfaces,
78
+ tier selection, CUDA Graph capture, determinism self-checks.
79
+ - [`examples/deterministic_training.rs`](examples/deterministic_training.rs) —
80
+ full Rust triad with runtime determinism/invariance asserts
81
+ (`cargo run --release --example deterministic_training`).
82
+ - [`examples/capi/smoke.c`](examples/capi/smoke.c) — the C ABI end to end.
83
+ - [`python/examples/train_deterministic.py`](python/examples/train_deterministic.py) —
84
+ bit-identical PyTorch training, twice from one seed.
85
+ - API reference: [docs.rs/sgemm-bi](https://docs.rs/sgemm-bi); the Python
86
+ package ships typed stubs (`.pyi` + `py.typed`), so IDE hover/completion
87
+ documents the native `Engine` too.
88
+
89
+ ## Usage
90
+
91
+ ```rust,ignore
92
+ use sgemm_bi::{Dtype, SgemmBi, TypedPtr};
93
+
94
+ let context = cudarc::driver::CudaContext::new(0).unwrap();
95
+ let stream = context.new_stream().unwrap();
96
+ let engine = SgemmBi::new(&context, stream.clone()).unwrap();
97
+
98
+ // y/x/w are CUdeviceptr device allocations on `stream` (bf16 storage).
99
+ engine
100
+ .forward(
101
+ TypedPtr::new(y, Dtype::Bf16),
102
+ TypedPtr::new(x, Dtype::Bf16),
103
+ TypedPtr::new(w, Dtype::Bf16),
104
+ Some(bias_f32_ptr),
105
+ (m, k, n),
106
+ )
107
+ .unwrap();
108
+ ```
109
+
110
+ The engine binds to one stream; all calls enqueue and return. For CUDA
111
+ Graph capture, call `presize_upcast_scratch` before capturing so the
112
+ typed fallback never allocates inside (or after) a captured graph.
113
+
114
+ ## Requirements
115
+
116
+ - NVIDIA GPU, `sm_80`+ for the bf16/f16 and tensor-core tiers (`cp.async`,
117
+ `ldmatrix`, bf16 `mma.sync`); the f32 tier runs on older architectures.
118
+ - CUDA driver + NVRTC at run time. Kernels compile at engine construction
119
+ for the device's native architecture — no toolkit or `nvcc` needed.
120
+ - No cuBLAS: the library never links or calls a vendor BLAS.
121
+
122
+ ## Testing
123
+
124
+ Contract tests require a CUDA device:
125
+
126
+ ```sh
127
+ cargo test --release -- --test-threads=1
128
+ ```
129
+
130
+ Covered: f32 run-to-run bit identity; the typed bit contract swept across
131
+ ~90 dispatch-gate boundary shapes (forward) plus backward shapes;
132
+ per-bucket batch invariance; tensor-core determinism, strict all-M
133
+ invariance, and accuracy vs the f32 reference. Benchmarks are `#[ignore]`d
134
+ (`bench_tc_vs_scalar`).
135
+
136
+ ## Lineage
137
+
138
+ The Big-tile kernels descend from [siboehm's SGEMM
139
+ warptiling](https://github.com/siboehm/SGEMM_CUDA) work; smem padding
140
+ follows [salykova's sgemm.cu](https://github.com/salykova/sgemm.cu). The
141
+ engine is extracted from the GEMM layer of
142
+ [mamba-rs](https://github.com/silvermpx/mamba-rs), where it powers
143
+ deterministic SSM training.
144
+
145
+ ## License
146
+
147
+ Dual-licensed under MIT or Apache-2.0.
@@ -0,0 +1,21 @@
1
+ [advisories]
2
+ version = 2
3
+
4
+ [licenses]
5
+ version = 2
6
+ allow = [
7
+ "MIT",
8
+ "Apache-2.0",
9
+ "BSD-2-Clause",
10
+ "BSD-3-Clause",
11
+ "ISC",
12
+ "Unicode-3.0",
13
+ "Zlib",
14
+ ]
15
+
16
+ [bans]
17
+ multiple-versions = "warn"
18
+
19
+ [sources]
20
+ unknown-registry = "deny"
21
+ unknown-git = "deny"