ejkernel 0.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (183) hide show
  1. ejkernel-0.0.1/PKG-INFO +607 -0
  2. ejkernel-0.0.1/README.md +569 -0
  3. ejkernel-0.0.1/ejkernel/__init__.py +51 -0
  4. ejkernel-0.0.1/ejkernel/benchmarks.py +772 -0
  5. ejkernel-0.0.1/ejkernel/callib/__init__.py +34 -0
  6. ejkernel-0.0.1/ejkernel/callib/_ejit.py +513 -0
  7. ejkernel-0.0.1/ejkernel/callib/_triton_call.py +891 -0
  8. ejkernel-0.0.1/ejkernel/callib/_utils.py +243 -0
  9. ejkernel-0.0.1/ejkernel/kernels/__init__.py +65 -0
  10. ejkernel-0.0.1/ejkernel/kernels/_cuda/__init__.py +15 -0
  11. ejkernel-0.0.1/ejkernel/kernels/_pallas/__init__.py +18 -0
  12. ejkernel-0.0.1/ejkernel/kernels/_pallas/gpu/__init__.py +19 -0
  13. ejkernel-0.0.1/ejkernel/kernels/_pallas/gpu/ragged_decode_attention/__init__.py +18 -0
  14. ejkernel-0.0.1/ejkernel/kernels/_pallas/gpu/ragged_decode_attention/_interface.py +81 -0
  15. ejkernel-0.0.1/ejkernel/kernels/_pallas/gpu/ragged_decode_attention/_pallas_impl_fwd.py +382 -0
  16. ejkernel-0.0.1/ejkernel/kernels/_pallas/gpu/scaled_dot_product_attention/__init__.py +25 -0
  17. ejkernel-0.0.1/ejkernel/kernels/_pallas/gpu/scaled_dot_product_attention/_interface.py +131 -0
  18. ejkernel-0.0.1/ejkernel/kernels/_pallas/tpu/__init__.py +32 -0
  19. ejkernel-0.0.1/ejkernel/kernels/_pallas/tpu/blocksparse_attention/__init__.py +66 -0
  20. ejkernel-0.0.1/ejkernel/kernels/_pallas/tpu/blocksparse_attention/_info.py +975 -0
  21. ejkernel-0.0.1/ejkernel/kernels/_pallas/tpu/blocksparse_attention/_kernel.py +2573 -0
  22. ejkernel-0.0.1/ejkernel/kernels/_pallas/tpu/blocksparse_attention/_masks.py +546 -0
  23. ejkernel-0.0.1/ejkernel/kernels/_pallas/tpu/flash_attention/__init__.py +24 -0
  24. ejkernel-0.0.1/ejkernel/kernels/_pallas/tpu/flash_attention/_interface.py +188 -0
  25. ejkernel-0.0.1/ejkernel/kernels/_pallas/tpu/flash_attention/_pallas_impl_bwd.py +676 -0
  26. ejkernel-0.0.1/ejkernel/kernels/_pallas/tpu/flash_attention/_pallas_impl_fwd.py +451 -0
  27. ejkernel-0.0.1/ejkernel/kernels/_pallas/tpu/flash_attention/_utils.py +373 -0
  28. ejkernel-0.0.1/ejkernel/kernels/_pallas/tpu/grouped_matmul/__init__.py +18 -0
  29. ejkernel-0.0.1/ejkernel/kernels/_pallas/tpu/grouped_matmul/_interface.py +250 -0
  30. ejkernel-0.0.1/ejkernel/kernels/_pallas/tpu/grouped_matmul/_pallas_impl.py +979 -0
  31. ejkernel-0.0.1/ejkernel/kernels/_pallas/tpu/grouped_matmul/_utils.py +191 -0
  32. ejkernel-0.0.1/ejkernel/kernels/_pallas/tpu/page_attention/__init__.py +18 -0
  33. ejkernel-0.0.1/ejkernel/kernels/_pallas/tpu/page_attention/_interface.py +253 -0
  34. ejkernel-0.0.1/ejkernel/kernels/_pallas/tpu/page_attention/_pallas_impl_fwd.py +290 -0
  35. ejkernel-0.0.1/ejkernel/kernels/_pallas/tpu/ragged_decode_attention/__init__.py +18 -0
  36. ejkernel-0.0.1/ejkernel/kernels/_pallas/tpu/ragged_decode_attention/_interface.py +74 -0
  37. ejkernel-0.0.1/ejkernel/kernels/_pallas/tpu/ragged_decode_attention/_pallas_impl_fwd.py +247 -0
  38. ejkernel-0.0.1/ejkernel/kernels/_pallas/tpu/ragged_page_attention/__init__.py +20 -0
  39. ejkernel-0.0.1/ejkernel/kernels/_pallas/tpu/ragged_page_attention/_interface.py +245 -0
  40. ejkernel-0.0.1/ejkernel/kernels/_pallas/tpu/ragged_page_attention/_pallas_impl_fwd.py +590 -0
  41. ejkernel-0.0.1/ejkernel/kernels/_pallas/tpu/ragged_page_attention/_utils.py +437 -0
  42. ejkernel-0.0.1/ejkernel/kernels/_pallas/tpu/ragged_page_attention_v3/__init__.py +20 -0
  43. ejkernel-0.0.1/ejkernel/kernels/_pallas/tpu/ragged_page_attention_v3/_interface.py +115 -0
  44. ejkernel-0.0.1/ejkernel/kernels/_pallas/tpu/ragged_page_attention_v3/_pallas_impl_fwd.py +1230 -0
  45. ejkernel-0.0.1/ejkernel/kernels/_pallas/tpu/ragged_page_attention_v3/_pallas_impl_fwd_h64.py +1184 -0
  46. ejkernel-0.0.1/ejkernel/kernels/_pallas/tpu/ragged_page_attention_v3/_utils.py +4545 -0
  47. ejkernel-0.0.1/ejkernel/kernels/_pallas/tpu/ring_attention/__init__.py +24 -0
  48. ejkernel-0.0.1/ejkernel/kernels/_pallas/tpu/ring_attention/_interface.py +225 -0
  49. ejkernel-0.0.1/ejkernel/kernels/_pallas/tpu/ring_attention/_pallas_impl_bwd.py +1003 -0
  50. ejkernel-0.0.1/ejkernel/kernels/_pallas/tpu/ring_attention/_pallas_impl_fwd.py +769 -0
  51. ejkernel-0.0.1/ejkernel/kernels/_pallas/tpu/ring_attention/_utils.py +192 -0
  52. ejkernel-0.0.1/ejkernel/kernels/_registry.py +377 -0
  53. ejkernel-0.0.1/ejkernel/kernels/_triton/__init__.py +59 -0
  54. ejkernel-0.0.1/ejkernel/kernels/_triton/blocksparse_attention/__init__.py +33 -0
  55. ejkernel-0.0.1/ejkernel/kernels/_triton/blocksparse_attention/_interface.py +465 -0
  56. ejkernel-0.0.1/ejkernel/kernels/_triton/blocksparse_attention/_mask.py +530 -0
  57. ejkernel-0.0.1/ejkernel/kernels/_triton/blocksparse_attention/_triton_impl_bwd.py +1380 -0
  58. ejkernel-0.0.1/ejkernel/kernels/_triton/blocksparse_attention/_triton_impl_fwd.py +695 -0
  59. ejkernel-0.0.1/ejkernel/kernels/_triton/blocksparse_attention/_utilities.py +437 -0
  60. ejkernel-0.0.1/ejkernel/kernels/_triton/flash_attention/__init__.py +24 -0
  61. ejkernel-0.0.1/ejkernel/kernels/_triton/flash_attention/_interface.py +385 -0
  62. ejkernel-0.0.1/ejkernel/kernels/_triton/flash_attention/_triton_impl_bwd.py +1494 -0
  63. ejkernel-0.0.1/ejkernel/kernels/_triton/flash_attention/_triton_impl_fwd.py +827 -0
  64. ejkernel-0.0.1/ejkernel/kernels/_triton/flash_attention/_utilities.py +358 -0
  65. ejkernel-0.0.1/ejkernel/kernels/_triton/flash_mla/__init__.py +38 -0
  66. ejkernel-0.0.1/ejkernel/kernels/_triton/flash_mla/_interface.py +100 -0
  67. ejkernel-0.0.1/ejkernel/kernels/_triton/flash_mla/_triton_impl_bwd.py +15 -0
  68. ejkernel-0.0.1/ejkernel/kernels/_triton/flash_mla/_triton_impl_fwd.py +180 -0
  69. ejkernel-0.0.1/ejkernel/kernels/_triton/flash_mla/_utilities.py +15 -0
  70. ejkernel-0.0.1/ejkernel/kernels/_triton/gla/__init__.py +18 -0
  71. ejkernel-0.0.1/ejkernel/kernels/_triton/gla/_interface.py +136 -0
  72. ejkernel-0.0.1/ejkernel/kernels/_triton/lightning_attn/__init__.py +18 -0
  73. ejkernel-0.0.1/ejkernel/kernels/_triton/lightning_attn/_interface.py +143 -0
  74. ejkernel-0.0.1/ejkernel/kernels/_triton/mean_pooling/__init__.py +30 -0
  75. ejkernel-0.0.1/ejkernel/kernels/_triton/mean_pooling/_interface.py +163 -0
  76. ejkernel-0.0.1/ejkernel/kernels/_triton/mean_pooling/_triton_impl_bwd.py +130 -0
  77. ejkernel-0.0.1/ejkernel/kernels/_triton/mean_pooling/_triton_impl_fwd.py +126 -0
  78. ejkernel-0.0.1/ejkernel/kernels/_triton/native_sparse_attention/__init__.py +32 -0
  79. ejkernel-0.0.1/ejkernel/kernels/_triton/native_sparse_attention/_compression.py +753 -0
  80. ejkernel-0.0.1/ejkernel/kernels/_triton/native_sparse_attention/_interface.py +403 -0
  81. ejkernel-0.0.1/ejkernel/kernels/_triton/native_sparse_attention/_triton_impl_bwd.py +394 -0
  82. ejkernel-0.0.1/ejkernel/kernels/_triton/native_sparse_attention/_triton_impl_fwd.py +428 -0
  83. ejkernel-0.0.1/ejkernel/kernels/_triton/native_sparse_attention/_utilities.py +87 -0
  84. ejkernel-0.0.1/ejkernel/kernels/_triton/page_attention/__init__.py +18 -0
  85. ejkernel-0.0.1/ejkernel/kernels/_triton/page_attention/_interface.py +369 -0
  86. ejkernel-0.0.1/ejkernel/kernels/_triton/page_attention/_triton_impl_fwd.py +257 -0
  87. ejkernel-0.0.1/ejkernel/kernels/_triton/ragged_decode_attention/__init__.py +18 -0
  88. ejkernel-0.0.1/ejkernel/kernels/_triton/ragged_decode_attention/_interface.py +75 -0
  89. ejkernel-0.0.1/ejkernel/kernels/_triton/ragged_decode_attention/_triton_impl_fwd.py +302 -0
  90. ejkernel-0.0.1/ejkernel/kernels/_triton/ragged_page_attention/__init__.py +18 -0
  91. ejkernel-0.0.1/ejkernel/kernels/_triton/ragged_page_attention/_interface.py +226 -0
  92. ejkernel-0.0.1/ejkernel/kernels/_triton/ragged_page_attention/_triton_impl_fwd.py +613 -0
  93. ejkernel-0.0.1/ejkernel/kernels/_triton/recurrent/__init__.py +30 -0
  94. ejkernel-0.0.1/ejkernel/kernels/_triton/recurrent/_interface.py +291 -0
  95. ejkernel-0.0.1/ejkernel/kernels/_triton/recurrent/_triton_impl_bwd.py +354 -0
  96. ejkernel-0.0.1/ejkernel/kernels/_triton/recurrent/_triton_impl_fwd.py +216 -0
  97. ejkernel-0.0.1/ejkernel/kernels/_xla/__init__.py +44 -0
  98. ejkernel-0.0.1/ejkernel/kernels/_xla/attention/__init__.py +18 -0
  99. ejkernel-0.0.1/ejkernel/kernels/_xla/attention/_interface.py +197 -0
  100. ejkernel-0.0.1/ejkernel/kernels/_xla/flash_attention/__init__.py +24 -0
  101. ejkernel-0.0.1/ejkernel/kernels/_xla/flash_attention/_interface.py +376 -0
  102. ejkernel-0.0.1/ejkernel/kernels/_xla/flash_attention/_xla_impl_bwd.py +85 -0
  103. ejkernel-0.0.1/ejkernel/kernels/_xla/flash_attention/_xla_impl_fwd.py +474 -0
  104. ejkernel-0.0.1/ejkernel/kernels/_xla/gla/__init__.py +18 -0
  105. ejkernel-0.0.1/ejkernel/kernels/_xla/gla/_interface.py +123 -0
  106. ejkernel-0.0.1/ejkernel/kernels/_xla/grouped_matmul/__init__.py +18 -0
  107. ejkernel-0.0.1/ejkernel/kernels/_xla/grouped_matmul/_interface.py +124 -0
  108. ejkernel-0.0.1/ejkernel/kernels/_xla/lightning_attn/__init__.py +18 -0
  109. ejkernel-0.0.1/ejkernel/kernels/_xla/lightning_attn/_interface.py +129 -0
  110. ejkernel-0.0.1/ejkernel/kernels/_xla/mean_pooling/__init__.py +18 -0
  111. ejkernel-0.0.1/ejkernel/kernels/_xla/mean_pooling/_interface.py +171 -0
  112. ejkernel-0.0.1/ejkernel/kernels/_xla/native_sparse_attention/__init__.py +32 -0
  113. ejkernel-0.0.1/ejkernel/kernels/_xla/native_sparse_attention/_interface.py +435 -0
  114. ejkernel-0.0.1/ejkernel/kernels/_xla/native_sparse_attention/_xla_impl_bwd.py +138 -0
  115. ejkernel-0.0.1/ejkernel/kernels/_xla/native_sparse_attention/_xla_impl_fwd.py +86 -0
  116. ejkernel-0.0.1/ejkernel/kernels/_xla/page_attention/__init__.py +18 -0
  117. ejkernel-0.0.1/ejkernel/kernels/_xla/page_attention/_interface.py +123 -0
  118. ejkernel-0.0.1/ejkernel/kernels/_xla/page_attention/_xla_impl_fwd.py +97 -0
  119. ejkernel-0.0.1/ejkernel/kernels/_xla/ragged_decode_attention/__init__.py +18 -0
  120. ejkernel-0.0.1/ejkernel/kernels/_xla/ragged_decode_attention/_interface.py +119 -0
  121. ejkernel-0.0.1/ejkernel/kernels/_xla/ragged_decode_attention/_xla_impl_fwd.py +468 -0
  122. ejkernel-0.0.1/ejkernel/kernels/_xla/ragged_page_attention/__init__.py +18 -0
  123. ejkernel-0.0.1/ejkernel/kernels/_xla/ragged_page_attention/_interface.py +109 -0
  124. ejkernel-0.0.1/ejkernel/kernels/_xla/ragged_page_attention/_xla_impl_fwd.py +214 -0
  125. ejkernel-0.0.1/ejkernel/kernels/_xla/recurrent/__init__.py +30 -0
  126. ejkernel-0.0.1/ejkernel/kernels/_xla/recurrent/_interface.py +201 -0
  127. ejkernel-0.0.1/ejkernel/kernels/_xla/recurrent/_xla_impl_bwd.py +163 -0
  128. ejkernel-0.0.1/ejkernel/kernels/_xla/recurrent/_xla_impl_fwd.py +242 -0
  129. ejkernel-0.0.1/ejkernel/kernels/_xla/ring_attention/__init__.py +24 -0
  130. ejkernel-0.0.1/ejkernel/kernels/_xla/ring_attention/_interface.py +212 -0
  131. ejkernel-0.0.1/ejkernel/kernels/_xla/ring_attention/_utils.py +163 -0
  132. ejkernel-0.0.1/ejkernel/kernels/_xla/ring_attention/_xla_impl_bwd.py +354 -0
  133. ejkernel-0.0.1/ejkernel/kernels/_xla/ring_attention/_xla_impl_fwd.py +388 -0
  134. ejkernel-0.0.1/ejkernel/kernels/_xla/scaled_dot_product_attention/__init__.py +24 -0
  135. ejkernel-0.0.1/ejkernel/kernels/_xla/scaled_dot_product_attention/_interface.py +106 -0
  136. ejkernel-0.0.1/ejkernel/loggings.py +482 -0
  137. ejkernel-0.0.1/ejkernel/modules/__init__.py +138 -0
  138. ejkernel-0.0.1/ejkernel/modules/base.py +213 -0
  139. ejkernel-0.0.1/ejkernel/modules/operations/__init__.py +143 -0
  140. ejkernel-0.0.1/ejkernel/modules/operations/attention.py +286 -0
  141. ejkernel-0.0.1/ejkernel/modules/operations/blocksparse_attention.py +937 -0
  142. ejkernel-0.0.1/ejkernel/modules/operations/configs.py +333 -0
  143. ejkernel-0.0.1/ejkernel/modules/operations/flash_attention.py +886 -0
  144. ejkernel-0.0.1/ejkernel/modules/operations/gated_linear_attention.py +301 -0
  145. ejkernel-0.0.1/ejkernel/modules/operations/grouped_matmul.py +300 -0
  146. ejkernel-0.0.1/ejkernel/modules/operations/lightning_attention.py +306 -0
  147. ejkernel-0.0.1/ejkernel/modules/operations/multi_head_latent_attention.py +290 -0
  148. ejkernel-0.0.1/ejkernel/modules/operations/native_sparse_attention.py +374 -0
  149. ejkernel-0.0.1/ejkernel/modules/operations/page_attention.py +444 -0
  150. ejkernel-0.0.1/ejkernel/modules/operations/pooling.py +244 -0
  151. ejkernel-0.0.1/ejkernel/modules/operations/ragged_decode_attention.py +588 -0
  152. ejkernel-0.0.1/ejkernel/modules/operations/ragged_page_attention.py +641 -0
  153. ejkernel-0.0.1/ejkernel/modules/operations/recurrent.py +328 -0
  154. ejkernel-0.0.1/ejkernel/modules/operations/ring_attention.py +617 -0
  155. ejkernel-0.0.1/ejkernel/modules/operations/scaled_dot_product_attention.py +417 -0
  156. ejkernel-0.0.1/ejkernel/ops/__init__.py +150 -0
  157. ejkernel-0.0.1/ejkernel/ops/config/__init__.py +43 -0
  158. ejkernel-0.0.1/ejkernel/ops/config/cache.py +187 -0
  159. ejkernel-0.0.1/ejkernel/ops/config/persistent.py +188 -0
  160. ejkernel-0.0.1/ejkernel/ops/config/selection.py +569 -0
  161. ejkernel-0.0.1/ejkernel/ops/core/__init__.py +36 -0
  162. ejkernel-0.0.1/ejkernel/ops/core/kernel.py +671 -0
  163. ejkernel-0.0.1/ejkernel/ops/core/types.py +50 -0
  164. ejkernel-0.0.1/ejkernel/ops/execution/__init__.py +67 -0
  165. ejkernel-0.0.1/ejkernel/ops/execution/batch.py +191 -0
  166. ejkernel-0.0.1/ejkernel/ops/execution/executor.py +505 -0
  167. ejkernel-0.0.1/ejkernel/ops/execution/offline.py +80 -0
  168. ejkernel-0.0.1/ejkernel/ops/execution/profiler.py +475 -0
  169. ejkernel-0.0.1/ejkernel/ops/execution/tuning.py +1197 -0
  170. ejkernel-0.0.1/ejkernel/ops/registry.py +93 -0
  171. ejkernel-0.0.1/ejkernel/ops/utils/__init__.py +72 -0
  172. ejkernel-0.0.1/ejkernel/ops/utils/datacarrier.py +126 -0
  173. ejkernel-0.0.1/ejkernel/ops/utils/fingerprint.py +342 -0
  174. ejkernel-0.0.1/ejkernel/ops/utils/meta.py +156 -0
  175. ejkernel-0.0.1/ejkernel/ops/utils/serialize.py +97 -0
  176. ejkernel-0.0.1/ejkernel/types/__init__.py +18 -0
  177. ejkernel-0.0.1/ejkernel/types/mask.py +1787 -0
  178. ejkernel-0.0.1/ejkernel/utils.py +677 -0
  179. ejkernel-0.0.1/ejkernel/xla_utils/__init__.py +45 -0
  180. ejkernel-0.0.1/ejkernel/xla_utils/cumsum.py +297 -0
  181. ejkernel-0.0.1/ejkernel/xla_utils/shardings.py +102 -0
  182. ejkernel-0.0.1/ejkernel/xla_utils/utils.py +184 -0
  183. ejkernel-0.0.1/pyproject.toml +126 -0
