ejkernel 0.0.21__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ejkernel-0.0.21/PKG-INFO +711 -0
- ejkernel-0.0.21/README.md +673 -0
- ejkernel-0.0.21/ejkernel/__init__.py +51 -0
- ejkernel-0.0.21/ejkernel/benchmarks.py +772 -0
- ejkernel-0.0.21/ejkernel/callib/__init__.py +43 -0
- ejkernel-0.0.21/ejkernel/callib/_ejit.py +521 -0
- ejkernel-0.0.21/ejkernel/callib/_pallas_call.py +173 -0
- ejkernel-0.0.21/ejkernel/callib/_triton_call.py +891 -0
- ejkernel-0.0.21/ejkernel/callib/_utils.py +243 -0
- ejkernel-0.0.21/ejkernel/kernels/__init__.py +65 -0
- ejkernel-0.0.21/ejkernel/kernels/_cuda/__init__.py +15 -0
- ejkernel-0.0.21/ejkernel/kernels/_pallas/__init__.py +18 -0
- ejkernel-0.0.21/ejkernel/kernels/_pallas/gpu/__init__.py +19 -0
- ejkernel-0.0.21/ejkernel/kernels/_pallas/gpu/ragged_decode_attention/__init__.py +18 -0
- ejkernel-0.0.21/ejkernel/kernels/_pallas/gpu/ragged_decode_attention/_interface.py +81 -0
- ejkernel-0.0.21/ejkernel/kernels/_pallas/gpu/ragged_decode_attention/_pallas_impl_fwd.py +382 -0
- ejkernel-0.0.21/ejkernel/kernels/_pallas/gpu/scaled_dot_product_attention/__init__.py +25 -0
- ejkernel-0.0.21/ejkernel/kernels/_pallas/gpu/scaled_dot_product_attention/_interface.py +131 -0
- ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/__init__.py +36 -0
- ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/blocksparse_attention/__init__.py +66 -0
- ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/blocksparse_attention/_info.py +975 -0
- ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/blocksparse_attention/_kernel.py +2573 -0
- ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/blocksparse_attention/_masks.py +546 -0
- ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/flash_attention/__init__.py +24 -0
- ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/flash_attention/_interface.py +188 -0
- ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/flash_attention/_pallas_impl_bwd.py +676 -0
- ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/flash_attention/_pallas_impl_fwd.py +451 -0
- ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/flash_attention/_utils.py +373 -0
- ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/grouped_matmul/__init__.py +18 -0
- ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/grouped_matmul/_interface.py +250 -0
- ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/grouped_matmul/_pallas_impl.py +979 -0
- ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/grouped_matmul/_utils.py +191 -0
- ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/grouped_matmulv2/__init__.py +18 -0
- ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/grouped_matmulv2/_interface.py +250 -0
- ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/grouped_matmulv2/_pallas_impl.py +545 -0
- ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/page_attention/__init__.py +18 -0
- ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/page_attention/_interface.py +253 -0
- ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/page_attention/_pallas_impl_fwd.py +290 -0
- ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/ragged_decode_attention/__init__.py +18 -0
- ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/ragged_decode_attention/_interface.py +74 -0
- ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/ragged_decode_attention/_pallas_impl_fwd.py +247 -0
- ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/ragged_page_attention_v2/__init__.py +20 -0
- ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/ragged_page_attention_v2/_interface.py +257 -0
- ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/ragged_page_attention_v2/_pallas_impl_fwd.py +602 -0
- ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/ragged_page_attention_v2/_utils.py +437 -0
- ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/ragged_page_attention_v3/__init__.py +23 -0
- ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/ragged_page_attention_v3/_interface.py +120 -0
- ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/ragged_page_attention_v3/_pallas_impl_fwd.py +1300 -0
- ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/ragged_page_attention_v3/_pallas_impl_fwd_h64.py +1250 -0
- ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/ragged_page_attention_v3/_utils.py +4545 -0
- ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/ring_attention/__init__.py +24 -0
- ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/ring_attention/_interface.py +225 -0
- ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/ring_attention/_pallas_impl_bwd.py +1003 -0
- ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/ring_attention/_pallas_impl_fwd.py +769 -0
- ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/ring_attention/_utils.py +192 -0
- ejkernel-0.0.21/ejkernel/kernels/_registry.py +377 -0
- ejkernel-0.0.21/ejkernel/kernels/_triton/__init__.py +59 -0
- ejkernel-0.0.21/ejkernel/kernels/_triton/blocksparse_attention/__init__.py +33 -0
- ejkernel-0.0.21/ejkernel/kernels/_triton/blocksparse_attention/_interface.py +462 -0
- ejkernel-0.0.21/ejkernel/kernels/_triton/blocksparse_attention/_mask.py +530 -0
- ejkernel-0.0.21/ejkernel/kernels/_triton/blocksparse_attention/_triton_impl_bwd.py +1380 -0
- ejkernel-0.0.21/ejkernel/kernels/_triton/blocksparse_attention/_triton_impl_fwd.py +695 -0
- ejkernel-0.0.21/ejkernel/kernels/_triton/blocksparse_attention/_utilities.py +437 -0
- ejkernel-0.0.21/ejkernel/kernels/_triton/flash_attention/__init__.py +24 -0
- ejkernel-0.0.21/ejkernel/kernels/_triton/flash_attention/_interface.py +380 -0
- ejkernel-0.0.21/ejkernel/kernels/_triton/flash_attention/_triton_impl_bwd.py +1494 -0
- ejkernel-0.0.21/ejkernel/kernels/_triton/flash_attention/_triton_impl_fwd.py +827 -0
- ejkernel-0.0.21/ejkernel/kernels/_triton/flash_attention/_utilities.py +358 -0
- ejkernel-0.0.21/ejkernel/kernels/_triton/flash_mla/__init__.py +38 -0
- ejkernel-0.0.21/ejkernel/kernels/_triton/flash_mla/_interface.py +100 -0
- ejkernel-0.0.21/ejkernel/kernels/_triton/flash_mla/_triton_impl_bwd.py +15 -0
- ejkernel-0.0.21/ejkernel/kernels/_triton/flash_mla/_triton_impl_fwd.py +180 -0
- ejkernel-0.0.21/ejkernel/kernels/_triton/flash_mla/_utilities.py +15 -0
- ejkernel-0.0.21/ejkernel/kernels/_triton/gla/__init__.py +18 -0
- ejkernel-0.0.21/ejkernel/kernels/_triton/gla/_interface.py +136 -0
- ejkernel-0.0.21/ejkernel/kernels/_triton/lightning_attn/__init__.py +18 -0
- ejkernel-0.0.21/ejkernel/kernels/_triton/lightning_attn/_interface.py +143 -0
- ejkernel-0.0.21/ejkernel/kernels/_triton/mean_pooling/__init__.py +30 -0
- ejkernel-0.0.21/ejkernel/kernels/_triton/mean_pooling/_interface.py +163 -0
- ejkernel-0.0.21/ejkernel/kernels/_triton/mean_pooling/_triton_impl_bwd.py +130 -0
- ejkernel-0.0.21/ejkernel/kernels/_triton/mean_pooling/_triton_impl_fwd.py +126 -0
- ejkernel-0.0.21/ejkernel/kernels/_triton/native_sparse_attention/__init__.py +32 -0
- ejkernel-0.0.21/ejkernel/kernels/_triton/native_sparse_attention/_compression.py +753 -0
- ejkernel-0.0.21/ejkernel/kernels/_triton/native_sparse_attention/_interface.py +403 -0
- ejkernel-0.0.21/ejkernel/kernels/_triton/native_sparse_attention/_triton_impl_bwd.py +394 -0
- ejkernel-0.0.21/ejkernel/kernels/_triton/native_sparse_attention/_triton_impl_fwd.py +428 -0
- ejkernel-0.0.21/ejkernel/kernels/_triton/native_sparse_attention/_utilities.py +87 -0
- ejkernel-0.0.21/ejkernel/kernels/_triton/page_attention/__init__.py +18 -0
- ejkernel-0.0.21/ejkernel/kernels/_triton/page_attention/_interface.py +369 -0
- ejkernel-0.0.21/ejkernel/kernels/_triton/page_attention/_triton_impl_fwd.py +257 -0
- ejkernel-0.0.21/ejkernel/kernels/_triton/ragged_decode_attention/__init__.py +18 -0
- ejkernel-0.0.21/ejkernel/kernels/_triton/ragged_decode_attention/_interface.py +75 -0
- ejkernel-0.0.21/ejkernel/kernels/_triton/ragged_decode_attention/_triton_impl_fwd.py +302 -0
- ejkernel-0.0.21/ejkernel/kernels/_triton/ragged_page_attention_v2/__init__.py +18 -0
- ejkernel-0.0.21/ejkernel/kernels/_triton/ragged_page_attention_v2/_interface.py +226 -0
- ejkernel-0.0.21/ejkernel/kernels/_triton/ragged_page_attention_v2/_triton_impl_fwd.py +613 -0
- ejkernel-0.0.21/ejkernel/kernels/_triton/ragged_page_attention_v3/_triton_impl_fwd.py +222 -0
- ejkernel-0.0.21/ejkernel/kernels/_triton/recurrent/__init__.py +30 -0
- ejkernel-0.0.21/ejkernel/kernels/_triton/recurrent/_interface.py +291 -0
- ejkernel-0.0.21/ejkernel/kernels/_triton/recurrent/_triton_impl_bwd.py +354 -0
- ejkernel-0.0.21/ejkernel/kernels/_triton/recurrent/_triton_impl_fwd.py +216 -0
- ejkernel-0.0.21/ejkernel/kernels/_xla/__init__.py +46 -0
- ejkernel-0.0.21/ejkernel/kernels/_xla/attention/__init__.py +18 -0
- ejkernel-0.0.21/ejkernel/kernels/_xla/attention/_interface.py +193 -0
- ejkernel-0.0.21/ejkernel/kernels/_xla/flash_attention/__init__.py +24 -0
- ejkernel-0.0.21/ejkernel/kernels/_xla/flash_attention/_interface.py +376 -0
- ejkernel-0.0.21/ejkernel/kernels/_xla/flash_attention/_xla_impl_bwd.py +85 -0
- ejkernel-0.0.21/ejkernel/kernels/_xla/flash_attention/_xla_impl_fwd.py +474 -0
- ejkernel-0.0.21/ejkernel/kernels/_xla/gla/__init__.py +18 -0
- ejkernel-0.0.21/ejkernel/kernels/_xla/gla/_interface.py +123 -0
- ejkernel-0.0.21/ejkernel/kernels/_xla/grouped_matmul/__init__.py +18 -0
- ejkernel-0.0.21/ejkernel/kernels/_xla/grouped_matmul/_interface.py +140 -0
- ejkernel-0.0.21/ejkernel/kernels/_xla/lightning_attn/__init__.py +18 -0
- ejkernel-0.0.21/ejkernel/kernels/_xla/lightning_attn/_interface.py +129 -0
- ejkernel-0.0.21/ejkernel/kernels/_xla/mean_pooling/__init__.py +18 -0
- ejkernel-0.0.21/ejkernel/kernels/_xla/mean_pooling/_interface.py +171 -0
- ejkernel-0.0.21/ejkernel/kernels/_xla/native_sparse_attention/__init__.py +32 -0
- ejkernel-0.0.21/ejkernel/kernels/_xla/native_sparse_attention/_interface.py +435 -0
- ejkernel-0.0.21/ejkernel/kernels/_xla/native_sparse_attention/_xla_impl_bwd.py +138 -0
- ejkernel-0.0.21/ejkernel/kernels/_xla/native_sparse_attention/_xla_impl_fwd.py +86 -0
- ejkernel-0.0.21/ejkernel/kernels/_xla/page_attention/__init__.py +18 -0
- ejkernel-0.0.21/ejkernel/kernels/_xla/page_attention/_interface.py +123 -0
- ejkernel-0.0.21/ejkernel/kernels/_xla/page_attention/_xla_impl_fwd.py +97 -0
- ejkernel-0.0.21/ejkernel/kernels/_xla/ragged_decode_attention/__init__.py +18 -0
- ejkernel-0.0.21/ejkernel/kernels/_xla/ragged_decode_attention/_interface.py +119 -0
- ejkernel-0.0.21/ejkernel/kernels/_xla/ragged_decode_attention/_xla_impl_fwd.py +468 -0
- ejkernel-0.0.21/ejkernel/kernels/_xla/ragged_page_attention_v2/__init__.py +18 -0
- ejkernel-0.0.21/ejkernel/kernels/_xla/ragged_page_attention_v2/_interface.py +109 -0
- ejkernel-0.0.21/ejkernel/kernels/_xla/ragged_page_attention_v2/_xla_impl_fwd.py +191 -0
- ejkernel-0.0.21/ejkernel/kernels/_xla/ragged_page_attention_v3/__init__.py +18 -0
- ejkernel-0.0.21/ejkernel/kernels/_xla/ragged_page_attention_v3/_interface.py +98 -0
- ejkernel-0.0.21/ejkernel/kernels/_xla/ragged_page_attention_v3/_kernel.py +427 -0
- ejkernel-0.0.21/ejkernel/kernels/_xla/recurrent/__init__.py +30 -0
- ejkernel-0.0.21/ejkernel/kernels/_xla/recurrent/_interface.py +201 -0
- ejkernel-0.0.21/ejkernel/kernels/_xla/recurrent/_xla_impl_bwd.py +163 -0
- ejkernel-0.0.21/ejkernel/kernels/_xla/recurrent/_xla_impl_fwd.py +242 -0
- ejkernel-0.0.21/ejkernel/kernels/_xla/ring_attention/__init__.py +24 -0
- ejkernel-0.0.21/ejkernel/kernels/_xla/ring_attention/_interface.py +212 -0
- ejkernel-0.0.21/ejkernel/kernels/_xla/ring_attention/_utils.py +163 -0
- ejkernel-0.0.21/ejkernel/kernels/_xla/ring_attention/_xla_impl_bwd.py +354 -0
- ejkernel-0.0.21/ejkernel/kernels/_xla/ring_attention/_xla_impl_fwd.py +388 -0
- ejkernel-0.0.21/ejkernel/kernels/_xla/scaled_dot_product_attention/__init__.py +24 -0
- ejkernel-0.0.21/ejkernel/kernels/_xla/scaled_dot_product_attention/_interface.py +106 -0
- ejkernel-0.0.21/ejkernel/loggings.py +482 -0
- ejkernel-0.0.21/ejkernel/modules/__init__.py +144 -0
- ejkernel-0.0.21/ejkernel/modules/base.py +213 -0
- ejkernel-0.0.21/ejkernel/modules/operations/__init__.py +148 -0
- ejkernel-0.0.21/ejkernel/modules/operations/attention.py +286 -0
- ejkernel-0.0.21/ejkernel/modules/operations/blocksparse_attention.py +937 -0
- ejkernel-0.0.21/ejkernel/modules/operations/configs.py +354 -0
- ejkernel-0.0.21/ejkernel/modules/operations/flash_attention.py +886 -0
- ejkernel-0.0.21/ejkernel/modules/operations/gated_linear_attention.py +301 -0
- ejkernel-0.0.21/ejkernel/modules/operations/grouped_matmul.py +426 -0
- ejkernel-0.0.21/ejkernel/modules/operations/lightning_attention.py +306 -0
- ejkernel-0.0.21/ejkernel/modules/operations/multi_head_latent_attention.py +290 -0
- ejkernel-0.0.21/ejkernel/modules/operations/native_sparse_attention.py +374 -0
- ejkernel-0.0.21/ejkernel/modules/operations/page_attention.py +444 -0
- ejkernel-0.0.21/ejkernel/modules/operations/pooling.py +244 -0
- ejkernel-0.0.21/ejkernel/modules/operations/ragged_decode_attention.py +588 -0
- ejkernel-0.0.21/ejkernel/modules/operations/ragged_page_attention_v2.py +702 -0
- ejkernel-0.0.21/ejkernel/modules/operations/ragged_page_attention_v3.py +1091 -0
- ejkernel-0.0.21/ejkernel/modules/operations/recurrent.py +328 -0
- ejkernel-0.0.21/ejkernel/modules/operations/ring_attention.py +617 -0
- ejkernel-0.0.21/ejkernel/modules/operations/scaled_dot_product_attention.py +417 -0
- ejkernel-0.0.21/ejkernel/ops/__init__.py +150 -0
- ejkernel-0.0.21/ejkernel/ops/config/__init__.py +43 -0
- ejkernel-0.0.21/ejkernel/ops/config/cache.py +187 -0
- ejkernel-0.0.21/ejkernel/ops/config/persistent.py +188 -0
- ejkernel-0.0.21/ejkernel/ops/config/selection.py +585 -0
- ejkernel-0.0.21/ejkernel/ops/core/__init__.py +36 -0
- ejkernel-0.0.21/ejkernel/ops/core/kernel.py +671 -0
- ejkernel-0.0.21/ejkernel/ops/core/types.py +50 -0
- ejkernel-0.0.21/ejkernel/ops/execution/__init__.py +67 -0
- ejkernel-0.0.21/ejkernel/ops/execution/batch.py +191 -0
- ejkernel-0.0.21/ejkernel/ops/execution/executor.py +514 -0
- ejkernel-0.0.21/ejkernel/ops/execution/offline.py +80 -0
- ejkernel-0.0.21/ejkernel/ops/execution/profiler.py +475 -0
- ejkernel-0.0.21/ejkernel/ops/execution/tuning.py +1194 -0
- ejkernel-0.0.21/ejkernel/ops/registry.py +93 -0
- ejkernel-0.0.21/ejkernel/ops/utils/__init__.py +72 -0
- ejkernel-0.0.21/ejkernel/ops/utils/datacarrier.py +126 -0
- ejkernel-0.0.21/ejkernel/ops/utils/fingerprint.py +342 -0
- ejkernel-0.0.21/ejkernel/ops/utils/meta.py +156 -0
- ejkernel-0.0.21/ejkernel/ops/utils/serialize.py +97 -0
- ejkernel-0.0.21/ejkernel/types/__init__.py +18 -0
- ejkernel-0.0.21/ejkernel/types/mask.py +2248 -0
- ejkernel-0.0.21/ejkernel/utils.py +837 -0
- ejkernel-0.0.21/ejkernel/xla_utils/__init__.py +45 -0
- ejkernel-0.0.21/ejkernel/xla_utils/cumsum.py +297 -0
- ejkernel-0.0.21/ejkernel/xla_utils/shardings.py +102 -0
- ejkernel-0.0.21/ejkernel/xla_utils/utils.py +184 -0
- ejkernel-0.0.21/pyproject.toml +126 -0
ejkernel-0.0.21/PKG-INFO
ADDED
|
@@ -0,0 +1,711 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: ejkernel
|
|
3
|
+
Version: 0.0.21
|
|
4
|
+
Summary: Accelerate, Optimize performance with streamlined training and serving options with JAX.
|
|
5
|
+
Keywords: Deep Learning,Machine Learning,JAX,CUDA,XLA,Triton,Pallas
|
|
6
|
+
Author: Erfan Zare Chavoshi
|
|
7
|
+
Author-email: Erfan Zare Chavoshi <Erfanzare810@gmail.com>
|
|
8
|
+
License: Apache-2.0
|
|
9
|
+
Classifier: Development Status :: 3 - Alpha
|
|
10
|
+
Classifier: Intended Audience :: Developers
|
|
11
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
12
|
+
Classifier: License :: OSI Approved :: Apache Software License
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
17
|
+
Requires-Dist: beartype>=0.22.2
|
|
18
|
+
Requires-Dist: chex>=0.1.91
|
|
19
|
+
Requires-Dist: einops>=0.8.1
|
|
20
|
+
Requires-Dist: jax>=0.8.0
|
|
21
|
+
Requires-Dist: jaxlib>=0.8.0
|
|
22
|
+
Requires-Dist: jaxtyping>=0.3.2
|
|
23
|
+
Requires-Dist: pydantic>=2.11.10
|
|
24
|
+
Requires-Dist: triton==3.4.0
|
|
25
|
+
Requires-Dist: jax[cuda12]>=0.8.0 ; extra == 'gpu'
|
|
26
|
+
Requires-Dist: xprof>=2.20.6 ; extra == 'profile'
|
|
27
|
+
Requires-Dist: tb-nightly>=2.21.0a20250820 ; extra == 'profile'
|
|
28
|
+
Requires-Dist: xprof-nightly>=2.21.6a20250820 ; extra == 'profile'
|
|
29
|
+
Requires-Dist: jax[tpu]>=0.8.0 ; extra == 'tpu'
|
|
30
|
+
Requires-Python: >=3.11, <3.14
|
|
31
|
+
Project-URL: Documentation, https://ejkernel.readthedocs.io/en/latest/
|
|
32
|
+
Project-URL: Homepage, https://github.com/erfanzar/ejkernel
|
|
33
|
+
Project-URL: Repository, https://github.com/erfanzar/ejkernel
|
|
34
|
+
Provides-Extra: gpu
|
|
35
|
+
Provides-Extra: profile
|
|
36
|
+
Provides-Extra: tpu
|
|
37
|
+
Description-Content-Type: text/markdown
|
|
38
|
+
|
|
39
|
+
# ejKernel: High-Performance JAX Kernels for Deep Learning
|
|
40
|
+
|
|
41
|
+
> *"The best optimization is the one you don't have to think about."*
|
|
42
|
+
|
|
43
|
+
[](https://opensource.org/licenses/Apache-2.0)
|
|
44
|
+
[](https://www.python.org/downloads/)
|
|
45
|
+
[](https://github.com/google/jax)
|
|
46
|
+
[](https://ejkernel.readthedocs.io/en/latest/)
|
|
47
|
+
|
|
48
|
+
ejKernel is a production-grade kernel library for JAX that provides highly optimized implementations of deep learning operations with automatic multi-backend support. The library features a sophisticated configuration management system with autotuning, comprehensive type safety, and seamless execution across GPUs, TPUs, and CPUs.
|
|
49
|
+
|
|
50
|
+
## Table of Contents
|
|
51
|
+
|
|
52
|
+
- [Key Features](#key-features)
|
|
53
|
+
- [Installation](#installation)
|
|
54
|
+
- [Quick Start](#quick-start)
|
|
55
|
+
- [Architecture Overview](#architecture-overview)
|
|
56
|
+
- [Supported Operations](#supported-operations)
|
|
57
|
+
- [Advanced Usage](#advanced-usage)
|
|
58
|
+
- [Development](#development)
|
|
59
|
+
- [Testing](#testing)
|
|
60
|
+
- [Contributing](#contributing)
|
|
61
|
+
- [Citation](#citation)
|
|
62
|
+
- [License](#license)
|
|
63
|
+
|
|
64
|
+
## Key Features
|
|
65
|
+
|
|
66
|
+
### Intelligent Kernel Management
|
|
67
|
+
|
|
68
|
+
- **7-Tier Configuration System**: Override → Overlay → Memory Cache → Persistent Cache → Autotune → Heuristics → Error
|
|
69
|
+
- **Automatic Platform Detection**: Seamlessly selects optimal implementation based on hardware
|
|
70
|
+
- **Priority-Based Registry**: Multi-backend support with intelligent fallback mechanisms
|
|
71
|
+
- **Device Fingerprinting**: Hardware-specific configuration caching for optimal performance
|
|
72
|
+
|
|
73
|
+
### State-of-the-Art Operations
|
|
74
|
+
|
|
75
|
+
- **15+ Attention Mechanisms**: Flash Attention v2, Ring Attention, Page Attention, Block Sparse, GLA, Lightning, Ragged Page Attention, and more
|
|
76
|
+
- **Memory Efficiency**: Custom VJP implementations with O(N) memory complexity for attention
|
|
77
|
+
- **Distributed Support**: Full shard_map integration for model and data parallelism
|
|
78
|
+
- **Mixed Precision**: Comprehensive dtype support with automatic gradient conversion
|
|
79
|
+
|
|
80
|
+
### Production-Ready Infrastructure
|
|
81
|
+
|
|
82
|
+
- **Type Safety**: Full jaxtyping annotations with runtime validation via beartype
|
|
83
|
+
- **Comprehensive Testing**: Cross-backend validation, performance benchmarks, integration tests
|
|
84
|
+
- **Atomic Persistence**: Thread-safe configuration storage with automatic optimization
|
|
85
|
+
- **Profiling Integration**: Built-in support for JAX profiling and performance monitoring
|
|
86
|
+
|
|
87
|
+
## Installation
|
|
88
|
+
|
|
89
|
+
### Basic Installation
|
|
90
|
+
|
|
91
|
+
```bash
|
|
92
|
+
pip install ejkernel
|
|
93
|
+
```
|
|
94
|
+
|
|
95
|
+
### Platform-Specific Installation
|
|
96
|
+
|
|
97
|
+
```bash
|
|
98
|
+
# GPU Support (CUDA/ROCm)
|
|
99
|
+
pip install ejkernel[gpu]
|
|
100
|
+
|
|
101
|
+
# TPU Support
|
|
102
|
+
pip install ejkernel[tpu]
|
|
103
|
+
|
|
104
|
+
# Development Installation
|
|
105
|
+
git clone https://github.com/erfanzar/ejkernel.git
|
|
106
|
+
cd ejkernel
|
|
107
|
+
pip install -e ".[dev]"
|
|
108
|
+
```
|
|
109
|
+
|
|
110
|
+
### Dependencies
|
|
111
|
+
|
|
112
|
+
- Python 3.11-3.13
|
|
113
|
+
- JAX >= 0.8.0
|
|
114
|
+
- Triton == 3.4.0 (for GPU)
|
|
115
|
+
- jaxtyping >= 0.3.2
|
|
116
|
+
- beartype >= 0.22.2
|
|
117
|
+
|
|
118
|
+
## Quick Start
|
|
119
|
+
|
|
120
|
+
### Simple API with Automatic Optimization
|
|
121
|
+
|
|
122
|
+
```python
|
|
123
|
+
import jax.numpy as jnp
|
|
124
|
+
from ejkernel.modules import flash_attention
|
|
125
|
+
|
|
126
|
+
# Basic usage - automatic configuration selection
|
|
127
|
+
output = flash_attention(
|
|
128
|
+
query, key, value,
|
|
129
|
+
causal=True,
|
|
130
|
+
dropout_prob=0.1
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
# With advanced features
|
|
134
|
+
output = flash_attention(
|
|
135
|
+
query, key, value,
|
|
136
|
+
causal=True,
|
|
137
|
+
sliding_window=128, # Local attention window
|
|
138
|
+
logits_soft_cap=30.0, # Gemma-2 style soft capping
|
|
139
|
+
attention_mask=mask, # Custom attention pattern
|
|
140
|
+
)
|
|
141
|
+
```
|
|
142
|
+
|
|
143
|
+
### Custom Configuration
|
|
144
|
+
|
|
145
|
+
```python
|
|
146
|
+
from ejkernel.modules import FlashAttentionConfig
|
|
147
|
+
from ejkernel.ops.utils.datacarrier import FwdParams, BwdParams
|
|
148
|
+
|
|
149
|
+
# Create optimized configuration
|
|
150
|
+
config = FlashAttentionConfig(
|
|
151
|
+
fwd_params=FwdParams(
|
|
152
|
+
q_blocksize=256,
|
|
153
|
+
kv_blocksize=256,
|
|
154
|
+
num_warps=8,
|
|
155
|
+
num_stages=2
|
|
156
|
+
),
|
|
157
|
+
bwd_params=BwdParams(
|
|
158
|
+
q_blocksize=128,
|
|
159
|
+
kv_blocksize=128,
|
|
160
|
+
num_warps=4
|
|
161
|
+
),
|
|
162
|
+
platform="triton", # Force specific backend
|
|
163
|
+
backend="gpu"
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
output = flash_attention(query, key, value, cfg=config)
|
|
167
|
+
```
|
|
168
|
+
|
|
169
|
+
### Direct Kernel Registry Access
|
|
170
|
+
|
|
171
|
+
```python
|
|
172
|
+
from ejkernel import kernel_registry, Platform, Backend
|
|
173
|
+
|
|
174
|
+
# Get specific implementation
|
|
175
|
+
kernel = kernel_registry.get(
|
|
176
|
+
algorithm="flash_attention",
|
|
177
|
+
platform=Platform.TRITON,
|
|
178
|
+
backend=Backend.GPU
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
# Direct execution
|
|
182
|
+
output = kernel(query, key, value, causal=True)
|
|
183
|
+
```
|
|
184
|
+
|
|
185
|
+
### Distributed Execution
|
|
186
|
+
|
|
187
|
+
```python
|
|
188
|
+
import jax
|
|
189
|
+
from jax.sharding import Mesh, PartitionSpec as P
|
|
190
|
+
from ejkernel.modules import flash_attention
|
|
191
|
+
|
|
192
|
+
# Setup mesh for distributed execution
|
|
193
|
+
devices = jax.devices()
|
|
194
|
+
mesh = Mesh(devices, axis_names=("data", "model"))
|
|
195
|
+
|
|
196
|
+
# Run distributed attention
|
|
197
|
+
output = flash_attention(
|
|
198
|
+
query, key, value,
|
|
199
|
+
causal=True,
|
|
200
|
+
mesh=mesh,
|
|
201
|
+
in_specs=(P("data", None), P("data", None), P("data", None)),
|
|
202
|
+
out_specs=P("data", None)
|
|
203
|
+
)
|
|
204
|
+
```
|
|
205
|
+
|
|
206
|
+
## Architecture Overview
|
|
207
|
+
|
|
208
|
+
### System Design
|
|
209
|
+
|
|
210
|
+
ejKernel employs a sophisticated layered architecture that separates concerns while maintaining high performance:
|
|
211
|
+
|
|
212
|
+
```md
|
|
213
|
+
┌─────────────────────────────────────────────────────┐
|
|
214
|
+
│ Public API (modules/) │
|
|
215
|
+
│ Simple functions with sensible defaults │
|
|
216
|
+
├─────────────────────────────────────────────────────┤
|
|
217
|
+
│ Operations Layer (ops/) │
|
|
218
|
+
│ Configuration management, autotuning, caching │
|
|
219
|
+
├─────────────────────────────────────────────────────┤
|
|
220
|
+
│ Kernel Registry (kernels/) │
|
|
221
|
+
│ Platform routing, signature validation │
|
|
222
|
+
├─────────────────────────────────────────────────────┤
|
|
223
|
+
│ Backend Implementations (kernels/_*) │
|
|
224
|
+
│ Triton, Pallas, XLA, CUDA kernels │
|
|
225
|
+
└─────────────────────────────────────────────────────┘
|
|
226
|
+
```
|
|
227
|
+
|
|
228
|
+
### Project Structure
|
|
229
|
+
|
|
230
|
+
```md
|
|
231
|
+
ejkernel/
|
|
232
|
+
├── kernels/ # Low-level kernel implementations
|
|
233
|
+
│ ├── _triton/ # Triton kernels (GPU)
|
|
234
|
+
│ │ ├── flash_attention/
|
|
235
|
+
│ │ ├── page_attention/
|
|
236
|
+
│ │ ├── ragged_page_attention_v2/
|
|
237
|
+
│ │ ├── gated_linear_attention/
|
|
238
|
+
│ │ ├── lightning_attn/
|
|
239
|
+
│ │ ├── mean_pooling/
|
|
240
|
+
│ │ ├── native_sparse_attention/
|
|
241
|
+
│ │ ├── recurrent/
|
|
242
|
+
│ │ └── blocksparse_attention/
|
|
243
|
+
│ ├── _pallas/
|
|
244
|
+
│ │ ├── tpu/ # TPU-specific implementations
|
|
245
|
+
│ │ │ ├── flash_attention/
|
|
246
|
+
│ │ │ ├── ring_attention/
|
|
247
|
+
│ │ │ ├── page_attention/
|
|
248
|
+
│ │ │ ├── ragged_page_attention_v2/
|
|
249
|
+
│ │ │ ├── ragged_page_attention_v3/
|
|
250
|
+
│ │ │ ├── blocksparse_attention/
|
|
251
|
+
│ │ │ ├── grouped_matmul/
|
|
252
|
+
│ │ │ └── ragged_decode_attention/
|
|
253
|
+
│ │ └── gpu/ # GPU Pallas implementations
|
|
254
|
+
│ ├── _xla/ # XLA implementations (universal)
|
|
255
|
+
│ │ ├── attention/
|
|
256
|
+
│ │ ├── flash_attention/
|
|
257
|
+
│ │ ├── gated_linear_attention/
|
|
258
|
+
│ │ ├── grouped_matmul/
|
|
259
|
+
│ │ ├── lightning_attn/
|
|
260
|
+
│ │ ├── mean_pooling/
|
|
261
|
+
│ │ ├── native_sparse_attention/
|
|
262
|
+
│ │ ├── page_attention/
|
|
263
|
+
│ │ ├── ragged_decode_attention/
|
|
264
|
+
│ │ ├── ragged_page_attention_v2/
|
|
265
|
+
│ │ ├── ragged_page_attention_v3/
|
|
266
|
+
│ │ ├── recurrent/
|
|
267
|
+
│ │ ├── ring_attention/
|
|
268
|
+
│ │ └── scaled_dot_product_attention/
|
|
269
|
+
│ ├── _cuda/ # CUDA implementations (dev)
|
|
270
|
+
│ └── _registry.py # Kernel registry system
|
|
271
|
+
│
|
|
272
|
+
├── modules/ # High-level API
|
|
273
|
+
│ └── operations/
|
|
274
|
+
│ ├── flash_attention.py
|
|
275
|
+
│ ├── ring_attention.py
|
|
276
|
+
│ ├── page_attention.py
|
|
277
|
+
│ ├── ragged_page_attention_v2.py
|
|
278
|
+
│ ├── ragged_page_attention_v3.py
|
|
279
|
+
│ ├── blocksparse_attention.py
|
|
280
|
+
│ ├── gated_linear_attention.py
|
|
281
|
+
│ ├── lightning_attention.py
|
|
282
|
+
│ ├── native_sparse_attention.py
|
|
283
|
+
│ ├── recurrent.py
|
|
284
|
+
│ ├── grouped_matmul.py
|
|
285
|
+
│ ├── pooling.py
|
|
286
|
+
│ ├── attention.py
|
|
287
|
+
│ ├── multi_head_latent_attention.py
|
|
288
|
+
│ ├── ragged_decode_attention.py
|
|
289
|
+
│ ├── scaled_dot_product_attention.py
|
|
290
|
+
│ └── configs.py
|
|
291
|
+
│
|
|
292
|
+
├── ops/ # Configuration & execution framework
|
|
293
|
+
│ ├── config/ # Configuration management
|
|
294
|
+
│ │ ├── cache.py # In-memory config cache
|
|
295
|
+
│ │ ├── persistent.py # Disk-based persistence
|
|
296
|
+
│ │ └── selection.py # Config selection chain
|
|
297
|
+
│ ├── core/ # Base kernel class
|
|
298
|
+
│ ├── execution/ # Execution orchestration
|
|
299
|
+
│ │ └── tuning.py # Autotuning framework
|
|
300
|
+
│ ├── registry.py # Operation invocation tracking
|
|
301
|
+
│ └── utils/ # Utilities (fingerprinting, etc)
|
|
302
|
+
│
|
|
303
|
+
├── xla_utils/ # XLA-specific utilities
|
|
304
|
+
│ ├── cumsum.py # Cumulative sum operations
|
|
305
|
+
│ ├── shardings.py # Sharding utilities
|
|
306
|
+
│ └── utils.py # Sequence length utilities
|
|
307
|
+
│
|
|
308
|
+
├── types/ # Type definitions
|
|
309
|
+
│ └── mask.py # MaskInfo for attention masking
|
|
310
|
+
│
|
|
311
|
+
├── callib/ # Calling library
|
|
312
|
+
│ ├── _ejit.py # Enhanced JIT
|
|
313
|
+
│ ├── _triton_call.py # Triton kernel calling
|
|
314
|
+
│ └── _pallas_call.py # Pallas kernel calling
|
|
315
|
+
│
|
|
316
|
+
└── utils.py # General utilities
|
|
317
|
+
```
|
|
318
|
+
|
|
319
|
+
### Core Components
|
|
320
|
+
|
|
321
|
+
#### Kernel Registry
|
|
322
|
+
|
|
323
|
+
The registry provides automatic platform-specific kernel selection:
|
|
324
|
+
|
|
325
|
+
```python
|
|
326
|
+
@kernel_registry.register("my_operation", Platform.TRITON, Backend.GPU, priority=100)
|
|
327
|
+
def my_operation_gpu(x, y):
|
|
328
|
+
# GPU-optimized implementation
|
|
329
|
+
pass
|
|
330
|
+
|
|
331
|
+
@kernel_registry.register("my_operation", Platform.XLA, Backend.ANY, priority=50)
|
|
332
|
+
def my_operation_fallback(x, y):
|
|
333
|
+
# Universal fallback
|
|
334
|
+
pass
|
|
335
|
+
|
|
336
|
+
# Automatic selection based on available hardware
|
|
337
|
+
impl = kernel_registry.get("my_operation")
|
|
338
|
+
```
|
|
339
|
+
|
|
340
|
+
#### Configuration Management
|
|
341
|
+
|
|
342
|
+
Multi-tier configuration system with intelligent fallback:
|
|
343
|
+
|
|
344
|
+
```python
|
|
345
|
+
class ConfigSelectorChain:
|
|
346
|
+
"""
|
|
347
|
+
Selection hierarchy:
|
|
348
|
+
1. Override - Explicit user configuration
|
|
349
|
+
2. Overlay - Temporary context overrides
|
|
350
|
+
3. Memory Cache - In-memory lookup
|
|
351
|
+
4. Persistent Cache - Disk-based storage
|
|
352
|
+
5. Autotune - Performance benchmarking
|
|
353
|
+
6. Heuristics - Intelligent defaults
|
|
354
|
+
7. Error - Clear failure message
|
|
355
|
+
"""
|
|
356
|
+
```
|
|
357
|
+
|
|
358
|
+
#### Custom VJP System
|
|
359
|
+
|
|
360
|
+
All performance-critical kernels implement memory-efficient gradients:
|
|
361
|
+
|
|
362
|
+
```python
|
|
363
|
+
@jax.custom_vjp
|
|
364
|
+
def kernel_with_custom_grad(inputs):
|
|
365
|
+
return forward(inputs)
|
|
366
|
+
|
|
367
|
+
def kernel_fwd(inputs):
|
|
368
|
+
output, residuals = forward_with_residuals(inputs)
|
|
369
|
+
return output, residuals
|
|
370
|
+
|
|
371
|
+
def kernel_bwd(residuals, grad_output):
|
|
372
|
+
return efficient_backward(residuals, grad_output)
|
|
373
|
+
|
|
374
|
+
kernel_with_custom_grad.defvjp(kernel_fwd, kernel_bwd)
|
|
375
|
+
```
|
|
376
|
+
|
|
377
|
+
## Supported Operations
|
|
378
|
+
|
|
379
|
+
### Attention Mechanisms
|
|
380
|
+
|
|
381
|
+
| Algorithm | Description | Memory | Key Features |
|
|
382
|
+
|-----------|-------------|--------|--------------|
|
|
383
|
+
| **Flash Attention v2** | Memory-efficient exact attention | O(N) | Causal masking, dropout, sliding windows, soft capping |
|
|
384
|
+
| **Ring Attention** | Distributed sequence parallelism | O(N/P) | Ultra-long sequences, communication overlap |
|
|
385
|
+
| **Page Attention** | KV-cache optimized inference | O(N) | Block-wise memory, continuous batching |
|
|
386
|
+
| **Block Sparse Attention** | Configurable sparse patterns | O(N√N) | Local+global, custom patterns |
|
|
387
|
+
| **GLA** | Gated Linear Attention | O(N) | Linear complexity, gated updates |
|
|
388
|
+
| **Lightning Attention** | Layer-dependent decay | O(N) | Exponential moving average |
|
|
389
|
+
| **MLA** | Multi-head Latent Attention | O(N) | Compressed KV representation |
|
|
390
|
+
| **Ragged Page Attention v2** | Variable-length paged attention | O(N) | Ragged sequences with page caching |
|
|
391
|
+
| **Ragged Page Attention v3** | Enhanced ragged page attention | O(N) | Attention sinks support, improved handling |
|
|
392
|
+
| **Ragged Decode Attention** | Variable-length decoding | O(N) | Efficient batched inference |
|
|
393
|
+
| **Scaled Dot-Product Attention** | Standard attention | O(N²) | Basic reference implementation |
|
|
394
|
+
|
|
395
|
+
### Other Operations
|
|
396
|
+
|
|
397
|
+
| Operation | Description | Use Case |
|
|
398
|
+
|-----------|-------------|----------|
|
|
399
|
+
| **Grouped MatMul** | Efficient batched matrix operations | Expert models, MoE |
|
|
400
|
+
| **Grouped MatMul v2** | Enhanced with shard_map support | Distributed expert models |
|
|
401
|
+
| **Mean Pooling** | Variable-length sequence aggregation | Sentence embeddings |
|
|
402
|
+
| **Recurrent** | Optimized RNN/LSTM/GRU operations | Sequential modeling |
|
|
403
|
+
| **Native Sparse** | Block-sparse matrix computations | Sparse attention patterns |
|
|
404
|
+
|
|
405
|
+
### Platform Support Matrix
|
|
406
|
+
|
|
407
|
+
| Operation | Triton (GPU) | Pallas (TPU) | XLA (Universal) |
|
|
408
|
+
|-----------|:------------:|:------------:|:---------------:|
|
|
409
|
+
| Flash Attention v2 | ✅ | ✅ | ✅ |
|
|
410
|
+
| Ring Attention | ✅ | ✅ | ✅ |
|
|
411
|
+
| Page Attention | ✅ | ✅ | ✅ |
|
|
412
|
+
| Block Sparse Attention | ✅ | ✅ | ✅ |
|
|
413
|
+
| Ragged Page Attention v2 | ✅ | ✅ | ✅ |
|
|
414
|
+
| Ragged Page Attention v3 | - | ✅ | ✅ |
|
|
415
|
+
| Ragged Decode Attention | ✅ | ✅ | ✅ |
|
|
416
|
+
| GLA | ✅ | - | ✅ |
|
|
417
|
+
| Lightning Attention | ✅ | - | ✅ |
|
|
418
|
+
| MLA | ✅ | 🚧 | - |
|
|
419
|
+
| Recurrent | ✅ | - | ✅ |
|
|
420
|
+
| Mean Pooling | ✅ | - | ✅ |
|
|
421
|
+
| Grouped MatMul | - | ✅ | ✅ |
|
|
422
|
+
| Grouped MatMul v2 | - | ✅ | - |
|
|
423
|
+
| Native Sparse Attention | ✅ | - | ✅ |
|
|
424
|
+
|
|
425
|
+
✅ = Production ready | 🚧 = Under development | - = Not available
|
|
426
|
+
|
|
427
|
+
## Advanced Usage
|
|
428
|
+
|
|
429
|
+
### Page Attention for KV-Cache Inference
|
|
430
|
+
|
|
431
|
+
```python
|
|
432
|
+
from ejkernel.modules import page_attention, PageAttentionConfig
|
|
433
|
+
|
|
434
|
+
# Configure paged attention for inference
|
|
435
|
+
config = PageAttentionConfig(
|
|
436
|
+
platform="auto",
|
|
437
|
+
backend="gpu"
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
output = page_attention(
|
|
441
|
+
query=q,
|
|
442
|
+
key_cache=k_cache,
|
|
443
|
+
value_cache=v_cache,
|
|
444
|
+
block_table=block_table,
|
|
445
|
+
cache_seqlens=cache_seqlens,
|
|
446
|
+
cfg=config
|
|
447
|
+
)
|
|
448
|
+
```
|
|
449
|
+
|
|
450
|
+
### Ragged Page Attention for Variable-Length Batches
|
|
451
|
+
|
|
452
|
+
```python
|
|
453
|
+
from ejkernel.modules import ragged_page_attention_v3, RaggedPageAttentionv3Config
|
|
454
|
+
|
|
455
|
+
# For variable-length sequences with attention sinks
|
|
456
|
+
config = RaggedPageAttentionv3Config(
|
|
457
|
+
platform="pallas",
|
|
458
|
+
backend="tpu"
|
|
459
|
+
)
|
|
460
|
+
|
|
461
|
+
output = ragged_page_attention_v3(
|
|
462
|
+
query=q,
|
|
463
|
+
key_pages=k_pages,
|
|
464
|
+
value_pages=v_pages,
|
|
465
|
+
lengths=seq_lengths,
|
|
466
|
+
page_indices=page_indices,
|
|
467
|
+
cfg=config
|
|
468
|
+
)
|
|
469
|
+
```
|
|
470
|
+
|
|
471
|
+
### Performance Optimization
|
|
472
|
+
|
|
473
|
+
```python
|
|
474
|
+
# Force autotuning for optimal configuration
|
|
475
|
+
import os
|
|
476
|
+
os.environ["EJKERNEL_AUTOTUNE_POLICY"] = "autotune"
|
|
477
|
+
os.environ["EJKERNEL_LOG_AUTOTUNE"] = "1"
|
|
478
|
+
|
|
479
|
+
# Enable profiling
|
|
480
|
+
os.environ["EJKERNEL_OPS_STAMP"] = "json" # Detailed metadata
|
|
481
|
+
os.environ["EJKERNEL_OPS_RECORD"] = "1" # Record invocations
|
|
482
|
+
```
|
|
483
|
+
|
|
484
|
+
### Custom Kernel Development
|
|
485
|
+
|
|
486
|
+
```python
|
|
487
|
+
from ejkernel.ops.core import Kernel
|
|
488
|
+
from ejkernel.modules.operations.configs import BaseOperationConfig
|
|
489
|
+
from dataclasses import dataclass
|
|
490
|
+
|
|
491
|
+
@dataclass
|
|
492
|
+
class MyConfig(BaseOperationConfig):
|
|
493
|
+
param1: int = 128
|
|
494
|
+
param2: float = 0.1
|
|
495
|
+
|
|
496
|
+
class MyKernel(Kernel[MyConfig, Array]):
|
|
497
|
+
def __init__(self):
|
|
498
|
+
super().__init__(op_id="my_kernel")
|
|
499
|
+
|
|
500
|
+
def run(self, x, cfg: MyConfig):
|
|
501
|
+
impl = kernel_registry.get("my_kernel", cfg.platform)
|
|
502
|
+
return impl(x, param1=cfg.param1, param2=cfg.param2)
|
|
503
|
+
|
|
504
|
+
def heuristic_cfg(self, inv):
|
|
505
|
+
# Return default configuration
|
|
506
|
+
return MyConfig(param1=256)
|
|
507
|
+
|
|
508
|
+
def candidate_cfgs(self, inv):
|
|
509
|
+
# Return autotuning candidates
|
|
510
|
+
return [MyConfig(param1=p) for p in [64, 128, 256]]
|
|
511
|
+
```
|
|
512
|
+
|
|
513
|
+
### Integration with Flax Models
|
|
514
|
+
|
|
515
|
+
```python
|
|
516
|
+
import flax.linen as nn
|
|
517
|
+
from ejkernel.modules import flash_attention
|
|
518
|
+
|
|
519
|
+
class TransformerBlock(nn.Module):
|
|
520
|
+
num_heads: int = 8
|
|
521
|
+
head_dim: int = 64
|
|
522
|
+
|
|
523
|
+
@nn.compact
|
|
524
|
+
def __call__(self, x, mask=None):
|
|
525
|
+
# Project to Q, K, V
|
|
526
|
+
q = nn.Dense(self.num_heads * self.head_dim)(x)
|
|
527
|
+
k = nn.Dense(self.num_heads * self.head_dim)(x)
|
|
528
|
+
v = nn.Dense(self.num_heads * self.head_dim)(x)
|
|
529
|
+
|
|
530
|
+
# Reshape for attention
|
|
531
|
+
shape = (x.shape[0], x.shape[1], self.num_heads, self.head_dim)
|
|
532
|
+
q, k, v = map(lambda t: t.reshape(shape), (q, k, v))
|
|
533
|
+
|
|
534
|
+
# Apply ejKernel Flash Attention
|
|
535
|
+
attn_output = flash_attention(
|
|
536
|
+
q, k, v,
|
|
537
|
+
causal=True,
|
|
538
|
+
attention_mask=mask
|
|
539
|
+
)
|
|
540
|
+
|
|
541
|
+
# Project output
|
|
542
|
+
return nn.Dense(x.shape[-1])(attn_output.reshape(x.shape))
|
|
543
|
+
```
|
|
544
|
+
|
|
545
|
+
## Development
|
|
546
|
+
|
|
547
|
+
### Setting Up Development Environment
|
|
548
|
+
|
|
549
|
+
```bash
|
|
550
|
+
# Clone repository
|
|
551
|
+
git clone https://github.com/erfanzar/ejkernel.git
|
|
552
|
+
cd ejkernel
|
|
553
|
+
|
|
554
|
+
# Create virtual environment
|
|
555
|
+
python -m venv .venv
|
|
556
|
+
source .venv/bin/activate # On Windows: .venv\Scripts\activate
|
|
557
|
+
|
|
558
|
+
# Install in development mode
|
|
559
|
+
pip install -e ".[dev]"
|
|
560
|
+
|
|
561
|
+
# Install pre-commit hooks
|
|
562
|
+
pre-commit install
|
|
563
|
+
```
|
|
564
|
+
|
|
565
|
+
### Code Style
|
|
566
|
+
|
|
567
|
+
The project uses:
|
|
568
|
+
|
|
569
|
+
- **black** for code formatting (line length: 121)
|
|
570
|
+
- **ruff** for linting
|
|
571
|
+
- **mypy/pyright** for type checking
|
|
572
|
+
- **pre-commit** for automated checks
|
|
573
|
+
|
|
574
|
+
### Adding New Kernels
|
|
575
|
+
|
|
576
|
+
1. **Implement the kernel** in appropriate backend directory:
|
|
577
|
+
|
|
578
|
+
```python
|
|
579
|
+
# ejkernel/kernels/_triton/my_kernel/_interface.py
|
|
580
|
+
@kernel_registry.register("my_kernel", Platform.TRITON, Backend.GPU)
|
|
581
|
+
def my_kernel_triton(x, config):
|
|
582
|
+
# Implementation
|
|
583
|
+
pass
|
|
584
|
+
```
|
|
585
|
+
|
|
586
|
+
1. **Create module wrapper**:
|
|
587
|
+
|
|
588
|
+
```python
|
|
589
|
+
# ejkernel/modules/operations/my_kernel.py
|
|
590
|
+
class MyKernel(Kernel[MyKernelConfig, Array]):
|
|
591
|
+
# Module implementation
|
|
592
|
+
pass
|
|
593
|
+
```
|
|
594
|
+
|
|
595
|
+
1. **Add tests**:
|
|
596
|
+
|
|
597
|
+
```python
|
|
598
|
+
# test/kernels/_triton/test_my_kernel.py
|
|
599
|
+
class TestMyKernel(unittest.TestCase):
|
|
600
|
+
# Test implementation
|
|
601
|
+
pass
|
|
602
|
+
```
|
|
603
|
+
|
|
604
|
+
1. **Update documentation**
|
|
605
|
+
|
|
606
|
+
## Testing
|
|
607
|
+
|
|
608
|
+
### Running Tests
|
|
609
|
+
|
|
610
|
+
```bash
|
|
611
|
+
# Run all tests
|
|
612
|
+
pytest test/
|
|
613
|
+
|
|
614
|
+
# Platform-specific tests
|
|
615
|
+
pytest test/kernels/_xla/ # XLA implementations
|
|
616
|
+
pytest test/kernels/_triton/ # Triton implementations
|
|
617
|
+
pytest test/kernels/_pallas/ # Pallas implementations
|
|
618
|
+
|
|
619
|
+
# Specific test patterns
|
|
620
|
+
pytest -k "flash_attention"
|
|
621
|
+
pytest --verbose --failfast
|
|
622
|
+
|
|
623
|
+
# Module operations tests
|
|
624
|
+
pytest test/test_module_operations.py
|
|
625
|
+
```
|
|
626
|
+
|
|
627
|
+
### Test Categories
|
|
628
|
+
|
|
629
|
+
- **Unit Tests**: Individual component testing
|
|
630
|
+
- **Integration Tests**: End-to-end workflows
|
|
631
|
+
- **Comparison Tests**: Cross-backend consistency
|
|
632
|
+
- **Performance Tests**: Regression detection
|
|
633
|
+
|
|
634
|
+
## Benchmarking
|
|
635
|
+
|
|
636
|
+
Run benchmarks to compare performance across backends:
|
|
637
|
+
|
|
638
|
+
```bash
|
|
639
|
+
# General attention benchmarks
|
|
640
|
+
python benchmarks/benchmark_attention.py
|
|
641
|
+
|
|
642
|
+
# Ragged page attention benchmarks
|
|
643
|
+
python benchmarks/benchmark_ragged_page_attn.py
|
|
644
|
+
```
|
|
645
|
+
|
|
646
|
+
## Contributing
|
|
647
|
+
|
|
648
|
+
We welcome contributions! See [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines.
|
|
649
|
+
|
|
650
|
+
### Priority Areas
|
|
651
|
+
|
|
652
|
+
- TPU/Pallas implementations for existing algorithms
|
|
653
|
+
- CUDA native kernels for maximum performance
|
|
654
|
+
- New attention mechanisms from recent papers
|
|
655
|
+
- Performance optimizations and kernel fusion
|
|
656
|
+
- Documentation and examples
|
|
657
|
+
|
|
658
|
+
### Contribution Process
|
|
659
|
+
|
|
660
|
+
1. Fork the repository
|
|
661
|
+
1. Create a feature branch
|
|
662
|
+
1. Implement your changes with tests
|
|
663
|
+
1. Ensure all tests pass
|
|
664
|
+
1. Submit a pull request
|
|
665
|
+
|
|
666
|
+
## Documentation
|
|
667
|
+
|
|
668
|
+
Comprehensive documentation available at [ejkernel.readthedocs.io](https://ejkernel.readthedocs.io/en/latest/)
|
|
669
|
+
|
|
670
|
+
- **[API Reference](https://ejkernel.readthedocs.io/en/latest/api/)**: Complete API documentation
|
|
671
|
+
- **[Tutorials](https://ejkernel.readthedocs.io/en/latest/tutorials/)**: Step-by-step guides
|
|
672
|
+
- **[Architecture](https://ejkernel.readthedocs.io/en/latest/architecture/)**: Design documentation
|
|
673
|
+
- **[Benchmarks](https://ejkernel.readthedocs.io/en/latest/benchmarks/)**: Performance analysis
|
|
674
|
+
|
|
675
|
+
## Citation
|
|
676
|
+
|
|
677
|
+
If you use ejKernel in your research, please cite:
|
|
678
|
+
|
|
679
|
+
```bibtex
|
|
680
|
+
@software{ejkernel2024,
|
|
681
|
+
author = {Erfan Zare Chavoshi},
|
|
682
|
+
title = {ejKernel: High-Performance JAX Kernels for Deep Learning},
|
|
683
|
+
year = {2024},
|
|
684
|
+
url = {https://github.com/erfanzar/ejkernel},
|
|
685
|
+
note = {Production-grade kernel library with multi-backend support}
|
|
686
|
+
}
|
|
687
|
+
```
|
|
688
|
+
|
|
689
|
+
## License
|
|
690
|
+
|
|
691
|
+
ejKernel is licensed under the Apache License 2.0. See [LICENSE](LICENSE) for details.
|
|
692
|
+
|
|
693
|
+
## Acknowledgments
|
|
694
|
+
|
|
695
|
+
ejKernel builds upon excellent work from:
|
|
696
|
+
|
|
697
|
+
- [JAX](https://github.com/google/jax) - Composable transformations of Python+NumPy programs
|
|
698
|
+
- [Triton](https://github.com/openai/triton) - GPU kernel programming language
|
|
699
|
+
- [Pallas](https://github.com/google/jax/tree/main/jax/experimental/pallas) - JAX kernel language
|
|
700
|
+
- [Flash Attention](https://github.com/Dao-AILab/flash-attention) - Memory-efficient attention
|
|
701
|
+
- [EasyDeL](https://github.com/erfanzar/EasyDeL) - Parent framework for JAX deep learning
|
|
702
|
+
|
|
703
|
+
## Community
|
|
704
|
+
|
|
705
|
+
- **GitHub Issues**: [Bug reports and feature requests](https://github.com/erfanzar/ejkernel/issues)
|
|
706
|
+
- **Discussions**: [Community forum](https://github.com/erfanzar/ejkernel/discussions)
|
|
707
|
+
- **Email**: <Erfanzare810@gmail.com>
|
|
708
|
+
|
|
709
|
+
---
|
|
710
|
+
|
|
711
|
+
**ejKernel** - Production-grade kernels for JAX deep learning
|