ejkernel 0.0.21__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (192) hide show
  1. ejkernel-0.0.21/PKG-INFO +711 -0
  2. ejkernel-0.0.21/README.md +673 -0
  3. ejkernel-0.0.21/ejkernel/__init__.py +51 -0
  4. ejkernel-0.0.21/ejkernel/benchmarks.py +772 -0
  5. ejkernel-0.0.21/ejkernel/callib/__init__.py +43 -0
  6. ejkernel-0.0.21/ejkernel/callib/_ejit.py +521 -0
  7. ejkernel-0.0.21/ejkernel/callib/_pallas_call.py +173 -0
  8. ejkernel-0.0.21/ejkernel/callib/_triton_call.py +891 -0
  9. ejkernel-0.0.21/ejkernel/callib/_utils.py +243 -0
  10. ejkernel-0.0.21/ejkernel/kernels/__init__.py +65 -0
  11. ejkernel-0.0.21/ejkernel/kernels/_cuda/__init__.py +15 -0
  12. ejkernel-0.0.21/ejkernel/kernels/_pallas/__init__.py +18 -0
  13. ejkernel-0.0.21/ejkernel/kernels/_pallas/gpu/__init__.py +19 -0
  14. ejkernel-0.0.21/ejkernel/kernels/_pallas/gpu/ragged_decode_attention/__init__.py +18 -0
  15. ejkernel-0.0.21/ejkernel/kernels/_pallas/gpu/ragged_decode_attention/_interface.py +81 -0
  16. ejkernel-0.0.21/ejkernel/kernels/_pallas/gpu/ragged_decode_attention/_pallas_impl_fwd.py +382 -0
  17. ejkernel-0.0.21/ejkernel/kernels/_pallas/gpu/scaled_dot_product_attention/__init__.py +25 -0
  18. ejkernel-0.0.21/ejkernel/kernels/_pallas/gpu/scaled_dot_product_attention/_interface.py +131 -0
  19. ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/__init__.py +36 -0
  20. ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/blocksparse_attention/__init__.py +66 -0
  21. ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/blocksparse_attention/_info.py +975 -0
  22. ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/blocksparse_attention/_kernel.py +2573 -0
  23. ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/blocksparse_attention/_masks.py +546 -0
  24. ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/flash_attention/__init__.py +24 -0
  25. ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/flash_attention/_interface.py +188 -0
  26. ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/flash_attention/_pallas_impl_bwd.py +676 -0
  27. ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/flash_attention/_pallas_impl_fwd.py +451 -0
  28. ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/flash_attention/_utils.py +373 -0
  29. ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/grouped_matmul/__init__.py +18 -0
  30. ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/grouped_matmul/_interface.py +250 -0
  31. ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/grouped_matmul/_pallas_impl.py +979 -0
  32. ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/grouped_matmul/_utils.py +191 -0
  33. ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/grouped_matmulv2/__init__.py +18 -0
  34. ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/grouped_matmulv2/_interface.py +250 -0
  35. ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/grouped_matmulv2/_pallas_impl.py +545 -0
  36. ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/page_attention/__init__.py +18 -0
  37. ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/page_attention/_interface.py +253 -0
  38. ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/page_attention/_pallas_impl_fwd.py +290 -0
  39. ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/ragged_decode_attention/__init__.py +18 -0
  40. ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/ragged_decode_attention/_interface.py +74 -0
  41. ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/ragged_decode_attention/_pallas_impl_fwd.py +247 -0
  42. ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/ragged_page_attention_v2/__init__.py +20 -0
  43. ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/ragged_page_attention_v2/_interface.py +257 -0
  44. ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/ragged_page_attention_v2/_pallas_impl_fwd.py +602 -0
  45. ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/ragged_page_attention_v2/_utils.py +437 -0
  46. ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/ragged_page_attention_v3/__init__.py +23 -0
  47. ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/ragged_page_attention_v3/_interface.py +120 -0
  48. ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/ragged_page_attention_v3/_pallas_impl_fwd.py +1300 -0
  49. ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/ragged_page_attention_v3/_pallas_impl_fwd_h64.py +1250 -0
  50. ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/ragged_page_attention_v3/_utils.py +4545 -0
  51. ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/ring_attention/__init__.py +24 -0
  52. ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/ring_attention/_interface.py +225 -0
  53. ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/ring_attention/_pallas_impl_bwd.py +1003 -0
  54. ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/ring_attention/_pallas_impl_fwd.py +769 -0
  55. ejkernel-0.0.21/ejkernel/kernels/_pallas/tpu/ring_attention/_utils.py +192 -0
  56. ejkernel-0.0.21/ejkernel/kernels/_registry.py +377 -0
  57. ejkernel-0.0.21/ejkernel/kernels/_triton/__init__.py +59 -0
  58. ejkernel-0.0.21/ejkernel/kernels/_triton/blocksparse_attention/__init__.py +33 -0
  59. ejkernel-0.0.21/ejkernel/kernels/_triton/blocksparse_attention/_interface.py +462 -0
  60. ejkernel-0.0.21/ejkernel/kernels/_triton/blocksparse_attention/_mask.py +530 -0
  61. ejkernel-0.0.21/ejkernel/kernels/_triton/blocksparse_attention/_triton_impl_bwd.py +1380 -0
  62. ejkernel-0.0.21/ejkernel/kernels/_triton/blocksparse_attention/_triton_impl_fwd.py +695 -0
  63. ejkernel-0.0.21/ejkernel/kernels/_triton/blocksparse_attention/_utilities.py +437 -0
  64. ejkernel-0.0.21/ejkernel/kernels/_triton/flash_attention/__init__.py +24 -0
  65. ejkernel-0.0.21/ejkernel/kernels/_triton/flash_attention/_interface.py +380 -0
  66. ejkernel-0.0.21/ejkernel/kernels/_triton/flash_attention/_triton_impl_bwd.py +1494 -0
  67. ejkernel-0.0.21/ejkernel/kernels/_triton/flash_attention/_triton_impl_fwd.py +827 -0
  68. ejkernel-0.0.21/ejkernel/kernels/_triton/flash_attention/_utilities.py +358 -0
  69. ejkernel-0.0.21/ejkernel/kernels/_triton/flash_mla/__init__.py +38 -0
  70. ejkernel-0.0.21/ejkernel/kernels/_triton/flash_mla/_interface.py +100 -0
  71. ejkernel-0.0.21/ejkernel/kernels/_triton/flash_mla/_triton_impl_bwd.py +15 -0
  72. ejkernel-0.0.21/ejkernel/kernels/_triton/flash_mla/_triton_impl_fwd.py +180 -0
  73. ejkernel-0.0.21/ejkernel/kernels/_triton/flash_mla/_utilities.py +15 -0
  74. ejkernel-0.0.21/ejkernel/kernels/_triton/gla/__init__.py +18 -0
  75. ejkernel-0.0.21/ejkernel/kernels/_triton/gla/_interface.py +136 -0
  76. ejkernel-0.0.21/ejkernel/kernels/_triton/lightning_attn/__init__.py +18 -0
  77. ejkernel-0.0.21/ejkernel/kernels/_triton/lightning_attn/_interface.py +143 -0
  78. ejkernel-0.0.21/ejkernel/kernels/_triton/mean_pooling/__init__.py +30 -0
  79. ejkernel-0.0.21/ejkernel/kernels/_triton/mean_pooling/_interface.py +163 -0
  80. ejkernel-0.0.21/ejkernel/kernels/_triton/mean_pooling/_triton_impl_bwd.py +130 -0
  81. ejkernel-0.0.21/ejkernel/kernels/_triton/mean_pooling/_triton_impl_fwd.py +126 -0
  82. ejkernel-0.0.21/ejkernel/kernels/_triton/native_sparse_attention/__init__.py +32 -0
  83. ejkernel-0.0.21/ejkernel/kernels/_triton/native_sparse_attention/_compression.py +753 -0
  84. ejkernel-0.0.21/ejkernel/kernels/_triton/native_sparse_attention/_interface.py +403 -0
  85. ejkernel-0.0.21/ejkernel/kernels/_triton/native_sparse_attention/_triton_impl_bwd.py +394 -0
  86. ejkernel-0.0.21/ejkernel/kernels/_triton/native_sparse_attention/_triton_impl_fwd.py +428 -0
  87. ejkernel-0.0.21/ejkernel/kernels/_triton/native_sparse_attention/_utilities.py +87 -0
  88. ejkernel-0.0.21/ejkernel/kernels/_triton/page_attention/__init__.py +18 -0
  89. ejkernel-0.0.21/ejkernel/kernels/_triton/page_attention/_interface.py +369 -0
  90. ejkernel-0.0.21/ejkernel/kernels/_triton/page_attention/_triton_impl_fwd.py +257 -0
  91. ejkernel-0.0.21/ejkernel/kernels/_triton/ragged_decode_attention/__init__.py +18 -0
  92. ejkernel-0.0.21/ejkernel/kernels/_triton/ragged_decode_attention/_interface.py +75 -0
  93. ejkernel-0.0.21/ejkernel/kernels/_triton/ragged_decode_attention/_triton_impl_fwd.py +302 -0
  94. ejkernel-0.0.21/ejkernel/kernels/_triton/ragged_page_attention_v2/__init__.py +18 -0
  95. ejkernel-0.0.21/ejkernel/kernels/_triton/ragged_page_attention_v2/_interface.py +226 -0
  96. ejkernel-0.0.21/ejkernel/kernels/_triton/ragged_page_attention_v2/_triton_impl_fwd.py +613 -0
  97. ejkernel-0.0.21/ejkernel/kernels/_triton/ragged_page_attention_v3/_triton_impl_fwd.py +222 -0
  98. ejkernel-0.0.21/ejkernel/kernels/_triton/recurrent/__init__.py +30 -0
  99. ejkernel-0.0.21/ejkernel/kernels/_triton/recurrent/_interface.py +291 -0
  100. ejkernel-0.0.21/ejkernel/kernels/_triton/recurrent/_triton_impl_bwd.py +354 -0
  101. ejkernel-0.0.21/ejkernel/kernels/_triton/recurrent/_triton_impl_fwd.py +216 -0
  102. ejkernel-0.0.21/ejkernel/kernels/_xla/__init__.py +46 -0
  103. ejkernel-0.0.21/ejkernel/kernels/_xla/attention/__init__.py +18 -0
  104. ejkernel-0.0.21/ejkernel/kernels/_xla/attention/_interface.py +193 -0
  105. ejkernel-0.0.21/ejkernel/kernels/_xla/flash_attention/__init__.py +24 -0
  106. ejkernel-0.0.21/ejkernel/kernels/_xla/flash_attention/_interface.py +376 -0
  107. ejkernel-0.0.21/ejkernel/kernels/_xla/flash_attention/_xla_impl_bwd.py +85 -0
  108. ejkernel-0.0.21/ejkernel/kernels/_xla/flash_attention/_xla_impl_fwd.py +474 -0
  109. ejkernel-0.0.21/ejkernel/kernels/_xla/gla/__init__.py +18 -0
  110. ejkernel-0.0.21/ejkernel/kernels/_xla/gla/_interface.py +123 -0
  111. ejkernel-0.0.21/ejkernel/kernels/_xla/grouped_matmul/__init__.py +18 -0
  112. ejkernel-0.0.21/ejkernel/kernels/_xla/grouped_matmul/_interface.py +140 -0
  113. ejkernel-0.0.21/ejkernel/kernels/_xla/lightning_attn/__init__.py +18 -0
  114. ejkernel-0.0.21/ejkernel/kernels/_xla/lightning_attn/_interface.py +129 -0
  115. ejkernel-0.0.21/ejkernel/kernels/_xla/mean_pooling/__init__.py +18 -0
  116. ejkernel-0.0.21/ejkernel/kernels/_xla/mean_pooling/_interface.py +171 -0
  117. ejkernel-0.0.21/ejkernel/kernels/_xla/native_sparse_attention/__init__.py +32 -0
  118. ejkernel-0.0.21/ejkernel/kernels/_xla/native_sparse_attention/_interface.py +435 -0
  119. ejkernel-0.0.21/ejkernel/kernels/_xla/native_sparse_attention/_xla_impl_bwd.py +138 -0
  120. ejkernel-0.0.21/ejkernel/kernels/_xla/native_sparse_attention/_xla_impl_fwd.py +86 -0
  121. ejkernel-0.0.21/ejkernel/kernels/_xla/page_attention/__init__.py +18 -0
  122. ejkernel-0.0.21/ejkernel/kernels/_xla/page_attention/_interface.py +123 -0
  123. ejkernel-0.0.21/ejkernel/kernels/_xla/page_attention/_xla_impl_fwd.py +97 -0
  124. ejkernel-0.0.21/ejkernel/kernels/_xla/ragged_decode_attention/__init__.py +18 -0
  125. ejkernel-0.0.21/ejkernel/kernels/_xla/ragged_decode_attention/_interface.py +119 -0
  126. ejkernel-0.0.21/ejkernel/kernels/_xla/ragged_decode_attention/_xla_impl_fwd.py +468 -0
  127. ejkernel-0.0.21/ejkernel/kernels/_xla/ragged_page_attention_v2/__init__.py +18 -0
  128. ejkernel-0.0.21/ejkernel/kernels/_xla/ragged_page_attention_v2/_interface.py +109 -0
  129. ejkernel-0.0.21/ejkernel/kernels/_xla/ragged_page_attention_v2/_xla_impl_fwd.py +191 -0
  130. ejkernel-0.0.21/ejkernel/kernels/_xla/ragged_page_attention_v3/__init__.py +18 -0
  131. ejkernel-0.0.21/ejkernel/kernels/_xla/ragged_page_attention_v3/_interface.py +98 -0
  132. ejkernel-0.0.21/ejkernel/kernels/_xla/ragged_page_attention_v3/_kernel.py +427 -0
  133. ejkernel-0.0.21/ejkernel/kernels/_xla/recurrent/__init__.py +30 -0
  134. ejkernel-0.0.21/ejkernel/kernels/_xla/recurrent/_interface.py +201 -0
  135. ejkernel-0.0.21/ejkernel/kernels/_xla/recurrent/_xla_impl_bwd.py +163 -0
  136. ejkernel-0.0.21/ejkernel/kernels/_xla/recurrent/_xla_impl_fwd.py +242 -0
  137. ejkernel-0.0.21/ejkernel/kernels/_xla/ring_attention/__init__.py +24 -0
  138. ejkernel-0.0.21/ejkernel/kernels/_xla/ring_attention/_interface.py +212 -0
  139. ejkernel-0.0.21/ejkernel/kernels/_xla/ring_attention/_utils.py +163 -0
  140. ejkernel-0.0.21/ejkernel/kernels/_xla/ring_attention/_xla_impl_bwd.py +354 -0
  141. ejkernel-0.0.21/ejkernel/kernels/_xla/ring_attention/_xla_impl_fwd.py +388 -0
  142. ejkernel-0.0.21/ejkernel/kernels/_xla/scaled_dot_product_attention/__init__.py +24 -0
  143. ejkernel-0.0.21/ejkernel/kernels/_xla/scaled_dot_product_attention/_interface.py +106 -0
  144. ejkernel-0.0.21/ejkernel/loggings.py +482 -0
  145. ejkernel-0.0.21/ejkernel/modules/__init__.py +144 -0
  146. ejkernel-0.0.21/ejkernel/modules/base.py +213 -0
  147. ejkernel-0.0.21/ejkernel/modules/operations/__init__.py +148 -0
  148. ejkernel-0.0.21/ejkernel/modules/operations/attention.py +286 -0
  149. ejkernel-0.0.21/ejkernel/modules/operations/blocksparse_attention.py +937 -0
  150. ejkernel-0.0.21/ejkernel/modules/operations/configs.py +354 -0
  151. ejkernel-0.0.21/ejkernel/modules/operations/flash_attention.py +886 -0
  152. ejkernel-0.0.21/ejkernel/modules/operations/gated_linear_attention.py +301 -0
  153. ejkernel-0.0.21/ejkernel/modules/operations/grouped_matmul.py +426 -0
  154. ejkernel-0.0.21/ejkernel/modules/operations/lightning_attention.py +306 -0
  155. ejkernel-0.0.21/ejkernel/modules/operations/multi_head_latent_attention.py +290 -0
  156. ejkernel-0.0.21/ejkernel/modules/operations/native_sparse_attention.py +374 -0
  157. ejkernel-0.0.21/ejkernel/modules/operations/page_attention.py +444 -0
  158. ejkernel-0.0.21/ejkernel/modules/operations/pooling.py +244 -0
  159. ejkernel-0.0.21/ejkernel/modules/operations/ragged_decode_attention.py +588 -0
  160. ejkernel-0.0.21/ejkernel/modules/operations/ragged_page_attention_v2.py +702 -0
  161. ejkernel-0.0.21/ejkernel/modules/operations/ragged_page_attention_v3.py +1091 -0
  162. ejkernel-0.0.21/ejkernel/modules/operations/recurrent.py +328 -0
  163. ejkernel-0.0.21/ejkernel/modules/operations/ring_attention.py +617 -0
  164. ejkernel-0.0.21/ejkernel/modules/operations/scaled_dot_product_attention.py +417 -0
  165. ejkernel-0.0.21/ejkernel/ops/__init__.py +150 -0
  166. ejkernel-0.0.21/ejkernel/ops/config/__init__.py +43 -0
  167. ejkernel-0.0.21/ejkernel/ops/config/cache.py +187 -0
  168. ejkernel-0.0.21/ejkernel/ops/config/persistent.py +188 -0
  169. ejkernel-0.0.21/ejkernel/ops/config/selection.py +585 -0
  170. ejkernel-0.0.21/ejkernel/ops/core/__init__.py +36 -0
  171. ejkernel-0.0.21/ejkernel/ops/core/kernel.py +671 -0
  172. ejkernel-0.0.21/ejkernel/ops/core/types.py +50 -0
  173. ejkernel-0.0.21/ejkernel/ops/execution/__init__.py +67 -0
  174. ejkernel-0.0.21/ejkernel/ops/execution/batch.py +191 -0
  175. ejkernel-0.0.21/ejkernel/ops/execution/executor.py +514 -0
  176. ejkernel-0.0.21/ejkernel/ops/execution/offline.py +80 -0
  177. ejkernel-0.0.21/ejkernel/ops/execution/profiler.py +475 -0
  178. ejkernel-0.0.21/ejkernel/ops/execution/tuning.py +1194 -0
  179. ejkernel-0.0.21/ejkernel/ops/registry.py +93 -0
  180. ejkernel-0.0.21/ejkernel/ops/utils/__init__.py +72 -0
  181. ejkernel-0.0.21/ejkernel/ops/utils/datacarrier.py +126 -0
  182. ejkernel-0.0.21/ejkernel/ops/utils/fingerprint.py +342 -0
  183. ejkernel-0.0.21/ejkernel/ops/utils/meta.py +156 -0
  184. ejkernel-0.0.21/ejkernel/ops/utils/serialize.py +97 -0
  185. ejkernel-0.0.21/ejkernel/types/__init__.py +18 -0
  186. ejkernel-0.0.21/ejkernel/types/mask.py +2248 -0
  187. ejkernel-0.0.21/ejkernel/utils.py +837 -0
  188. ejkernel-0.0.21/ejkernel/xla_utils/__init__.py +45 -0
  189. ejkernel-0.0.21/ejkernel/xla_utils/cumsum.py +297 -0
  190. ejkernel-0.0.21/ejkernel/xla_utils/shardings.py +102 -0
  191. ejkernel-0.0.21/ejkernel/xla_utils/utils.py +184 -0
  192. ejkernel-0.0.21/pyproject.toml +126 -0