@@ -0,0 +1,607 @@
1
+ Metadata-Version: 2.3
2
+ Name: ejkernel
3
+ Version: 0.0.1
4
+ Summary: Accelerate, Optimize performance with streamlined training and serving options with JAX.
5
+ Keywords: Deep Learning,Machine Learning,JAX,CUDA,XLA,Triton,Pallas
6
+ Author: Erfan Zare Chavoshi
7
+ Author-email: Erfan Zare Chavoshi <Erfanzare810@gmail.com>
8
+ License: Apache-2.0
9
+ Classifier: Development Status :: 3 - Alpha
10
+ Classifier: Intended Audience :: Developers
11
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
12
+ Classifier: License :: OSI Approved :: Apache Software License
13
+ Classifier: Programming Language :: Python :: 3.10
14
+ Classifier: Programming Language :: Python :: 3.11
15
+ Classifier: Programming Language :: Python :: 3.12
16
+ Classifier: Programming Language :: Python :: 3.13
17
+ Requires-Dist: beartype>=0.22.2
18
+ Requires-Dist: chex>=0.1.91
19
+ Requires-Dist: einops>=0.8.1
20
+ Requires-Dist: jax>=0.8.0
21
+ Requires-Dist: jaxlib>=0.8.0
22
+ Requires-Dist: jaxtyping>=0.3.2
23
+ Requires-Dist: pydantic>=2.11.10
24
+ Requires-Dist: triton==3.4.0
25
+ Requires-Dist: jax[cuda12]>=0.8.0 ; extra == 'gpu'
26
+ Requires-Dist: xprof>=2.20.6 ; extra == 'profile'
27
+ Requires-Dist: tb-nightly>=2.21.0a20250820 ; extra == 'profile'
28
+ Requires-Dist: xprof-nightly>=2.21.6a20250820 ; extra == 'profile'
29
+ Requires-Dist: jax[tpu]>=0.8.0 ; extra == 'tpu'
30
+ Requires-Python: >=3.11, <3.14
31
+ Project-URL: Documentation, https://ejkernel.readthedocs.io/en/latest/
32
+ Project-URL: Homepage, https://github.com/erfanzar/ejkernel
33
+ Project-URL: Repository, https://github.com/erfanzar/ejkernel
34
+ Provides-Extra: gpu
35
+ Provides-Extra: profile
36
+ Provides-Extra: tpu
37
+ Description-Content-Type: text/markdown
38
+
39
+ # ejKernel: High-Performance JAX Kernels for Deep Learning
40
+
41
+ > *"The best optimization is the one you don't have to think about."*
42
+
43
+ [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
44
+ [![Python 3.11+](https://img.shields.io/badge/python-3.11+-blue.svg)](https://www.python.org/downloads/)
45
+ [![JAX](https://img.shields.io/badge/JAX-0.7.2+-orange.svg)](https://github.com/google/jax)
46
+ [![Documentation](https://img.shields.io/badge/docs-readthedocs-green.svg)](https://ejkernel.readthedocs.io/en/latest/)
47
+
48
+ ejKernel is a production-grade kernel library for JAX that provides highly optimized implementations of deep learning operations with automatic multi-backend support. The library features a sophisticated configuration management system with autotuning, comprehensive type safety, and seamless execution across GPUs, TPUs, and CPUs.
49
+
50
+ ## Table of Contents
51
+
52
+ - [Key Features](#key-features)
53
+ - [Installation](#installation)
54
+ - [Quick Start](#quick-start)
55
+ - [Architecture Overview](#architecture-overview)
56
+ - [Supported Operations](#supported-operations)
57
+ - [Advanced Usage](#advanced-usage)
58
+ - [Performance](#performance)
59
+ - [Development](#development)
60
+ - [Testing](#testing)
61
+ - [Contributing](#contributing)
62
+ - [Citation](#citation)
63
+ - [License](#license)
64
+
65
+ ## Key Features
66
+
67
+ ### Intelligent Kernel Management
68
+
69
+ - **7-Tier Configuration System**: Override → Overlay → Memory Cache → Persistent Cache → Autotune → Heuristics → Error
70
+ - **Automatic Platform Detection**: Seamlessly selects optimal implementation based on hardware
71
+ - **Priority-Based Registry**: Multi-backend support with intelligent fallback mechanisms
72
+ - **Device Fingerprinting**: Hardware-specific configuration caching for optimal performance
73
+
74
+ ### State-of-the-Art Operations
75
+
76
+ - **15+ Attention Mechanisms**: Flash Attention v2, Ring Attention, Page Attention, Block Sparse, GLA, Lightning, and more
77
+ - **Memory Efficiency**: Custom VJP implementations with O(N) memory complexity for attention
78
+ - **Distributed Support**: Full shard_map integration for model and data parallelism
79
+ - **Mixed Precision**: Comprehensive dtype support with automatic gradient conversion
80
+
81
+ ### Production-Ready Infrastructure
82
+
83
+ - **Type Safety**: Full jaxtyping annotations with runtime validation via beartype
84
+ - **Comprehensive Testing**: Cross-backend validation, performance benchmarks, integration tests
85
+ - **Atomic Persistence**: Thread-safe configuration storage with automatic optimization
86
+ - **Profiling Integration**: Built-in support for JAX profiling and performance monitoring
87
+
88
+ ## Installation
89
+
90
+ ### Basic Installation
91
+
92
+ ```bash
93
+ pip install ejkernel
94
+ ```
95
+
96
+ ### Platform-Specific Installation
97
+
98
+ ```bash
99
+ # GPU Support (CUDA/ROCm)
100
+ pip install ejkernel[gpu]
101
+
102
+ # TPU Support
103
+ pip install ejkernel[tpu]
104
+
105
+ # Development Installation
106
+ git clone https://github.com/erfanzar/ejkernel.git
107
+ cd ejkernel
108
+ pip install -e ".[dev]"
109
+ ```
110
+
111
+ ### Dependencies
112
+
113
+ - Python 3.11-3.13
114
+ - JAX >= 0.7.2
115
+ - Triton == 3.4.0 (for GPU)
116
+ - jaxtyping >= 0.3.2
117
+ - beartype >= 0.22.2
118
+
119
+ ## Quick Start
120
+
121
+ ### Simple API with Automatic Optimization
122
+
123
+ ```python
124
+ import jax.numpy as jnp
125
+ from ejkernel.modules import flash_attention
126
+
127
+ # Basic usage - automatic configuration selection
128
+ output = flash_attention(
129
+ query, key, value,
130
+ causal=True,
131
+ dropout_prob=0.1
132
+ )
133
+
134
+ # With advanced features
135
+ output = flash_attention(
136
+ query, key, value,
137
+ causal=True,
138
+ sliding_window=128, # Local attention window
139
+ logits_soft_cap=30.0, # Gemma-2 style soft capping
140
+ attention_mask=mask, # Custom attention pattern
141
+ )
142
+ ```
143
+
144
+ ### Custom Configuration
145
+
146
+ ```python
147
+ from ejkernel.modules import FlashAttentionConfig
148
+ from ejkernel.ops.utils.datacarrier import FwdParams, BwdParams
149
+
150
+ # Create optimized configuration
151
+ config = FlashAttentionConfig(
152
+ fwd_params=FwdParams(
153
+ q_blocksize=256,
154
+ kv_blocksize=256,
155
+ num_warps=8,
156
+ num_stages=2
157
+ ),
158
+ bwd_params=BwdParams(
159
+ q_blocksize=128,
160
+ kv_blocksize=128,
161
+ num_warps=4
162
+ ),
163
+ platform="triton", # Force specific backend
164
+ backend="gpu"
165
+ )
166
+
167
+ output = flash_attention(query, key, value, cfg=config)
168
+ ```
169
+
170
+ ### Direct Kernel Registry Access
171
+
172
+ ```python
173
+ from ejkernel import kernel_registry, Platform, Backend
174
+
175
+ # Get specific implementation
176
+ kernel = kernel_registry.get(
177
+ algorithm="flash_attention",
178
+ platform=Platform.TRITON,
179
+ backend=Backend.GPU
180
+ )
181
+
182
+ # Direct execution
183
+ output = kernel(query, key, value, causal=True)
184
+ ```
185
+
186
+ ### Distributed Execution
187
+
188
+ ```python
189
+ import jax
190
+ from jax.sharding import Mesh, PartitionSpec as P
191
+ from ejkernel.modules import flash_attention
192
+
193
+ # Setup mesh for distributed execution
194
+ devices = jax.devices()
195
+ mesh = Mesh(devices, axis_names=("data", "model"))
196
+
197
+ # Run distributed attention
198
+ output = flash_attention(
199
+ query, key, value,
200
+ causal=True,
201
+ mesh=mesh,
202
+ in_specs=(P("data", None), P("data", None), P("data", None)),
203
+ out_specs=P("data", None)
204
+ )
205
+ ```
206
+
207
+ ## Architecture Overview
208
+
209
+ ### System Design
210
+
211
+ ejKernel employs a sophisticated layered architecture that separates concerns while maintaining high performance:
212
+
213
+ ```md
214
+ ┌─────────────────────────────────────────────────────┐
215
+ │ Public API (modules/) │
216
+ │ Simple functions with sensible defaults │
217
+ ├─────────────────────────────────────────────────────┤
218
+ │ Operations Layer (ops/) │
219
+ │ Configuration management, autotuning, caching │
220
+ ├─────────────────────────────────────────────────────┤
221
+ │ Kernel Registry (kernels/) │
222
+ │ Platform routing, signature validation │
223
+ ├─────────────────────────────────────────────────────┤
224
+ │ Backend Implementations (kernels/_*) │
225
+ │ Triton, Pallas, XLA, CUDA kernels │
226
+ └─────────────────────────────────────────────────────┘
227
+ ```
228
+
229
+ ### Project Structure
230
+
231
+ ```md
232
+ ejkernel/
233
+ ├── kernels/
234
+ │ ├── _triton/ # GPU kernels via Triton
235
+ │ ├── _pallas/ # TPU/GPU kernels via Pallas
236
+ │ │ ├── tpu/ # TPU-specific implementations
237
+ │ │ └── gpu/ # GPU Pallas implementations
238
+ │ ├── _xla/ # Universal XLA implementations
239
+ │ └── _cuda/ # Native CUDA kernels
240
+ ├── modules/
241
+ │ └── operations/ # High-level API modules
242
+ ├── ops/
243
+ │ ├── config/ # Configuration management
244
+ │ ├── core/ # Base kernel classes
245
+ │ ├── execution/ # Execution orchestration
246
+ │ └── utils/ # Fingerprinting, utilities
247
+ ├── xla_utils/ # XLA-specific utilities
248
+ └── callib/ # Calibration utilities
249
+ ```
250
+
251
+ ### Core Components
252
+
253
+ #### Kernel Registry
254
+
255
+ The registry provides automatic platform-specific kernel selection:
256
+
257
+ ```python
258
+ @kernel_registry.register("my_operation", Platform.TRITON, Backend.GPU, priority=100)
259
+ def my_operation_gpu(x, y):
260
+ # GPU-optimized implementation
261
+ pass
262
+
263
+ @kernel_registry.register("my_operation", Platform.XLA, Backend.ANY, priority=50)
264
+ def my_operation_fallback(x, y):
265
+ # Universal fallback
266
+ pass
267
+
268
+ # Automatic selection based on available hardware
269
+ impl = kernel_registry.get("my_operation")
270
+ ```
271
+
272
+ #### Configuration Management
273
+
274
+ Multi-tier configuration system with intelligent fallback:
275
+
276
+ ```python
277
+ class ConfigSelectorChain:
278
+ """
279
+ Selection hierarchy:
280
+ 1. Override - Explicit user configuration
281
+ 2. Overlay - Temporary context overrides
282
+ 3. Memory Cache - In-memory lookup
283
+ 4. Persistent Cache - Disk-based storage
284
+ 5. Autotune - Performance benchmarking
285
+ 6. Heuristics - Intelligent defaults
286
+ 7. Error - Clear failure message
287
+ """
288
+ ```
289
+
290
+ #### Custom VJP System
291
+
292
+ All performance-critical kernels implement memory-efficient gradients:
293
+
294
+ ```python
295
+ @jax.custom_vjp
296
+ def kernel_with_custom_grad(inputs):
297
+ return forward(inputs)
298
+
299
+ def kernel_fwd(inputs):
300
+ output, residuals = forward_with_residuals(inputs)
301
+ return output, residuals
302
+
303
+ def kernel_bwd(residuals, grad_output):
304
+ return efficient_backward(residuals, grad_output)
305
+
306
+ kernel_with_custom_grad.defvjp(kernel_fwd, kernel_bwd)
307
+ ```
308
+
309
+ ## Supported Operations
310
+
311
+ ### Attention Mechanisms
312
+
313
+ | Algorithm | Description | Memory | Key Features |
314
+ |-----------|-------------|--------|--------------|
315
+ | **Flash Attention v2** | Memory-efficient exact attention | O(N) | Causal masking, dropout, sliding windows, soft capping |
316
+ | **Ring Attention** | Distributed sequence parallelism | O(N/P) | Ultra-long sequences, communication overlap |
317
+ | **Page Attention** | KV-cache optimized inference | O(N) | Block-wise memory, continuous batching |
318
+ | **Block Sparse Attention** | Configurable sparse patterns | O(N√N) | Local+global, custom patterns |
319
+ | **GLA** | Gated Linear Attention | O(N) | Linear complexity, gated updates |
320
+ | **Lightning Attention** | Layer-dependent decay | O(N) | Exponential moving average |
321
+ | **MLA** | Multi-head Latent Attention | O(N) | Compressed KV representation |
322
+ | **Ragged Attention** | Variable-length sequences | O(N) | Efficient padding, batched inference |
323
+
324
+ ### Other Operations
325
+
326
+ - **Recurrent Kernels**: Optimized RNN/LSTM/GRU operations
327
+ - **Mean Pooling**: Variable-length sequence aggregation
328
+ - **Grouped MatMul**: Efficient batched matrix operations
329
+ - **Native Sparse**: Block-sparse matrix computations
330
+
331
+ ### Platform Support Matrix
332
+
333
+ | Operation | Triton (GPU) | Pallas (TPU) | XLA (Universal) | CUDA |
334
+ |-----------|-------------|--------------|-----------------|------|
335
+ | Flash Attention v2 | ✓ | ✓ | ✓ | Dev |
336
+ | Ring Attention | ✓ | ✓ | ✓ | Dev |
337
+ | Page Attention | ✓ | ✓ | ✓ | Dev |
338
+ | Block Sparse | ✓ | - | ✓ | Dev |
339
+ | GLA | ✓ | Dev | ✓ | - |
340
+ | Lightning | ✓ | - | ✓ | Dev |
341
+ | MLA | ✓ | Dev | - | - |
342
+ | Ragged Attention | ✓ | ✓ | ✓ | Dev |
343
+
344
+ ✓ = Production ready | Dev = Under development | - = Not planned
345
+
346
+ ## Advanced Usage
347
+
348
+ ### Performance Optimization
349
+
350
+ ```python
351
+ # Force autotuning for optimal configuration
352
+ import os
353
+ os.environ["EJKERNEL_AUTOTUNE_POLICY"] = "autotune"
354
+ os.environ["EJKERNEL_LOG_AUTOTUNE"] = "1"
355
+
356
+ # Enable profiling
357
+ os.environ["EJKERNEL_OPS_STAMP"] = "json" # Detailed metadata
358
+ os.environ["EJKERNEL_OPS_RECORD"] = "1" # Record invocations
359
+ ```
360
+
361
+ ### Custom Kernel Development
362
+
363
+ ```python
364
+ from ejkernel.ops.core import Kernel
365
+ from ejkernel.modules.operations.configs import BaseOperationConfig
366
+
367
+ @dataclass
368
+ class MyConfig(BaseOperationConfig):
369
+ param1: int = 128
370
+ param2: float = 0.1
371
+
372
+ class MyKernel(Kernel[MyConfig, Array]):
373
+ def __init__(self):
374
+ super().__init__(op_id="my_kernel")
375
+
376
+ def run(self, x, cfg: MyConfig):
377
+ impl = kernel_registry.get("my_kernel", cfg.platform)
378
+ return impl(x, param1=cfg.param1, param2=cfg.param2)
379
+
380
+ def heuristic_cfg(self, inv):
381
+ # Return default configuration
382
+ return MyConfig(param1=256)
383
+
384
+ def candidate_cfgs(self, inv):
385
+ # Return autotuning candidates
386
+ return [MyConfig(param1=p) for p in [64, 128, 256]]
387
+ ```
388
+
389
+ ### Integration with Models
390
+
391
+ ```python
392
+ import flax.linen as nn
393
+
394
+ class TransformerBlock(nn.Module):
395
+ num_heads: int = 8
396
+ head_dim: int = 64
397
+
398
+ @nn.compact
399
+ def __call__(self, x, mask=None):
400
+ # Project to Q, K, V
401
+ q = nn.Dense(self.num_heads * self.head_dim)(x)
402
+ k = nn.Dense(self.num_heads * self.head_dim)(x)
403
+ v = nn.Dense(self.num_heads * self.head_dim)(x)
404
+
405
+ # Reshape for attention
406
+ shape = (x.shape[0], x.shape[1], self.num_heads, self.head_dim)
407
+ q, k, v = map(lambda t: t.reshape(shape), (q, k, v))
408
+
409
+ # Apply ejKernel Flash Attention
410
+ attn_output = flash_attention(
411
+ q, k, v,
412
+ causal=True,
413
+ attention_mask=mask
414
+ )
415
+
416
+ # Project output
417
+ return nn.Dense(x.shape[-1])(attn_output.reshape(x.shape))
418
+ ```
419
+
420
+ ## Performance
421
+
422
+ ## Development
423
+
424
+ ### Setting Up Development Environment
425
+
426
+ ```bash
427
+ # Clone repository
428
+ git clone https://github.com/erfanzar/ejkernel.git
429
+ cd ejkernel
430
+
431
+ # Create virtual environment
432
+ python -m venv .venv
433
+ source .venv/bin/activate # On Windows: .venv\Scripts\activate
434
+
435
+ # Install in development mode
436
+ pip install -e ".[dev]"
437
+
438
+ # Install pre-commit hooks
439
+ pre-commit install
440
+ ```
441
+
442
+ ### Code Style
443
+
444
+ The project uses:
445
+
446
+ - **black** for code formatting (line length: 121)
447
+ - **ruff** for linting
448
+ - **mypy/pyright** for type checking
449
+ - **pre-commit** for automated checks
450
+
451
+ ### Adding New Kernels
452
+
453
+ 1. **Implement the kernel** in appropriate backend directory:
454
+
455
+ ```python
456
+ # ejkernel/kernels/_triton/my_kernel.py
457
+ @kernel_registry.register("my_kernel", Platform.TRITON, Backend.GPU)
458
+ def my_kernel_triton(x, config):
459
+ # Implementation
460
+ pass
461
+ ```
462
+
463
+ 2 **Create module wrapper**:
464
+
465
+ ```python
466
+ # ejkernel/modules/operations/my_kernel.py
467
+ class MyKernel(Kernel[MyKernelConfig, Array]):
468
+ # Module implementation
469
+ pass
470
+ ```
471
+
472
+ 3 **Add tests**:
473
+
474
+ ```python
475
+ # test/kernels/_triton/test_my_kernel.py
476
+ class TestMyKernel(unittest.TestCase):
477
+ # Test implementation
478
+ pass
479
+ ```
480
+
481
+ 4 **Update documentation**
482
+
483
+ ## Testing
484
+
485
+ ### Running Tests
486
+
487
+ ```bash
488
+ # Run all tests
489
+ python test/run_tests.py
490
+
491
+ # Platform-specific tests
492
+ python test/run_tests.py --xla # XLA implementations
493
+ python test/run_tests.py --triton # Triton implementations
494
+ python test/run_tests.py --pallas # Pallas implementations
495
+
496
+ # Cross-platform validation
497
+ python test/run_tests.py --comparison
498
+
499
+ # Specific test patterns
500
+ python test/run_tests.py -k "flash_attention"
501
+ python test/run_tests.py --verbose --failfast
502
+ ```
503
+
504
+ ### Test Categories
505
+
506
+ - **Unit Tests**: Individual component testing
507
+ - **Integration Tests**: End-to-end workflows
508
+ - **Comparison Tests**: Cross-backend consistency
509
+ - **Performance Tests**: Regression detection
510
+ - **Property Tests**: Invariant verification
511
+
512
+ ### Continuous Integration
513
+
514
+ The project uses GitHub Actions for CI with tests across:
515
+
516
+ - Multiple Python versions (3.11, 3.12, 3.13)
517
+ - Multiple platforms (CPU, GPU, TPU)
518
+ - Multiple JAX versions
519
+
520
+ ## Contributing
521
+
522
+ We welcome contributions! See [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines.
523
+
524
+ ### Priority Areas
525
+
526
+ - TPU/Pallas implementations for existing algorithms
527
+ - CUDA native kernels for maximum performance
528
+ - New attention mechanisms from recent papers
529
+ - Performance optimizations and kernel fusion
530
+ - Documentation and examples
531
+
532
+ ### Contribution Process
533
+
534
+ 1. Fork the repository
535
+ 2. Create a feature branch
536
+ 3. Implement your changes with tests
537
+ 4. Ensure all tests pass
538
+ 5. Submit a pull request
539
+
540
+ ## Documentation
541
+
542
+ Comprehensive documentation available at [ejkernel.readthedocs.io](https://ejkernel.readthedocs.io/en/latest/)
543
+
544
+ - **[API Reference](https://ejkernel.readthedocs.io/en/latest/api/)**: Complete API documentation
545
+ - **[Tutorials](https://ejkernel.readthedocs.io/en/latest/tutorials/)**: Step-by-step guides
546
+ - **[Architecture](https://ejkernel.readthedocs.io/en/latest/architecture/)**: Design documentation
547
+ - **[Benchmarks](https://ejkernel.readthedocs.io/en/latest/benchmarks/)**: Performance analysis
548
+
549
+ ## Citation
550
+
551
+ If you use ejKernel in your research, please cite:
552
+
553
+ ```bibtex
554
+ @software{ejkernel2024,
555
+ author = {Erfan Zare Chavoshi},
556
+ title = {ejKernel: High-Performance JAX Kernels for Deep Learning},
557
+ year = {2024},
558
+ url = {https://github.com/erfanzar/ejkernel},
559
+ note = {Production-grade kernel library with multi-backend support}
560
+ }
561
+ ```
562
+
563
+ ## License
564
+
565
+ ejKernel is licensed under the Apache License 2.0. See [LICENSE](LICENSE) for details.
566
+
567
+ ## Acknowledgments
568
+
569
+ ejKernel builds upon excellent work from:
570
+
571
+ - [JAX](https://github.com/google/jax) - Composable transformations of Python+NumPy programs
572
+ - [Triton](https://github.com/openai/triton) - GPU kernel programming language
573
+ - [Pallas](https://github.com/google/jax/tree/main/jax/experimental/pallas) - JAX kernel language
574
+ - [Flash Attention](https://github.com/Dao-AILab/flash-attention) - Memory-efficient attention
575
+ - [EasyDeL](https://github.com/erfanzar/EasyDeL) - Parent framework for JAX deep learning
576
+
577
+ ## Community
578
+
579
+ - **GitHub Issues**: [Bug reports and feature requests](https://github.com/erfanzar/ejkernel/issues)
580
+ - **Discussions**: [Community forum](https://github.com/erfanzar/ejkernel/discussions)
581
+ - **Email**: <Erfanzare810@gmail.com>
582
+
583
+ ## Roadmap
584
+
585
+ ### Near Term (Q1 2025)
586
+
587
+ - Flash Attention 3 implementation
588
+ - Complete CUDA backend
589
+ - Quantized attention (INT8/INT4)
590
+ - Fused operations (LayerNorm+Attention)
591
+
592
+ ### Medium Term (Q2-Q3 2025)
593
+
594
+ - Speculative decoding support
595
+ - Continuous batching
596
+ - Mamba SSM kernels
597
+
598
+ ### Long Term (Q4 2025+)
599
+
600
+ - Multi-GPU kernel fusion
601
+ - Automatic kernel selection ML model
602
+ - Custom DSL for kernel development
603
+ - Hardware-agnostic IR
604
+
605
+ ---
606
+
607
+ ejKernel - Production-grade kernels for JAX deep learning