sgemm-bi 0.1.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sgemm_bi-0.1.1/.gitignore +5 -0
- sgemm_bi-0.1.1/CHANGELOG.md +85 -0
- sgemm_bi-0.1.1/Cargo.lock +114 -0
- sgemm_bi-0.1.1/Cargo.toml +32 -0
- sgemm_bi-0.1.1/LICENSE-APACHE +19 -0
- sgemm_bi-0.1.1/LICENSE-MIT +21 -0
- sgemm_bi-0.1.1/PKG-INFO +84 -0
- sgemm_bi-0.1.1/README.md +147 -0
- sgemm_bi-0.1.1/deny.toml +21 -0
- sgemm_bi-0.1.1/docs/usage-guide.md +172 -0
- sgemm_bi-0.1.1/examples/capi/smoke.c +162 -0
- sgemm_bi-0.1.1/examples/deterministic_training.rs +99 -0
- sgemm_bi-0.1.1/include/sgemm_bi.h +126 -0
- sgemm_bi-0.1.1/kernels/casts.cu +36 -0
- sgemm_bi-0.1.1/kernels/prelude.cuh +24 -0
- sgemm_bi-0.1.1/kernels/sgemm_bi.cu +6354 -0
- sgemm_bi-0.1.1/pyproject.toml +29 -0
- sgemm_bi-0.1.1/python/.gitignore +5 -0
- sgemm_bi-0.1.1/python/Cargo.lock +172 -0
- sgemm_bi-0.1.1/python/Cargo.toml +33 -0
- sgemm_bi-0.1.1/python/README.md +67 -0
- sgemm_bi-0.1.1/python/examples/train_deterministic.py +55 -0
- sgemm_bi-0.1.1/python/sgemm_bi/__init__.py +25 -0
- sgemm_bi-0.1.1/python/sgemm_bi/_sgemm_bi.pyi +96 -0
- sgemm_bi-0.1.1/python/sgemm_bi/py.typed +0 -0
- sgemm_bi-0.1.1/python/sgemm_bi/torch.py +230 -0
- sgemm_bi-0.1.1/python/src/lib.rs +307 -0
- sgemm_bi-0.1.1/python/tests/test_torch.py +142 -0
- sgemm_bi-0.1.1/src/capi.rs +423 -0
- sgemm_bi-0.1.1/src/dispatch.rs +2284 -0
- sgemm_bi-0.1.1/src/dtype.rs +36 -0
- sgemm_bi-0.1.1/src/engine.rs +333 -0
- sgemm_bi-0.1.1/src/error.rs +57 -0
- sgemm_bi-0.1.1/src/kernels.rs +270 -0
- sgemm_bi-0.1.1/src/lib.rs +90 -0
- sgemm_bi-0.1.1/tests/common/mod.rs +115 -0
- sgemm_bi-0.1.1/tests/contracts.rs +203 -0
- sgemm_bi-0.1.1/tests/tensor_cores.rs +322 -0
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
# Changelog
|
|
2
|
+
|
|
3
|
+
All notable changes to this project are documented here. The format is
|
|
4
|
+
based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and
|
|
5
|
+
this project adheres to [Semantic Versioning](https://semver.org/).
|
|
6
|
+
|
|
7
|
+
## [0.1.1] - 2026-06-12
|
|
8
|
+
|
|
9
|
+
### Added
|
|
10
|
+
|
|
11
|
+
- **Tensor-core small-tile family + BK=64 staging**: six
|
|
12
|
+
`sgemm_bi_*_tc64_*` kernels (64×64 tiles, 128 threads) extend the TC
|
|
13
|
+
tier to both output dims >= 64 and to GPU-underfilling grids; both
|
|
14
|
+
tile families stage 64-deep reduction slabs (half the barrier count)
|
|
15
|
+
with the 128-tile family on dynamic shared memory. The two families
|
|
16
|
+
are BIT-IDENTICAL per output element (same ascending mma chain), so
|
|
17
|
+
the shape-only routing never changes output bits and the strict
|
|
18
|
+
all-M forward invariance survives tile switching — asserted through
|
|
19
|
+
the public API by `tc_cross_tile_strict_all_m_invariance`. Measured
|
|
20
|
+
(RTX 6000 Ada, bf16): TC forward 3.5–6.3× over the scalar tier at
|
|
21
|
+
GEMM level, ~116 TFLOPS on M2048 K768 N3072; in a host training
|
|
22
|
+
loop, parity with cuBLAS-PEDANTIC on d128/d256 models and 16–30 %
|
|
23
|
+
faster from d768 up.
|
|
24
|
+
|
|
25
|
+
- **PyTorch binding** (`python/`, PyPI package `sgemm-bi`, import
|
|
26
|
+
`sgemm_bi`): PyO3 0.29 + maturin, abi3 wheel for Python >= 3.9. No
|
|
27
|
+
libtorch linkage — tensors cross as raw device pointers, so one wheel
|
|
28
|
+
works with any PyTorch build; runtime needs only the NVIDIA driver.
|
|
29
|
+
Ships `sgemm_bi.Linear` (deterministic `nn.Linear` replacement with
|
|
30
|
+
GEMM-natural `[in, out]` weight layout and `from_torch` converter),
|
|
31
|
+
the functional `deterministic_linear` autograd op (dW accumulated in
|
|
32
|
+
f32 inside the kernel, one rounding to the parameter dtype), and the
|
|
33
|
+
low-level `Engine`. Engine work is ordered against torch's current
|
|
34
|
+
stream with a CUDA-event bridge (no host syncs); calls release the
|
|
35
|
+
GIL; forward/backward are safe across torch's autograd thread.
|
|
36
|
+
Desk-reviewed against PyTorch/PyO3/maturin/CUDA driver documentation;
|
|
37
|
+
GPU test suite (`python/tests/`) green on RTX 6000 Ada: parity vs
|
|
38
|
+
float64 references, bit-identity across runs in all three dtypes,
|
|
39
|
+
strict all-M batch invariance of the tensor-core forward, end-to-end
|
|
40
|
+
training.
|
|
41
|
+
- **CI/release for the binding**: `python-binding` job (fmt, clippy,
|
|
42
|
+
wheel build artifact) and a tag-gated `publish-pypi` job using PyPI
|
|
43
|
+
trusted publishing (OIDC, no token secret).
|
|
44
|
+
|
|
45
|
+
- **C ABI** behind the `capi` feature (`src/capi.rs`, header
|
|
46
|
+
`include/sgemm_bi.h`, `cdylib`/`staticlib` crate types): engine
|
|
47
|
+
create/destroy/synchronize on a device ordinal, one `SgbGemm`
|
|
48
|
+
descriptor for all six GEMM entry points (scalar forward/dW/dX +
|
|
49
|
+
tensor-core triad) over raw `CUdeviceptr`s, per-thread error strings
|
|
50
|
+
(`sgb_last_error`), raw stream access for event-based ordering, and
|
|
51
|
+
upcast-scratch pre-sizing for CUDA Graph capture. Panics convert to
|
|
52
|
+
`SGB_ERR_PANIC` instead of unwinding across the boundary. Smoke test:
|
|
53
|
+
`examples/capi/smoke.c`.
|
|
54
|
+
- **Explicit architecture gate**: `SgemmBi::new` now rejects devices
|
|
55
|
+
below `sm_80` with the new `Error::UnsupportedArch` ("requires Ampere
|
|
56
|
+
or newer") instead of surfacing an opaque NVRTC failure — the kernel
|
|
57
|
+
blob uses `cp.async` and native bf16 in every tier, so pre-Ampere
|
|
58
|
+
devices were never able to run it.
|
|
59
|
+
|
|
60
|
+
## [0.1.0] - 2026-06-12
|
|
61
|
+
|
|
62
|
+
### Added
|
|
63
|
+
|
|
64
|
+
- Initial release: deterministic, batch-invariant CUDA GEMM engine with
|
|
65
|
+
the full training triad — forward `Y = X@W + bias`, weight gradient
|
|
66
|
+
`dW += X^T@dY` (f32 master accumulator), input gradient `dX = dY@W^T`.
|
|
67
|
+
- **f32 tier**: full shape coverage via bucketed dispatch (GEMV,
|
|
68
|
+
ultra-thin, narrow, split-K, gap-fill, Big, Slim, split-M/N); fixed
|
|
69
|
+
reduction order, no atomics, no cuBLAS anywhere.
|
|
70
|
+
- **Typed tier (bf16/f16)**: native buckets keep f32 shared memory and
|
|
71
|
+
accumulation with the f32 tier's exact FMA chain; uncovered shapes
|
|
72
|
+
take "upcast → f32 kernel → RNE downcast". Both routes are
|
|
73
|
+
bit-identical by contract.
|
|
74
|
+
- **Tensor-core tier (bf16/f16)**: `mma.sync.m16n8k16` with f32
|
|
75
|
+
accumulators, 2-stage `cp.async` staging, `ldmatrix` fragment loads.
|
|
76
|
+
Separate numeric contract; bit-identical across runs and strictly
|
|
77
|
+
batch-invariant forward across all M. 3-7x faster than the scalar
|
|
78
|
+
tiers on 128x128-tile shapes.
|
|
79
|
+
- GPU contract tests (`tests/contracts.rs`, `tests/tensor_cores.rs`)
|
|
80
|
+
validated on RTX 6000 Ada / CUDA 13.2; CI with fmt, clippy, docs,
|
|
81
|
+
MSRV 1.94, cargo-deny, cross-platform build matrix, and a manual
|
|
82
|
+
tag-gated release pipeline.
|
|
83
|
+
|
|
84
|
+
[0.1.1]: https://github.com/silvermpx/sgemm-bi/compare/v0.1.0...v0.1.1
|
|
85
|
+
[0.1.0]: https://github.com/silvermpx/sgemm-bi/releases/tag/v0.1.0
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
# This file is automatically @generated by Cargo.
|
|
2
|
+
# It is not intended for manual editing.
|
|
3
|
+
version = 4
|
|
4
|
+
|
|
5
|
+
[[package]]
|
|
6
|
+
name = "cfg-if"
|
|
7
|
+
version = "1.0.4"
|
|
8
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
9
|
+
checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801"
|
|
10
|
+
|
|
11
|
+
[[package]]
|
|
12
|
+
name = "crunchy"
|
|
13
|
+
version = "0.2.4"
|
|
14
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
15
|
+
checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5"
|
|
16
|
+
|
|
17
|
+
[[package]]
|
|
18
|
+
name = "cudarc"
|
|
19
|
+
version = "0.19.7"
|
|
20
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
21
|
+
checksum = "1cea5f10a99e025c1b44ae2354c2d8326b25ddbd0baf76bde8e55cfd4018a2cc"
|
|
22
|
+
dependencies = [
|
|
23
|
+
"libloading",
|
|
24
|
+
]
|
|
25
|
+
|
|
26
|
+
[[package]]
|
|
27
|
+
name = "half"
|
|
28
|
+
version = "2.7.1"
|
|
29
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
30
|
+
checksum = "6ea2d84b969582b4b1864a92dc5d27cd2b77b622a8d79306834f1be5ba20d84b"
|
|
31
|
+
dependencies = [
|
|
32
|
+
"cfg-if",
|
|
33
|
+
"crunchy",
|
|
34
|
+
"zerocopy",
|
|
35
|
+
]
|
|
36
|
+
|
|
37
|
+
[[package]]
|
|
38
|
+
name = "libloading"
|
|
39
|
+
version = "0.9.0"
|
|
40
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
41
|
+
checksum = "754ca22de805bb5744484a5b151a9e1a8e837d5dc232c2d7d8c2e3492edc8b60"
|
|
42
|
+
dependencies = [
|
|
43
|
+
"cfg-if",
|
|
44
|
+
"windows-link",
|
|
45
|
+
]
|
|
46
|
+
|
|
47
|
+
[[package]]
|
|
48
|
+
name = "proc-macro2"
|
|
49
|
+
version = "1.0.106"
|
|
50
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
51
|
+
checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934"
|
|
52
|
+
dependencies = [
|
|
53
|
+
"unicode-ident",
|
|
54
|
+
]
|
|
55
|
+
|
|
56
|
+
[[package]]
|
|
57
|
+
name = "quote"
|
|
58
|
+
version = "1.0.45"
|
|
59
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
60
|
+
checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924"
|
|
61
|
+
dependencies = [
|
|
62
|
+
"proc-macro2",
|
|
63
|
+
]
|
|
64
|
+
|
|
65
|
+
[[package]]
|
|
66
|
+
name = "sgemm-bi"
|
|
67
|
+
version = "0.1.1"
|
|
68
|
+
dependencies = [
|
|
69
|
+
"cudarc",
|
|
70
|
+
"half",
|
|
71
|
+
]
|
|
72
|
+
|
|
73
|
+
[[package]]
|
|
74
|
+
name = "syn"
|
|
75
|
+
version = "2.0.117"
|
|
76
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
77
|
+
checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99"
|
|
78
|
+
dependencies = [
|
|
79
|
+
"proc-macro2",
|
|
80
|
+
"quote",
|
|
81
|
+
"unicode-ident",
|
|
82
|
+
]
|
|
83
|
+
|
|
84
|
+
[[package]]
|
|
85
|
+
name = "unicode-ident"
|
|
86
|
+
version = "1.0.24"
|
|
87
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
88
|
+
checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75"
|
|
89
|
+
|
|
90
|
+
[[package]]
|
|
91
|
+
name = "windows-link"
|
|
92
|
+
version = "0.2.1"
|
|
93
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
94
|
+
checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5"
|
|
95
|
+
|
|
96
|
+
[[package]]
|
|
97
|
+
name = "zerocopy"
|
|
98
|
+
version = "0.8.52"
|
|
99
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
100
|
+
checksum = "ce1022995ff5ff5d841ad7d994facc23098cd40152f2c1d11cd607c6f530653f"
|
|
101
|
+
dependencies = [
|
|
102
|
+
"zerocopy-derive",
|
|
103
|
+
]
|
|
104
|
+
|
|
105
|
+
[[package]]
|
|
106
|
+
name = "zerocopy-derive"
|
|
107
|
+
version = "0.8.52"
|
|
108
|
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
109
|
+
checksum = "1ae7f38b72ec2a254e2b87ef277cf2cd4fb97cbebf944faa6f33354da0867930"
|
|
110
|
+
dependencies = [
|
|
111
|
+
"proc-macro2",
|
|
112
|
+
"quote",
|
|
113
|
+
"syn",
|
|
114
|
+
]
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
[package]
|
|
2
|
+
name = "sgemm-bi"
|
|
3
|
+
version = "0.1.1"
|
|
4
|
+
edition = "2024"
|
|
5
|
+
rust-version = "1.94"
|
|
6
|
+
authors = ["silvermpx"]
|
|
7
|
+
description = "Deterministic, batch-invariant CUDA GEMM engine with a full training triad (forward, dW, dX) in f32 / bf16 / f16, plus an opt-in tensor-core tier that is faster than cuBLAS PEDANTIC. Bit-identical results across runs; fixed reduction order; no atomics; no cuBLAS dependency."
|
|
8
|
+
homepage = "https://github.com/silvermpx/sgemm-bi"
|
|
9
|
+
documentation = "https://docs.rs/sgemm-bi"
|
|
10
|
+
license = "MIT OR Apache-2.0"
|
|
11
|
+
repository = "https://github.com/silvermpx/sgemm-bi"
|
|
12
|
+
readme = "README.md"
|
|
13
|
+
exclude = ["python/", ".github/"]
|
|
14
|
+
keywords = ["cuda", "gemm", "deterministic", "deep-learning", "reproducibility"]
|
|
15
|
+
categories = ["science", "mathematics"]
|
|
16
|
+
|
|
17
|
+
[lib]
|
|
18
|
+
# cdylib/staticlib carry the C ABI (`capi` feature); rlib is the normal
|
|
19
|
+
# Rust dependency form.
|
|
20
|
+
crate-type = ["lib", "cdylib", "staticlib"]
|
|
21
|
+
|
|
22
|
+
[features]
|
|
23
|
+
# Flat extern "C" interface (src/capi.rs + include/sgemm_bi.h).
|
|
24
|
+
capi = []
|
|
25
|
+
|
|
26
|
+
[dependencies.cudarc]
|
|
27
|
+
version = "0.19.7"
|
|
28
|
+
default-features = false
|
|
29
|
+
features = ["driver", "nvrtc", "dynamic-loading", "cuda-version-from-build-system"]
|
|
30
|
+
|
|
31
|
+
[dev-dependencies]
|
|
32
|
+
half = "2"
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
Apache License
|
|
2
|
+
Version 2.0, January 2004
|
|
3
|
+
http://www.apache.org/licenses/
|
|
4
|
+
|
|
5
|
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
|
6
|
+
|
|
7
|
+
Copyright 2026 silvermpx
|
|
8
|
+
|
|
9
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
10
|
+
you may not use this file except in compliance with the License.
|
|
11
|
+
You may obtain a copy of the License at
|
|
12
|
+
|
|
13
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
14
|
+
|
|
15
|
+
Unless required by applicable law or agreed to in writing, software
|
|
16
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
17
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
18
|
+
See the License for the specific language governing permissions and
|
|
19
|
+
limitations under the License.
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 silvermpx
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
sgemm_bi-0.1.1/PKG-INFO
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: sgemm-bi
|
|
3
|
+
Version: 0.1.1
|
|
4
|
+
Classifier: Programming Language :: Python :: 3
|
|
5
|
+
Classifier: Programming Language :: Rust
|
|
6
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
7
|
+
Classifier: Environment :: GPU :: NVIDIA CUDA
|
|
8
|
+
Summary: Deterministic, batch-invariant CUDA GEMM for PyTorch: bit-identical training matmuls (forward / dW / dX) in f32, bf16, f16, with an opt-in tensor-core tier.
|
|
9
|
+
Keywords: cuda,gemm,deterministic,pytorch,reproducibility
|
|
10
|
+
Author: silvermpx
|
|
11
|
+
License-Expression: MIT OR Apache-2.0
|
|
12
|
+
Requires-Python: >=3.9
|
|
13
|
+
Description-Content-Type: text/markdown; charset=UTF-8; variant=GFM
|
|
14
|
+
Project-URL: Changelog, https://github.com/silvermpx/sgemm-bi/blob/main/CHANGELOG.md
|
|
15
|
+
Project-URL: Repository, https://github.com/silvermpx/sgemm-bi
|
|
16
|
+
|
|
17
|
+
# sgemm-bi for PyTorch
|
|
18
|
+
|
|
19
|
+
Deterministic, batch-invariant CUDA GEMM for PyTorch training:
|
|
20
|
+
bit-identical matmuls (forward / dW / dX) in float32, bfloat16, and
|
|
21
|
+
float16, with an opt-in tensor-core tier that is faster than cuBLAS
|
|
22
|
+
PEDANTIC on transformer-class shapes.
|
|
23
|
+
|
|
24
|
+
Built on the [sgemm-bi](https://github.com/silvermpx/sgemm-bi) Rust
|
|
25
|
+
engine. No libtorch linkage — tensors cross as raw device pointers, so
|
|
26
|
+
one wheel works with any PyTorch build. Kernels compile once at engine
|
|
27
|
+
creation via NVRTC; the runtime needs only the NVIDIA driver.
|
|
28
|
+
|
|
29
|
+
**Requirements**: NVIDIA Ampere or newer (sm_80+), PyTorch with CUDA.
|
|
30
|
+
|
|
31
|
+
## Install
|
|
32
|
+
|
|
33
|
+
```sh
|
|
34
|
+
pip install maturin
|
|
35
|
+
cd python && maturin build --release
|
|
36
|
+
pip install target/wheels/sgemm_bi-*.whl
|
|
37
|
+
```
|
|
38
|
+
|
|
39
|
+
## Use
|
|
40
|
+
|
|
41
|
+
```python
|
|
42
|
+
import torch, sgemm_bi
|
|
43
|
+
|
|
44
|
+
# Drop-in layer (weight stored [in, out] — GEMM-natural; convert
|
|
45
|
+
# existing layers with Linear.from_torch):
|
|
46
|
+
layer = sgemm_bi.Linear(768, 3072, device="cuda", dtype=torch.bfloat16,
|
|
47
|
+
tensor_cores=True)
|
|
48
|
+
y = layer(x)
|
|
49
|
+
y.sum().backward() # deterministic dW (f32-accumulated) and dX
|
|
50
|
+
|
|
51
|
+
# Functional form:
|
|
52
|
+
y = sgemm_bi.deterministic_linear(x, weight, bias, tensor_cores=True)
|
|
53
|
+
|
|
54
|
+
# Low-level engine (raw pointers, explicit stream):
|
|
55
|
+
eng = sgemm_bi.Engine(0)
|
|
56
|
+
eng.forward(y.data_ptr(), x.data_ptr(), w.data_ptr(), None,
|
|
57
|
+
m, k, n, "bfloat16",
|
|
58
|
+
torch.cuda.current_stream().cuda_stream, True)
|
|
59
|
+
```
|
|
60
|
+
|
|
61
|
+
## Determinism contracts
|
|
62
|
+
|
|
63
|
+
| tier | guarantee |
|
|
64
|
+
|---|---|
|
|
65
|
+
| f32 / typed scalar | bit-identical across runs; bf16/f16 ≡ "upcast → f32 kernel → RNE downcast"; full shape coverage |
|
|
66
|
+
| tensor cores (`tensor_cores=True`) | own deterministic contract (mma.sync, f32 accumulate); bit-identical across runs; forward strictly batch-invariant across all M; falls back to the scalar tier when an output dim is < 64 (shape-only dispatch — still deterministic; two bit-identical tile families, 128×128 and 64×64, cover everything above) |
|
|
67
|
+
|
|
68
|
+
Bias gradient is a plain f32 column sum done in torch (deterministic
|
|
69
|
+
run-to-run for a fixed shape; not part of the engine's batch-invariance
|
|
70
|
+
contract).
|
|
71
|
+
|
|
72
|
+
## Examples
|
|
73
|
+
|
|
74
|
+
[`examples/train_deterministic.py`](examples/train_deterministic.py)
|
|
75
|
+
trains twice from one seed and asserts bit-identical weights, then
|
|
76
|
+
checks batch invariance. Typed stubs ship in the wheel (`py.typed`) —
|
|
77
|
+
IDEs autocomplete and document the native `Engine` class.
|
|
78
|
+
|
|
79
|
+
## Tests
|
|
80
|
+
|
|
81
|
+
```sh
|
|
82
|
+
pip install pytest && pytest tests/ -v # needs a CUDA GPU
|
|
83
|
+
```
|
|
84
|
+
|
sgemm_bi-0.1.1/README.md
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
# sgemm-bi
|
|
2
|
+
|
|
3
|
+
Deterministic, batch-invariant CUDA GEMM engine with a full **training
|
|
4
|
+
triad** — forward, weight gradient, and input gradient — in **f32, bf16,
|
|
5
|
+
and f16**, plus an opt-in **tensor-core tier**.
|
|
6
|
+
|
|
7
|
+
Existing batch-invariant kernel collections cover inference only and trade
|
|
8
|
+
10–40% throughput for determinism. `sgemm-bi` covers the backward pass
|
|
9
|
+
too, and on tile-friendly shapes the tensor-core tier makes deterministic
|
|
10
|
+
training *faster* than a CUDA-core cuBLAS baseline.
|
|
11
|
+
|
|
12
|
+
## Guarantees
|
|
13
|
+
|
|
14
|
+
- **Run-to-run determinism** — fixed reduction order in every kernel: no
|
|
15
|
+
atomics, no data-dependent splits, no vendor-BLAS fallback. Same inputs
|
|
16
|
+
→ bit-identical outputs, including through CUDA Graph replay.
|
|
17
|
+
- **Batch invariance** — within a dispatch bucket, output row 0 is
|
|
18
|
+
bit-identical regardless of the batch dimension M. The tensor-core
|
|
19
|
+
forward is strictly batch-invariant across *all* M.
|
|
20
|
+
- **Typed bit contract** — bf16/f16 results are bit-identical to "upcast
|
|
21
|
+
the inputs to f32, run the f32 tier, round-to-nearest-even downcast the
|
|
22
|
+
output". Accumulation never happens in reduced precision; exactly one
|
|
23
|
+
rounding is applied, at the output store.
|
|
24
|
+
|
|
25
|
+
## Operations
|
|
26
|
+
|
|
27
|
+
| op | math | output |
|
|
28
|
+
|---|---|---|
|
|
29
|
+
| `forward` | `Y[M,N] = X[M,K] @ W[K,N] + bias[N]` | typed / f32 |
|
|
30
|
+
| `backward_dw` | `dW[K,N] += X^T[K,M] @ dY[M,N]` | f32 accumulate |
|
|
31
|
+
| `backward_dx` | `dX[M,K] = dY[M,N] @ W^T[N,K]` | typed / f32 |
|
|
32
|
+
|
|
33
|
+
Each op exists in three tiers: `*_f32` (the reference chain), typed
|
|
34
|
+
(bf16/f16, bit-equal to the f32 tier on upcast inputs), and `*_tc`
|
|
35
|
+
(tensor cores — a separate deterministic contract; mma.sync with f32
|
|
36
|
+
accumulators cannot bit-match a scalar FMA chain, but it is deterministic
|
|
37
|
+
and strictly batch-invariant).
|
|
38
|
+
|
|
39
|
+
The f32 and typed tiers cover **every** shape: a bucketed dispatcher
|
|
40
|
+
(Big / Slim / narrow / ultra-thin / GEMV / split-K/M/N with fixed-order
|
|
41
|
+
tree reduction) handles the common cases natively and the typed tier
|
|
42
|
+
falls back to "upcast → f32 kernel → downcast" — same bits by contract —
|
|
43
|
+
for the rest. The tensor-core tier covers both output dims ≥ 64 (two
|
|
44
|
+
bit-identical kernel families, 128×128 and 64×64 tiles, routed by shape)
|
|
45
|
+
and returns `Error::Uncovered` otherwise.
|
|
46
|
+
|
|
47
|
+
## Performance (RTX 6000 Ada, bf16)
|
|
48
|
+
|
|
49
|
+
Tensor-core tier vs the scalar deterministic tier, GEMM level
|
|
50
|
+
(forward; measured on this crate's bench suite):
|
|
51
|
+
|
|
52
|
+
| shape (M, K, N) | scalar | tensor cores | speedup |
|
|
53
|
+
|---|---:|---:|---:|
|
|
54
|
+
| 2048, 768, 3072 | 290.9 µs | 83.4 µs | **3.5×** |
|
|
55
|
+
| 4096, 1536, 3072 | 1123.0 µs | 353.5 µs | 3.2× |
|
|
56
|
+
| 2048, 768, 512 | 123.5 µs | 19.5 µs | **6.3×** |
|
|
57
|
+
|
|
58
|
+
~116 TFLOPS bf16 at M2048 K768 N3072 (~32 % of Ada dense peak). dW and
|
|
59
|
+
dX see similar gains (4.0–5.6× and 3.5–5.1× on the same shapes).
|
|
60
|
+
|
|
61
|
+
Against cuBLAS (measured in a host application using this engine for
|
|
62
|
+
every training GEMM, same GPU, per optimizer step):
|
|
63
|
+
|
|
64
|
+
| dtype × tier | vs cuBLAS | model size |
|
|
65
|
+
|---|---|---|
|
|
66
|
+
| f32 scalar vs TF32 | 1.28–1.53× | full f32 precision vs truncated-mantissa TF32 |
|
|
67
|
+
| bf16/f16 scalar vs PEDANTIC | 1.09–1.37× | bit-contract, CUDA cores |
|
|
68
|
+
| bf16 TC vs PEDANTIC | **1.04× (d128) → 0.70× (d1536)** | parity on small models, 16–30 % FASTER from d768 |
|
|
69
|
+
| f16 TC vs PEDANTIC | 1.19× (d128) → 0.76× (d1536) | |
|
|
70
|
+
|
|
71
|
+
The cost of determinism is zero-to-negative on transformer-class
|
|
72
|
+
shapes; the deterministic bf16 step at d1536 also beats the f32-TF32
|
|
73
|
+
baseline outright.
|
|
74
|
+
|
|
75
|
+
## Documentation and examples
|
|
76
|
+
|
|
77
|
+
- [Usage guide](docs/usage-guide.md) — recipes for all three interfaces,
|
|
78
|
+
tier selection, CUDA Graph capture, determinism self-checks.
|
|
79
|
+
- [`examples/deterministic_training.rs`](examples/deterministic_training.rs) —
|
|
80
|
+
full Rust triad with runtime determinism/invariance asserts
|
|
81
|
+
(`cargo run --release --example deterministic_training`).
|
|
82
|
+
- [`examples/capi/smoke.c`](examples/capi/smoke.c) — the C ABI end to end.
|
|
83
|
+
- [`python/examples/train_deterministic.py`](python/examples/train_deterministic.py) —
|
|
84
|
+
bit-identical PyTorch training, twice from one seed.
|
|
85
|
+
- API reference: [docs.rs/sgemm-bi](https://docs.rs/sgemm-bi); the Python
|
|
86
|
+
package ships typed stubs (`.pyi` + `py.typed`), so IDE hover/completion
|
|
87
|
+
documents the native `Engine` too.
|
|
88
|
+
|
|
89
|
+
## Usage
|
|
90
|
+
|
|
91
|
+
```rust,ignore
|
|
92
|
+
use sgemm_bi::{Dtype, SgemmBi, TypedPtr};
|
|
93
|
+
|
|
94
|
+
let context = cudarc::driver::CudaContext::new(0).unwrap();
|
|
95
|
+
let stream = context.new_stream().unwrap();
|
|
96
|
+
let engine = SgemmBi::new(&context, stream.clone()).unwrap();
|
|
97
|
+
|
|
98
|
+
// y/x/w are CUdeviceptr device allocations on `stream` (bf16 storage).
|
|
99
|
+
engine
|
|
100
|
+
.forward(
|
|
101
|
+
TypedPtr::new(y, Dtype::Bf16),
|
|
102
|
+
TypedPtr::new(x, Dtype::Bf16),
|
|
103
|
+
TypedPtr::new(w, Dtype::Bf16),
|
|
104
|
+
Some(bias_f32_ptr),
|
|
105
|
+
(m, k, n),
|
|
106
|
+
)
|
|
107
|
+
.unwrap();
|
|
108
|
+
```
|
|
109
|
+
|
|
110
|
+
The engine binds to one stream; all calls enqueue and return. For CUDA
|
|
111
|
+
Graph capture, call `presize_upcast_scratch` before capturing so the
|
|
112
|
+
typed fallback never allocates inside (or after) a captured graph.
|
|
113
|
+
|
|
114
|
+
## Requirements
|
|
115
|
+
|
|
116
|
+
- NVIDIA GPU, `sm_80`+ for the bf16/f16 and tensor-core tiers (`cp.async`,
|
|
117
|
+
`ldmatrix`, bf16 `mma.sync`); the f32 tier runs on older architectures.
|
|
118
|
+
- CUDA driver + NVRTC at run time. Kernels compile at engine construction
|
|
119
|
+
for the device's native architecture — no toolkit or `nvcc` needed.
|
|
120
|
+
- No cuBLAS: the library never links or calls a vendor BLAS.
|
|
121
|
+
|
|
122
|
+
## Testing
|
|
123
|
+
|
|
124
|
+
Contract tests require a CUDA device:
|
|
125
|
+
|
|
126
|
+
```sh
|
|
127
|
+
cargo test --release -- --test-threads=1
|
|
128
|
+
```
|
|
129
|
+
|
|
130
|
+
Covered: f32 run-to-run bit identity; the typed bit contract swept across
|
|
131
|
+
~90 dispatch-gate boundary shapes (forward) plus backward shapes;
|
|
132
|
+
per-bucket batch invariance; tensor-core determinism, strict all-M
|
|
133
|
+
invariance, and accuracy vs the f32 reference. Benchmarks are `#[ignore]`d
|
|
134
|
+
(`bench_tc_vs_scalar`).
|
|
135
|
+
|
|
136
|
+
## Lineage
|
|
137
|
+
|
|
138
|
+
The Big-tile kernels descend from [siboehm's SGEMM
|
|
139
|
+
warptiling](https://github.com/siboehm/SGEMM_CUDA) work; smem padding
|
|
140
|
+
follows [salykova's sgemm.cu](https://github.com/salykova/sgemm.cu). The
|
|
141
|
+
engine is extracted from the GEMM layer of
|
|
142
|
+
[mamba-rs](https://github.com/silvermpx/mamba-rs), where it powers
|
|
143
|
+
deterministic SSM training.
|
|
144
|
+
|
|
145
|
+
## License
|
|
146
|
+
|
|
147
|
+
Dual-licensed under MIT or Apache-2.0.
|
sgemm_bi-0.1.1/deny.toml
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
[advisories]
|
|
2
|
+
version = 2
|
|
3
|
+
|
|
4
|
+
[licenses]
|
|
5
|
+
version = 2
|
|
6
|
+
allow = [
|
|
7
|
+
"MIT",
|
|
8
|
+
"Apache-2.0",
|
|
9
|
+
"BSD-2-Clause",
|
|
10
|
+
"BSD-3-Clause",
|
|
11
|
+
"ISC",
|
|
12
|
+
"Unicode-3.0",
|
|
13
|
+
"Zlib",
|
|
14
|
+
]
|
|
15
|
+
|
|
16
|
+
[bans]
|
|
17
|
+
multiple-versions = "warn"
|
|
18
|
+
|
|
19
|
+
[sources]
|
|
20
|
+
unknown-registry = "deny"
|
|
21
|
+
unknown-git = "deny"
|