@@ -0,0 +1,711 @@
1
+ Metadata-Version: 2.3
2
+ Name: ejkernel
3
+ Version: 0.0.21
4
+ Summary: Accelerate, Optimize performance with streamlined training and serving options with JAX.
5
+ Keywords: Deep Learning,Machine Learning,JAX,CUDA,XLA,Triton,Pallas
6
+ Author: Erfan Zare Chavoshi
7
+ Author-email: Erfan Zare Chavoshi <Erfanzare810@gmail.com>
8
+ License: Apache-2.0
9
+ Classifier: Development Status :: 3 - Alpha
10
+ Classifier: Intended Audience :: Developers
11
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
12
+ Classifier: License :: OSI Approved :: Apache Software License
13
+ Classifier: Programming Language :: Python :: 3.10
14
+ Classifier: Programming Language :: Python :: 3.11
15
+ Classifier: Programming Language :: Python :: 3.12
16
+ Classifier: Programming Language :: Python :: 3.13
17
+ Requires-Dist: beartype>=0.22.2
18
+ Requires-Dist: chex>=0.1.91
19
+ Requires-Dist: einops>=0.8.1
20
+ Requires-Dist: jax>=0.8.0
21
+ Requires-Dist: jaxlib>=0.8.0
22
+ Requires-Dist: jaxtyping>=0.3.2
23
+ Requires-Dist: pydantic>=2.11.10
24
+ Requires-Dist: triton==3.4.0
25
+ Requires-Dist: jax[cuda12]>=0.8.0 ; extra == 'gpu'
26
+ Requires-Dist: xprof>=2.20.6 ; extra == 'profile'
27
+ Requires-Dist: tb-nightly>=2.21.0a20250820 ; extra == 'profile'
28
+ Requires-Dist: xprof-nightly>=2.21.6a20250820 ; extra == 'profile'
29
+ Requires-Dist: jax[tpu]>=0.8.0 ; extra == 'tpu'
30
+ Requires-Python: >=3.11, <3.14
31
+ Project-URL: Documentation, https://ejkernel.readthedocs.io/en/latest/
32
+ Project-URL: Homepage, https://github.com/erfanzar/ejkernel
33
+ Project-URL: Repository, https://github.com/erfanzar/ejkernel
34
+ Provides-Extra: gpu
35
+ Provides-Extra: profile
36
+ Provides-Extra: tpu
37
+ Description-Content-Type: text/markdown
38
+
39
+ # ejKernel: High-Performance JAX Kernels for Deep Learning
40
+
41
+ > *"The best optimization is the one you don't have to think about."*
42
+
43
+ [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
44
+ [![Python 3.11+](https://img.shields.io/badge/python-3.11+-blue.svg)](https://www.python.org/downloads/)
45
+ [![JAX](https://img.shields.io/badge/JAX-0.8.0+-orange.svg)](https://github.com/google/jax)
46
+ [![Documentation](https://img.shields.io/badge/docs-readthedocs-green.svg)](https://ejkernel.readthedocs.io/en/latest/)
47
+
48
+ ejKernel is a production-grade kernel library for JAX that provides highly optimized implementations of deep learning operations with automatic multi-backend support. The library features a sophisticated configuration management system with autotuning, comprehensive type safety, and seamless execution across GPUs, TPUs, and CPUs.
49
+
50
+ ## Table of Contents
51
+
52
+ - [Key Features](#key-features)
53
+ - [Installation](#installation)
54
+ - [Quick Start](#quick-start)
55
+ - [Architecture Overview](#architecture-overview)
56
+ - [Supported Operations](#supported-operations)
57
+ - [Advanced Usage](#advanced-usage)
58
+ - [Development](#development)
59
+ - [Testing](#testing)
60
+ - [Contributing](#contributing)
61
+ - [Citation](#citation)
62
+ - [License](#license)
63
+
64
+ ## Key Features
65
+
66
+ ### Intelligent Kernel Management
67
+
68
+ - **7-Tier Configuration System**: Override → Overlay → Memory Cache → Persistent Cache → Autotune → Heuristics → Error
69
+ - **Automatic Platform Detection**: Seamlessly selects optimal implementation based on hardware
70
+ - **Priority-Based Registry**: Multi-backend support with intelligent fallback mechanisms
71
+ - **Device Fingerprinting**: Hardware-specific configuration caching for optimal performance
72
+
73
+ ### State-of-the-Art Operations
74
+
75
+ - **15+ Attention Mechanisms**: Flash Attention v2, Ring Attention, Page Attention, Block Sparse, GLA, Lightning, Ragged Page Attention, and more
76
+ - **Memory Efficiency**: Custom VJP implementations with O(N) memory complexity for attention
77
+ - **Distributed Support**: Full shard_map integration for model and data parallelism
78
+ - **Mixed Precision**: Comprehensive dtype support with automatic gradient conversion
79
+
80
+ ### Production-Ready Infrastructure
81
+
82
+ - **Type Safety**: Full jaxtyping annotations with runtime validation via beartype
83
+ - **Comprehensive Testing**: Cross-backend validation, performance benchmarks, integration tests
84
+ - **Atomic Persistence**: Thread-safe configuration storage with automatic optimization
85
+ - **Profiling Integration**: Built-in support for JAX profiling and performance monitoring
86
+
87
+ ## Installation
88
+
89
+ ### Basic Installation
90
+
91
+ ```bash
92
+ pip install ejkernel
93
+ ```
94
+
95
+ ### Platform-Specific Installation
96
+
97
+ ```bash
98
+ # GPU Support (CUDA/ROCm)
99
+ pip install ejkernel[gpu]
100
+
101
+ # TPU Support
102
+ pip install ejkernel[tpu]
103
+
104
+ # Development Installation
105
+ git clone https://github.com/erfanzar/ejkernel.git
106
+ cd ejkernel
107
+ pip install -e ".[dev]"
108
+ ```
109
+
110
+ ### Dependencies
111
+
112
+ - Python 3.11-3.13
113
+ - JAX >= 0.8.0
114
+ - Triton == 3.4.0 (for GPU)
115
+ - jaxtyping >= 0.3.2
116
+ - beartype >= 0.22.2
117
+
118
+ ## Quick Start
119
+
120
+ ### Simple API with Automatic Optimization
121
+
122
+ ```python
123
+ import jax.numpy as jnp
124
+ from ejkernel.modules import flash_attention
125
+
126
+ # Basic usage - automatic configuration selection
127
+ output = flash_attention(
128
+ query, key, value,
129
+ causal=True,
130
+ dropout_prob=0.1
131
+ )
132
+
133
+ # With advanced features
134
+ output = flash_attention(
135
+ query, key, value,
136
+ causal=True,
137
+ sliding_window=128, # Local attention window
138
+ logits_soft_cap=30.0, # Gemma-2 style soft capping
139
+ attention_mask=mask, # Custom attention pattern
140
+ )
141
+ ```
142
+
143
+ ### Custom Configuration
144
+
145
+ ```python
146
+ from ejkernel.modules import FlashAttentionConfig
147
+ from ejkernel.ops.utils.datacarrier import FwdParams, BwdParams
148
+
149
+ # Create optimized configuration
150
+ config = FlashAttentionConfig(
151
+ fwd_params=FwdParams(
152
+ q_blocksize=256,
153
+ kv_blocksize=256,
154
+ num_warps=8,
155
+ num_stages=2
156
+ ),
157
+ bwd_params=BwdParams(
158
+ q_blocksize=128,
159
+ kv_blocksize=128,
160
+ num_warps=4
161
+ ),
162
+ platform="triton", # Force specific backend
163
+ backend="gpu"
164
+ )
165
+
166
+ output = flash_attention(query, key, value, cfg=config)
167
+ ```
168
+
169
+ ### Direct Kernel Registry Access
170
+
171
+ ```python
172
+ from ejkernel import kernel_registry, Platform, Backend
173
+
174
+ # Get specific implementation
175
+ kernel = kernel_registry.get(
176
+ algorithm="flash_attention",
177
+ platform=Platform.TRITON,
178
+ backend=Backend.GPU
179
+ )
180
+
181
+ # Direct execution
182
+ output = kernel(query, key, value, causal=True)
183
+ ```
184
+
185
+ ### Distributed Execution
186
+
187
+ ```python
188
+ import jax
189
+ from jax.sharding import Mesh, PartitionSpec as P
190
+ from ejkernel.modules import flash_attention
191
+
192
+ # Setup mesh for distributed execution
193
+ devices = jax.devices()
194
+ mesh = Mesh(devices, axis_names=("data", "model"))
195
+
196
+ # Run distributed attention
197
+ output = flash_attention(
198
+ query, key, value,
199
+ causal=True,
200
+ mesh=mesh,
201
+ in_specs=(P("data", None), P("data", None), P("data", None)),
202
+ out_specs=P("data", None)
203
+ )
204
+ ```
205
+
206
+ ## Architecture Overview
207
+
208
+ ### System Design
209
+
210
+ ejKernel employs a sophisticated layered architecture that separates concerns while maintaining high performance:
211
+
212
+ ```md
213
+ ┌─────────────────────────────────────────────────────┐
214
+ │ Public API (modules/) │
215
+ │ Simple functions with sensible defaults │
216
+ ├─────────────────────────────────────────────────────┤
217
+ │ Operations Layer (ops/) │
218
+ │ Configuration management, autotuning, caching │
219
+ ├─────────────────────────────────────────────────────┤
220
+ │ Kernel Registry (kernels/) │
221
+ │ Platform routing, signature validation │
222
+ ├─────────────────────────────────────────────────────┤
223
+ │ Backend Implementations (kernels/_*) │
224
+ │ Triton, Pallas, XLA, CUDA kernels │
225
+ └─────────────────────────────────────────────────────┘
226
+ ```
227
+
228
+ ### Project Structure
229
+
230
+ ```md
231
+ ejkernel/
232
+ ├── kernels/ # Low-level kernel implementations
233
+ │ ├── _triton/ # Triton kernels (GPU)
234
+ │ │ ├── flash_attention/
235
+ │ │ ├── page_attention/
236
+ │ │ ├── ragged_page_attention_v2/
237
+ │ │ ├── gated_linear_attention/
238
+ │ │ ├── lightning_attn/
239
+ │ │ ├── mean_pooling/
240
+ │ │ ├── native_sparse_attention/
241
+ │ │ ├── recurrent/
242
+ │ │ └── blocksparse_attention/
243
+ │ ├── _pallas/
244
+ │ │ ├── tpu/ # TPU-specific implementations
245
+ │ │ │ ├── flash_attention/
246
+ │ │ │ ├── ring_attention/
247
+ │ │ │ ├── page_attention/
248
+ │ │ │ ├── ragged_page_attention_v2/
249
+ │ │ │ ├── ragged_page_attention_v3/
250
+ │ │ │ ├── blocksparse_attention/
251
+ │ │ │ ├── grouped_matmul/
252
+ │ │ │ └── ragged_decode_attention/
253
+ │ │ └── gpu/ # GPU Pallas implementations
254
+ │ ├── _xla/ # XLA implementations (universal)
255
+ │ │ ├── attention/
256
+ │ │ ├── flash_attention/
257
+ │ │ ├── gated_linear_attention/
258
+ │ │ ├── grouped_matmul/
259
+ │ │ ├── lightning_attn/
260
+ │ │ ├── mean_pooling/
261
+ │ │ ├── native_sparse_attention/
262
+ │ │ ├── page_attention/
263
+ │ │ ├── ragged_decode_attention/
264
+ │ │ ├── ragged_page_attention_v2/
265
+ │ │ ├── ragged_page_attention_v3/
266
+ │ │ ├── recurrent/
267
+ │ │ ├── ring_attention/
268
+ │ │ └── scaled_dot_product_attention/
269
+ │ ├── _cuda/ # CUDA implementations (dev)
270
+ │ └── _registry.py # Kernel registry system
271
+
272
+ ├── modules/ # High-level API
273
+ │ └── operations/
274
+ │ ├── flash_attention.py
275
+ │ ├── ring_attention.py
276
+ │ ├── page_attention.py
277
+ │ ├── ragged_page_attention_v2.py
278
+ │ ├── ragged_page_attention_v3.py
279
+ │ ├── blocksparse_attention.py
280
+ │ ├── gated_linear_attention.py
281
+ │ ├── lightning_attention.py
282
+ │ ├── native_sparse_attention.py
283
+ │ ├── recurrent.py
284
+ │ ├── grouped_matmul.py
285
+ │ ├── pooling.py
286
+ │ ├── attention.py
287
+ │ ├── multi_head_latent_attention.py
288
+ │ ├── ragged_decode_attention.py
289
+ │ ├── scaled_dot_product_attention.py
290
+ │ └── configs.py
291
+
292
+ ├── ops/ # Configuration & execution framework
293
+ │ ├── config/ # Configuration management
294
+ │ │ ├── cache.py # In-memory config cache
295
+ │ │ ├── persistent.py # Disk-based persistence
296
+ │ │ └── selection.py # Config selection chain
297
+ │ ├── core/ # Base kernel class
298
+ │ ├── execution/ # Execution orchestration
299
+ │ │ └── tuning.py # Autotuning framework
300
+ │ ├── registry.py # Operation invocation tracking
301
+ │ └── utils/ # Utilities (fingerprinting, etc)
302
+
303
+ ├── xla_utils/ # XLA-specific utilities
304
+ │ ├── cumsum.py # Cumulative sum operations
305
+ │ ├── shardings.py # Sharding utilities
306
+ │ └── utils.py # Sequence length utilities
307
+
308
+ ├── types/ # Type definitions
309
+ │ └── mask.py # MaskInfo for attention masking
310
+
311
+ ├── callib/ # Calling library
312
+ │ ├── _ejit.py # Enhanced JIT
313
+ │ ├── _triton_call.py # Triton kernel calling
314
+ │ └── _pallas_call.py # Pallas kernel calling
315
+
316
+ └── utils.py # General utilities
317
+ ```
318
+
319
+ ### Core Components
320
+
321
+ #### Kernel Registry
322
+
323
+ The registry provides automatic platform-specific kernel selection:
324
+
325
+ ```python
326
+ @kernel_registry.register("my_operation", Platform.TRITON, Backend.GPU, priority=100)
327
+ def my_operation_gpu(x, y):
328
+ # GPU-optimized implementation
329
+ pass
330
+
331
+ @kernel_registry.register("my_operation", Platform.XLA, Backend.ANY, priority=50)
332
+ def my_operation_fallback(x, y):
333
+ # Universal fallback
334
+ pass
335
+
336
+ # Automatic selection based on available hardware
337
+ impl = kernel_registry.get("my_operation")
338
+ ```
339
+
340
+ #### Configuration Management
341
+
342
+ Multi-tier configuration system with intelligent fallback:
343
+
344
+ ```python
345
+ class ConfigSelectorChain:
346
+ """
347
+ Selection hierarchy:
348
+ 1. Override - Explicit user configuration
349
+ 2. Overlay - Temporary context overrides
350
+ 3. Memory Cache - In-memory lookup
351
+ 4. Persistent Cache - Disk-based storage
352
+ 5. Autotune - Performance benchmarking
353
+ 6. Heuristics - Intelligent defaults
354
+ 7. Error - Clear failure message
355
+ """
356
+ ```
357
+
358
+ #### Custom VJP System
359
+
360
+ All performance-critical kernels implement memory-efficient gradients:
361
+
362
+ ```python
363
+ @jax.custom_vjp
364
+ def kernel_with_custom_grad(inputs):
365
+ return forward(inputs)
366
+
367
+ def kernel_fwd(inputs):
368
+ output, residuals = forward_with_residuals(inputs)
369
+ return output, residuals
370
+
371
+ def kernel_bwd(residuals, grad_output):
372
+ return efficient_backward(residuals, grad_output)
373
+
374
+ kernel_with_custom_grad.defvjp(kernel_fwd, kernel_bwd)
375
+ ```
376
+
377
+ ## Supported Operations
378
+
379
+ ### Attention Mechanisms
380
+
381
+ | Algorithm | Description | Memory | Key Features |
382
+ |-----------|-------------|--------|--------------|
383
+ | **Flash Attention v2** | Memory-efficient exact attention | O(N) | Causal masking, dropout, sliding windows, soft capping |
384
+ | **Ring Attention** | Distributed sequence parallelism | O(N/P) | Ultra-long sequences, communication overlap |
385
+ | **Page Attention** | KV-cache optimized inference | O(N) | Block-wise memory, continuous batching |
386
+ | **Block Sparse Attention** | Configurable sparse patterns | O(N√N) | Local+global, custom patterns |
387
+ | **GLA** | Gated Linear Attention | O(N) | Linear complexity, gated updates |
388
+ | **Lightning Attention** | Layer-dependent decay | O(N) | Exponential moving average |
389
+ | **MLA** | Multi-head Latent Attention | O(N) | Compressed KV representation |
390
+ | **Ragged Page Attention v2** | Variable-length paged attention | O(N) | Ragged sequences with page caching |
391
+ | **Ragged Page Attention v3** | Enhanced ragged page attention | O(N) | Attention sinks support, improved handling |
392
+ | **Ragged Decode Attention** | Variable-length decoding | O(N) | Efficient batched inference |
393
+ | **Scaled Dot-Product Attention** | Standard attention | O(N²) | Basic reference implementation |
394
+
395
+ ### Other Operations
396
+
397
+ | Operation | Description | Use Case |
398
+ |-----------|-------------|----------|
399
+ | **Grouped MatMul** | Efficient batched matrix operations | Expert models, MoE |
400
+ | **Grouped MatMul v2** | Enhanced with shard_map support | Distributed expert models |
401
+ | **Mean Pooling** | Variable-length sequence aggregation | Sentence embeddings |
402
+ | **Recurrent** | Optimized RNN/LSTM/GRU operations | Sequential modeling |
403
+ | **Native Sparse** | Block-sparse matrix computations | Sparse attention patterns |
404
+
405
+ ### Platform Support Matrix
406
+
407
+ | Operation | Triton (GPU) | Pallas (TPU) | XLA (Universal) |
408
+ |-----------|:------------:|:------------:|:---------------:|
409
+ | Flash Attention v2 | ✅ | ✅ | ✅ |
410
+ | Ring Attention | ✅ | ✅ | ✅ |
411
+ | Page Attention | ✅ | ✅ | ✅ |
412
+ | Block Sparse Attention | ✅ | ✅ | ✅ |
413
+ | Ragged Page Attention v2 | ✅ | ✅ | ✅ |
414
+ | Ragged Page Attention v3 | - | ✅ | ✅ |
415
+ | Ragged Decode Attention | ✅ | ✅ | ✅ |
416
+ | GLA | ✅ | - | ✅ |
417
+ | Lightning Attention | ✅ | - | ✅ |
418
+ | MLA | ✅ | 🚧 | - |
419
+ | Recurrent | ✅ | - | ✅ |
420
+ | Mean Pooling | ✅ | - | ✅ |
421
+ | Grouped MatMul | - | ✅ | ✅ |
422
+ | Grouped MatMul v2 | - | ✅ | - |
423
+ | Native Sparse Attention | ✅ | - | ✅ |
424
+
425
+ ✅ = Production ready | 🚧 = Under development | - = Not available
426
+
427
+ ## Advanced Usage
428
+
429
+ ### Page Attention for KV-Cache Inference
430
+
431
+ ```python
432
+ from ejkernel.modules import page_attention, PageAttentionConfig
433
+
434
+ # Configure paged attention for inference
435
+ config = PageAttentionConfig(
436
+ platform="auto",
437
+ backend="gpu"
438
+ )
439
+
440
+ output = page_attention(
441
+ query=q,
442
+ key_cache=k_cache,
443
+ value_cache=v_cache,
444
+ block_table=block_table,
445
+ cache_seqlens=cache_seqlens,
446
+ cfg=config
447
+ )
448
+ ```
449
+
450
+ ### Ragged Page Attention for Variable-Length Batches
451
+
452
+ ```python
453
+ from ejkernel.modules import ragged_page_attention_v3, RaggedPageAttentionv3Config
454
+
455
+ # For variable-length sequences with attention sinks
456
+ config = RaggedPageAttentionv3Config(
457
+ platform="pallas",
458
+ backend="tpu"
459
+ )
460
+
461
+ output = ragged_page_attention_v3(
462
+ query=q,
463
+ key_pages=k_pages,
464
+ value_pages=v_pages,
465
+ lengths=seq_lengths,
466
+ page_indices=page_indices,
467
+ cfg=config
468
+ )
469
+ ```
470
+
471
+ ### Performance Optimization
472
+
473
+ ```python
474
+ # Force autotuning for optimal configuration
475
+ import os
476
+ os.environ["EJKERNEL_AUTOTUNE_POLICY"] = "autotune"
477
+ os.environ["EJKERNEL_LOG_AUTOTUNE"] = "1"
478
+
479
+ # Enable profiling
480
+ os.environ["EJKERNEL_OPS_STAMP"] = "json" # Detailed metadata
481
+ os.environ["EJKERNEL_OPS_RECORD"] = "1" # Record invocations
482
+ ```
483
+
484
+ ### Custom Kernel Development
485
+
486
+ ```python
487
+ from ejkernel.ops.core import Kernel
488
+ from ejkernel.modules.operations.configs import BaseOperationConfig
489
+ from dataclasses import dataclass
490
+
491
+ @dataclass
492
+ class MyConfig(BaseOperationConfig):
493
+ param1: int = 128
494
+ param2: float = 0.1
495
+
496
+ class MyKernel(Kernel[MyConfig, Array]):
497
+ def __init__(self):
498
+ super().__init__(op_id="my_kernel")
499
+
500
+ def run(self, x, cfg: MyConfig):
501
+ impl = kernel_registry.get("my_kernel", cfg.platform)
502
+ return impl(x, param1=cfg.param1, param2=cfg.param2)
503
+
504
+ def heuristic_cfg(self, inv):
505
+ # Return default configuration
506
+ return MyConfig(param1=256)
507
+
508
+ def candidate_cfgs(self, inv):
509
+ # Return autotuning candidates
510
+ return [MyConfig(param1=p) for p in [64, 128, 256]]
511
+ ```
512
+
513
+ ### Integration with Flax Models
514
+
515
+ ```python
516
+ import flax.linen as nn
517
+ from ejkernel.modules import flash_attention
518
+
519
+ class TransformerBlock(nn.Module):
520
+ num_heads: int = 8
521
+ head_dim: int = 64
522
+
523
+ @nn.compact
524
+ def __call__(self, x, mask=None):
525
+ # Project to Q, K, V
526
+ q = nn.Dense(self.num_heads * self.head_dim)(x)
527
+ k = nn.Dense(self.num_heads * self.head_dim)(x)
528
+ v = nn.Dense(self.num_heads * self.head_dim)(x)
529
+
530
+ # Reshape for attention
531
+ shape = (x.shape[0], x.shape[1], self.num_heads, self.head_dim)
532
+ q, k, v = map(lambda t: t.reshape(shape), (q, k, v))
533
+
534
+ # Apply ejKernel Flash Attention
535
+ attn_output = flash_attention(
536
+ q, k, v,
537
+ causal=True,
538
+ attention_mask=mask
539
+ )
540
+
541
+ # Project output
542
+ return nn.Dense(x.shape[-1])(attn_output.reshape(x.shape))
543
+ ```
544
+
545
+ ## Development
546
+
547
+ ### Setting Up Development Environment
548
+
549
+ ```bash
550
+ # Clone repository
551
+ git clone https://github.com/erfanzar/ejkernel.git
552
+ cd ejkernel
553
+
554
+ # Create virtual environment
555
+ python -m venv .venv
556
+ source .venv/bin/activate # On Windows: .venv\Scripts\activate
557
+
558
+ # Install in development mode
559
+ pip install -e ".[dev]"
560
+
561
+ # Install pre-commit hooks
562
+ pre-commit install
563
+ ```
564
+
565
+ ### Code Style
566
+
567
+ The project uses:
568
+
569
+ - **black** for code formatting (line length: 121)
570
+ - **ruff** for linting
571
+ - **mypy/pyright** for type checking
572
+ - **pre-commit** for automated checks
573
+
574
+ ### Adding New Kernels
575
+
576
+ 1. **Implement the kernel** in appropriate backend directory:
577
+
578
+ ```python
579
+ # ejkernel/kernels/_triton/my_kernel/_interface.py
580
+ @kernel_registry.register("my_kernel", Platform.TRITON, Backend.GPU)
581
+ def my_kernel_triton(x, config):
582
+ # Implementation
583
+ pass
584
+ ```
585
+
586
+ 1. **Create module wrapper**:
587
+
588
+ ```python
589
+ # ejkernel/modules/operations/my_kernel.py
590
+ class MyKernel(Kernel[MyKernelConfig, Array]):
591
+ # Module implementation
592
+ pass
593
+ ```
594
+
595
+ 1. **Add tests**:
596
+
597
+ ```python
598
+ # test/kernels/_triton/test_my_kernel.py
599
+ class TestMyKernel(unittest.TestCase):
600
+ # Test implementation
601
+ pass
602
+ ```
603
+
604
+ 1. **Update documentation**
605
+
606
+ ## Testing
607
+
608
+ ### Running Tests
609
+
610
+ ```bash
611
+ # Run all tests
612
+ pytest test/
613
+
614
+ # Platform-specific tests
615
+ pytest test/kernels/_xla/ # XLA implementations
616
+ pytest test/kernels/_triton/ # Triton implementations
617
+ pytest test/kernels/_pallas/ # Pallas implementations
618
+
619
+ # Specific test patterns
620
+ pytest -k "flash_attention"
621
+ pytest --verbose --failfast
622
+
623
+ # Module operations tests
624
+ pytest test/test_module_operations.py
625
+ ```
626
+
627
+ ### Test Categories
628
+
629
+ - **Unit Tests**: Individual component testing
630
+ - **Integration Tests**: End-to-end workflows
631
+ - **Comparison Tests**: Cross-backend consistency
632
+ - **Performance Tests**: Regression detection
633
+
634
+ ## Benchmarking
635
+
636
+ Run benchmarks to compare performance across backends:
637
+
638
+ ```bash
639
+ # General attention benchmarks
640
+ python benchmarks/benchmark_attention.py
641
+
642
+ # Ragged page attention benchmarks
643
+ python benchmarks/benchmark_ragged_page_attn.py
644
+ ```
645
+
646
+ ## Contributing
647
+
648
+ We welcome contributions! See [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines.
649
+
650
+ ### Priority Areas
651
+
652
+ - TPU/Pallas implementations for existing algorithms
653
+ - CUDA native kernels for maximum performance
654
+ - New attention mechanisms from recent papers
655
+ - Performance optimizations and kernel fusion
656
+ - Documentation and examples
657
+
658
+ ### Contribution Process
659
+
660
+ 1. Fork the repository
661
+ 1. Create a feature branch
662
+ 1. Implement your changes with tests
663
+ 1. Ensure all tests pass
664
+ 1. Submit a pull request
665
+
666
+ ## Documentation
667
+
668
+ Comprehensive documentation available at [ejkernel.readthedocs.io](https://ejkernel.readthedocs.io/en/latest/)
669
+
670
+ - **[API Reference](https://ejkernel.readthedocs.io/en/latest/api/)**: Complete API documentation
671
+ - **[Tutorials](https://ejkernel.readthedocs.io/en/latest/tutorials/)**: Step-by-step guides
672
+ - **[Architecture](https://ejkernel.readthedocs.io/en/latest/architecture/)**: Design documentation
673
+ - **[Benchmarks](https://ejkernel.readthedocs.io/en/latest/benchmarks/)**: Performance analysis
674
+
675
+ ## Citation
676
+
677
+ If you use ejKernel in your research, please cite:
678
+
679
+ ```bibtex
680
+ @software{ejkernel2024,
681
+ author = {Erfan Zare Chavoshi},
682
+ title = {ejKernel: High-Performance JAX Kernels for Deep Learning},
683
+ year = {2024},
684
+ url = {https://github.com/erfanzar/ejkernel},
685
+ note = {Production-grade kernel library with multi-backend support}
686
+ }
687
+ ```
688
+
689
+ ## License
690
+
691
+ ejKernel is licensed under the Apache License 2.0. See [LICENSE](LICENSE) for details.
692
+
693
+ ## Acknowledgments
694
+
695
+ ejKernel builds upon excellent work from:
696
+
697
+ - [JAX](https://github.com/google/jax) - Composable transformations of Python+NumPy programs
698
+ - [Triton](https://github.com/openai/triton) - GPU kernel programming language
699
+ - [Pallas](https://github.com/google/jax/tree/main/jax/experimental/pallas) - JAX kernel language
700
+ - [Flash Attention](https://github.com/Dao-AILab/flash-attention) - Memory-efficient attention
701
+ - [EasyDeL](https://github.com/erfanzar/EasyDeL) - Parent framework for JAX deep learning
702
+
703
+ ## Community
704
+
705
+ - **GitHub Issues**: [Bug reports and feature requests](https://github.com/erfanzar/ejkernel/issues)
706
+ - **Discussions**: [Community forum](https://github.com/erfanzar/ejkernel/discussions)
707
+ - **Email**: <Erfanzare810@gmail.com>
708
+
709
+ ---
710
+
711
+ **ejKernel** - Production-grade kernels for JAX deep learning