areno 0.0.0.dev0__tar.gz → 0.0.0.dev1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. areno-0.0.0.dev1/MANIFEST.in +1 -0
  2. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/PKG-INFO +1 -1
  3. areno-0.0.0.dev1/areno/accel/csrc/activation.cu +360 -0
  4. areno-0.0.0.dev1/areno/accel/csrc/atomic_utils.cuh +25 -0
  5. areno-0.0.0.dev1/areno/accel/csrc/conv.cu +438 -0
  6. areno-0.0.0.dev1/areno/accel/csrc/embedding.cu +104 -0
  7. areno-0.0.0.dev1/areno/accel/csrc/extension.cpp +163 -0
  8. areno-0.0.0.dev1/areno/accel/csrc/linear.cu +447 -0
  9. areno-0.0.0.dev1/areno/accel/csrc/moe_align_kernel.cu +320 -0
  10. areno-0.0.0.dev1/areno/accel/csrc/moe_permute.cu +349 -0
  11. areno-0.0.0.dev1/areno/accel/csrc/normalization.cu +393 -0
  12. areno-0.0.0.dev1/areno/accel/csrc/router.cu +177 -0
  13. areno-0.0.0.dev1/areno/accel/csrc/topk.cu +257 -0
  14. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno.egg-info/PKG-INFO +1 -1
  15. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno.egg-info/SOURCES.txt +12 -0
  16. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/pyproject.toml +1 -1
  17. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/__init__.py +0 -0
  18. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/accel/__init__.py +0 -0
  19. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/accel/_extension.py +0 -0
  20. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/accel/activations.py +0 -0
  21. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/accel/conv.py +0 -0
  22. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/accel/embedding.py +0 -0
  23. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/accel/kernels/__init__.py +0 -0
  24. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/accel/kernels/fused_moe.py +0 -0
  25. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/accel/kernels/group_rmsnorm.py +0 -0
  26. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/accel/kernels/seg_la.py +0 -0
  27. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/accel/linear.py +0 -0
  28. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/accel/moe.py +0 -0
  29. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/accel/normalization.py +0 -0
  30. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/accel/ops.py +0 -0
  31. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/accel/router.py +0 -0
  32. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/accel/routing.py +0 -0
  33. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/accel/topk.py +0 -0
  34. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/api/__init__.py +0 -0
  35. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/api/advantages.py +0 -0
  36. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/api/algorithms.py +0 -0
  37. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/api/backend/__init__.py +0 -0
  38. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/api/backend/areno/__init__.py +0 -0
  39. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/api/backend/areno/backend.py +0 -0
  40. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/api/backend/base.py +0 -0
  41. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/api/config.py +0 -0
  42. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/api/context.py +0 -0
  43. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/api/data.py +0 -0
  44. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/api/data_utils.py +0 -0
  45. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/api/defaults.py +0 -0
  46. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/api/loss_fns/__init__.py +0 -0
  47. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/api/loss_fns/dpo.py +0 -0
  48. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/api/loss_fns/grpo.py +0 -0
  49. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/api/loss_fns/gspo.py +0 -0
  50. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/api/loss_fns/layout.py +0 -0
  51. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/api/loss_fns/ppo.py +0 -0
  52. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/api/loss_fns/sft.py +0 -0
  53. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/api/metrics.py +0 -0
  54. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/api/models.py +0 -0
  55. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/api/rewards.py +0 -0
  56. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/api/roles.py +0 -0
  57. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/api/tokenizer.py +0 -0
  58. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/api/trainer.py +0 -0
  59. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/api/trainer_config.py +0 -0
  60. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/api/trainer_factory.py +0 -0
  61. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/api/trainers/__init__.py +0 -0
  62. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/api/trainers/dpo.py +0 -0
  63. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/api/trainers/policy_only.py +0 -0
  64. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/api/trainers/ppo.py +0 -0
  65. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/api/trainers/sft.py +0 -0
  66. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/cli/__init__.py +0 -0
  67. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/cli/main.py +0 -0
  68. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/cli/model_refs.py +0 -0
  69. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/cli/serve.py +0 -0
  70. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/cli/train.py +0 -0
  71. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/engine/__init__.py +0 -0
  72. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/engine/api.py +0 -0
  73. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/engine/checkpoints/__init__.py +0 -0
  74. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/engine/checkpoints/common.py +0 -0
  75. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/engine/checkpoints/io.py +0 -0
  76. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/engine/config.py +0 -0
  77. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/engine/data/__init__.py +0 -0
  78. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/engine/data/batch.py +0 -0
  79. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/engine/data/rollout_state.py +0 -0
  80. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/engine/data/sampling.py +0 -0
  81. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/engine/data/tokenizer.py +0 -0
  82. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/engine/inference.py +0 -0
  83. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/engine/layers/__init__.py +0 -0
  84. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/engine/layers/attention.py +0 -0
  85. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/engine/layers/attention_backend/__init__.py +0 -0
  86. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/engine/layers/attention_backend/common.py +0 -0
  87. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/engine/layers/attention_backend/infer.py +0 -0
  88. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/engine/layers/attention_backend/train.py +0 -0
  89. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/engine/layers/linear.py +0 -0
  90. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/engine/layers/mlp.py +0 -0
  91. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/engine/layers/norm.py +0 -0
  92. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/engine/layers/rotary.py +0 -0
  93. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/engine/layers/vocab.py +0 -0
  94. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/engine/log.py +0 -0
  95. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/engine/modeling.py +0 -0
  96. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/engine/optim/__init__.py +0 -0
  97. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/engine/optim/adamw_8bit.py +0 -0
  98. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/engine/optim/adamw_fp32_master.py +0 -0
  99. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/engine/parallel/__init__.py +0 -0
  100. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/engine/parallel/collectives.py +0 -0
  101. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/engine/parallel/context.py +0 -0
  102. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/engine/protocol.py +0 -0
  103. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/engine/roles.py +0 -0
  104. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/engine/runtime/__init__.py +0 -0
  105. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/engine/runtime/common.py +0 -0
  106. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/engine/runtime/decode_graph.py +0 -0
  107. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/engine/runtime/logprobs.py +0 -0
  108. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/engine/runtime/metadata.py +0 -0
  109. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/engine/runtime/recompute.py +0 -0
  110. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/engine/runtime/rollout.py +0 -0
  111. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/engine/runtime/train_step.py +0 -0
  112. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/engine/training.py +0 -0
  113. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/engine/worker.py +0 -0
  114. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/experimental/__init__.py +0 -0
  115. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/models/__init__.py +0 -0
  116. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/models/_shared/__init__.py +0 -0
  117. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/models/_shared/dynamo_wrappers.py +0 -0
  118. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/models/bailing/__init__.py +0 -0
  119. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/models/bailing/checkpoint.py +0 -0
  120. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/models/bailing/model.py +0 -0
  121. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/models/base.py +0 -0
  122. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/models/gemma4/__init__.py +0 -0
  123. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/models/gemma4/checkpoint.py +0 -0
  124. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/models/gemma4/model.py +0 -0
  125. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/models/llama/__init__.py +0 -0
  126. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/models/llama/checkpoint.py +0 -0
  127. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/models/llama/model.py +0 -0
  128. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/models/minicpmv46/__init__.py +0 -0
  129. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/models/minicpmv46/checkpoint.py +0 -0
  130. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/models/minicpmv46/model.py +0 -0
  131. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/models/qwen3/__init__.py +0 -0
  132. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/models/qwen3/checkpoint.py +0 -0
  133. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/models/qwen3/model.py +0 -0
  134. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/models/qwen3_5/__init__.py +0 -0
  135. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/models/qwen3_5/checkpoint.py +0 -0
  136. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/models/qwen3_5/model.py +0 -0
  137. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno/models/registry.py +0 -0
  138. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno.egg-info/dependency_links.txt +0 -0
  139. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno.egg-info/entry_points.txt +0 -0
  140. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno.egg-info/requires.txt +0 -0
  141. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/areno.egg-info/top_level.txt +0 -0
  142. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/setup.cfg +0 -0
  143. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/setup.py +0 -0
  144. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/tests/test_algorithms_cpu.py +0 -0
  145. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/tests/test_cli_model_refs_cpu.py +0 -0
  146. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/tests/test_config_data_cpu.py +0 -0
  147. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/tests/test_logprobs_cpu.py +0 -0
  148. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/tests/test_losses_rewards_cpu.py +0 -0
  149. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/tests/test_metrics_cpu.py +0 -0
  150. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/tests/test_more_losses_cpu.py +0 -0
  151. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/tests/test_protocol_cpu.py +0 -0
  152. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/tests/test_recompute_cpu.py +0 -0
  153. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/tests/test_registry_cpu.py +0 -0
  154. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/tests/test_runtime_utils_cpu.py +0 -0
  155. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/tests/test_sampling_cpu.py +0 -0
  156. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/tests/test_tokenizer_api_cpu.py +0 -0
  157. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/tests/test_trainer_api_cpu.py +0 -0
  158. {areno-0.0.0.dev0 → areno-0.0.0.dev1}/tests/test_trainer_dataset_utils_cpu.py +0 -0
@@ -0,0 +1 @@
1
+ recursive-include areno/accel/csrc *.cpp *.cu *.cuh *.h *.hpp
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: areno
3
- Version: 0.0.0.dev0
3
+ Version: 0.0.0.dev1
4
4
  Summary: Local LLM post-training stack with the areno engine and model plugins.
5
5
  Requires-Python: >=3.10
6
6
  Requires-Dist: torch>=2.6
@@ -0,0 +1,360 @@
1
+ #include <ATen/ATen.h>
2
+ #include <ATen/cuda/CUDAContext.h>
3
+ #include <c10/cuda/CUDAGuard.h>
4
+ #include <torch/extension.h>
5
+
6
+ namespace areno_accel {
7
+
8
+ template <typename scalar_t>
9
+ __device__ __forceinline__ float read_as_float(scalar_t value) {
10
+ return static_cast<float>(value);
11
+ }
12
+
13
+ template <typename scalar_t>
14
+ __device__ __forceinline__ scalar_t cast_from_float(float value) {
15
+ return static_cast<scalar_t>(value);
16
+ }
17
+
18
+ struct SiluMul {
19
+ __device__ __forceinline__ float operator()(float gate, float up) const {
20
+ return gate / (1.0f + expf(-gate)) * up;
21
+ }
22
+ };
23
+
24
+ struct GeluTanhMul {
25
+ __device__ __forceinline__ float operator()(float gate, float up) const {
26
+ constexpr float cubic = 0.044715f;
27
+ constexpr float scale = 0.7978845608028654f;
28
+ float cdf = 0.5f * (1.0f + tanhf(scale * (gate + cubic * gate * gate * gate)));
29
+ return gate * cdf * up;
30
+ }
31
+ };
32
+
33
+ struct SiluMulGrad {
34
+ __device__ __forceinline__ void operator()(float gate, float up, float grad, float* dgate, float* dup) const {
35
+ float sigmoid = 1.0f / (1.0f + expf(-gate));
36
+ float silu = gate * sigmoid;
37
+ *dgate = grad * up * sigmoid * (1.0f + gate * (1.0f - sigmoid));
38
+ *dup = grad * silu;
39
+ }
40
+ };
41
+
42
+ struct GeluTanhMulGrad {
43
+ __device__ __forceinline__ void operator()(float gate, float up, float grad, float* dgate, float* dup) const {
44
+ constexpr float cubic = 0.044715f;
45
+ constexpr float scale = 0.7978845608028654f;
46
+ float gate2 = gate * gate;
47
+ float inner = scale * (gate + cubic * gate * gate2);
48
+ float tanh_inner = tanhf(inner);
49
+ float cdf = 0.5f * (1.0f + tanh_inner);
50
+ float sech2 = 1.0f - tanh_inner * tanh_inner;
51
+ float d_inner = scale * (1.0f + 3.0f * cubic * gate2);
52
+ float d_gelu = cdf + 0.5f * gate * sech2 * d_inner;
53
+ *dgate = grad * up * d_gelu;
54
+ *dup = grad * gate * cdf;
55
+ }
56
+ };
57
+
58
+ template <typename scalar_t, typename Activation>
59
+ __global__ void gated_product_kernel(scalar_t* __restrict__ output, const scalar_t* __restrict__ input, int half_width) {
60
+ const int64_t row = blockIdx.x;
61
+ const int64_t in_base = row * half_width * 2;
62
+ const int64_t out_base = row * half_width;
63
+ Activation activation;
64
+ for (int col = threadIdx.x; col < half_width; col += blockDim.x) {
65
+ float gate = read_as_float(input[in_base + col]);
66
+ float up = read_as_float(input[in_base + half_width + col]);
67
+ output[out_base + col] = cast_from_float<scalar_t>(activation(gate, up));
68
+ }
69
+ }
70
+
71
+ template <typename scalar_t, typename ActivationGrad>
72
+ __global__ void gated_product_backward_kernel(
73
+ scalar_t* __restrict__ grad_input,
74
+ const scalar_t* __restrict__ grad_output,
75
+ const scalar_t* __restrict__ input,
76
+ int half_width) {
77
+ const int64_t row = blockIdx.x;
78
+ const int64_t in_base = row * half_width * 2;
79
+ const int64_t out_base = row * half_width;
80
+ ActivationGrad activation_grad;
81
+ for (int col = threadIdx.x; col < half_width; col += blockDim.x) {
82
+ float gate = read_as_float(input[in_base + col]);
83
+ float up = read_as_float(input[in_base + half_width + col]);
84
+ float grad = read_as_float(grad_output[out_base + col]);
85
+ float dgate;
86
+ float dup;
87
+ activation_grad(gate, up, grad, &dgate, &dup);
88
+ grad_input[in_base + col] = cast_from_float<scalar_t>(dgate);
89
+ grad_input[in_base + half_width + col] = cast_from_float<scalar_t>(dup);
90
+ }
91
+ }
92
+
93
+ template <typename scalar_t>
94
+ __global__ void silu_kernel(scalar_t* __restrict__ output, const scalar_t* __restrict__ input, int64_t elements) {
95
+ for (int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < elements; idx += static_cast<int64_t>(blockDim.x) * gridDim.x) {
96
+ float x = read_as_float(input[idx]);
97
+ output[idx] = cast_from_float<scalar_t>(x / (1.0f + expf(-x)));
98
+ }
99
+ }
100
+
101
+ template <typename scalar_t>
102
+ __global__ void sigmoid_kernel(scalar_t* __restrict__ output, const scalar_t* __restrict__ input, int64_t elements) {
103
+ for (int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < elements; idx += static_cast<int64_t>(blockDim.x) * gridDim.x) {
104
+ float x = read_as_float(input[idx]);
105
+ output[idx] = cast_from_float<scalar_t>(1.0f / (1.0f + expf(-x)));
106
+ }
107
+ }
108
+
109
+ template <typename scalar_t>
110
+ __global__ void sigmoid_backward_kernel(
111
+ scalar_t* __restrict__ grad_input,
112
+ const scalar_t* __restrict__ grad_output,
113
+ const scalar_t* __restrict__ output,
114
+ int64_t elements) {
115
+ for (int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < elements; idx += static_cast<int64_t>(blockDim.x) * gridDim.x) {
116
+ float y = read_as_float(output[idx]);
117
+ grad_input[idx] = cast_from_float<scalar_t>(read_as_float(grad_output[idx]) * y * (1.0f - y));
118
+ }
119
+ }
120
+
121
+ template <typename scalar_t>
122
+ __global__ void softplus_kernel(scalar_t* __restrict__ output, const scalar_t* __restrict__ input, int64_t elements) {
123
+ for (int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < elements; idx += static_cast<int64_t>(blockDim.x) * gridDim.x) {
124
+ float x = read_as_float(input[idx]);
125
+ output[idx] = cast_from_float<scalar_t>(x > 20.0f ? x : log1pf(expf(x)));
126
+ }
127
+ }
128
+
129
+ template <typename scalar_t>
130
+ __global__ void softplus_backward_kernel(
131
+ scalar_t* __restrict__ grad_input,
132
+ const scalar_t* __restrict__ grad_output,
133
+ const scalar_t* __restrict__ input,
134
+ int64_t elements) {
135
+ for (int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < elements; idx += static_cast<int64_t>(blockDim.x) * gridDim.x) {
136
+ float x = read_as_float(input[idx]);
137
+ float sigmoid = x > 20.0f ? 1.0f : 1.0f / (1.0f + expf(-x));
138
+ grad_input[idx] = cast_from_float<scalar_t>(read_as_float(grad_output[idx]) * sigmoid);
139
+ }
140
+ }
141
+
142
+ template <typename scalar_t>
143
+ __global__ void silu_backward_kernel(
144
+ scalar_t* __restrict__ grad_input,
145
+ const scalar_t* __restrict__ grad_output,
146
+ const scalar_t* __restrict__ input,
147
+ int64_t elements) {
148
+ for (int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < elements; idx += static_cast<int64_t>(blockDim.x) * gridDim.x) {
149
+ float x = read_as_float(input[idx]);
150
+ float sigmoid = 1.0f / (1.0f + expf(-x));
151
+ float grad = read_as_float(grad_output[idx]) * sigmoid * (1.0f + x * (1.0f - sigmoid));
152
+ grad_input[idx] = cast_from_float<scalar_t>(grad);
153
+ }
154
+ }
155
+
156
+ template <typename Kernel>
157
+ void run_unary(torch::Tensor output, torch::Tensor input, const char* op_name, Kernel kernel) {
158
+ TORCH_CHECK(input.is_cuda(), op_name, " input must be CUDA");
159
+ TORCH_CHECK(output.is_cuda(), op_name, " output must be CUDA");
160
+ TORCH_CHECK(input.scalar_type() == output.scalar_type(), op_name, " dtype mismatch");
161
+ TORCH_CHECK(input.numel() == output.numel(), op_name, " shape mismatch");
162
+
163
+ const int threads = 256;
164
+ const int blocks = static_cast<int>(std::min<int64_t>((input.numel() + threads - 1) / threads, 4096));
165
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
166
+ const at::cuda::OptionalCUDAGuard guard(device_of(input));
167
+
168
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::kHalf, at::kBFloat16, input.scalar_type(), "areno_unary", [&] {
169
+ kernel.template operator()<scalar_t>(output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), input.numel(), stream, blocks, threads);
170
+ });
171
+ }
172
+
173
+ struct SiluLauncher {
174
+ template <typename scalar_t>
175
+ void operator()(scalar_t* output, const scalar_t* input, int64_t elements, cudaStream_t stream, int blocks, int threads) const {
176
+ silu_kernel<scalar_t><<<blocks, threads, 0, stream>>>(output, input, elements);
177
+ }
178
+ };
179
+
180
+ struct SigmoidLauncher {
181
+ template <typename scalar_t>
182
+ void operator()(scalar_t* output, const scalar_t* input, int64_t elements, cudaStream_t stream, int blocks, int threads) const {
183
+ sigmoid_kernel<scalar_t><<<blocks, threads, 0, stream>>>(output, input, elements);
184
+ }
185
+ };
186
+
187
+ struct SoftplusLauncher {
188
+ template <typename scalar_t>
189
+ void operator()(scalar_t* output, const scalar_t* input, int64_t elements, cudaStream_t stream, int blocks, int threads) const {
190
+ softplus_kernel<scalar_t><<<blocks, threads, 0, stream>>>(output, input, elements);
191
+ }
192
+ };
193
+
194
+ void run_silu(torch::Tensor output, torch::Tensor input, const char* op_name) {
195
+ run_unary(output, input, op_name, SiluLauncher{});
196
+ }
197
+
198
+ void run_sigmoid(torch::Tensor output, torch::Tensor input, const char* op_name) {
199
+ run_unary(output, input, op_name, SigmoidLauncher{});
200
+ }
201
+
202
+ void run_softplus(torch::Tensor output, torch::Tensor input, const char* op_name) {
203
+ run_unary(output, input, op_name, SoftplusLauncher{});
204
+ }
205
+
206
+ void run_silu_backward(torch::Tensor grad_input, torch::Tensor grad_output, torch::Tensor input, const char* op_name) {
207
+ TORCH_CHECK(input.is_cuda(), op_name, " input must be CUDA");
208
+ TORCH_CHECK(grad_output.is_cuda(), op_name, " grad_output must be CUDA");
209
+ TORCH_CHECK(grad_input.is_cuda(), op_name, " grad_input must be CUDA");
210
+ TORCH_CHECK(input.scalar_type() == grad_output.scalar_type(), op_name, " dtype mismatch");
211
+ TORCH_CHECK(input.scalar_type() == grad_input.scalar_type(), op_name, " dtype mismatch");
212
+ TORCH_CHECK(input.numel() == grad_output.numel(), op_name, " shape mismatch");
213
+ TORCH_CHECK(input.numel() == grad_input.numel(), op_name, " shape mismatch");
214
+
215
+ const int threads = 256;
216
+ const int blocks = static_cast<int>(std::min<int64_t>((input.numel() + threads - 1) / threads, 4096));
217
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
218
+ const at::cuda::OptionalCUDAGuard guard(device_of(input));
219
+
220
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::kHalf, at::kBFloat16, input.scalar_type(), "areno_silu_backward", [&] {
221
+ silu_backward_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
222
+ grad_input.data_ptr<scalar_t>(), grad_output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), input.numel());
223
+ });
224
+ }
225
+
226
+ void run_sigmoid_backward(torch::Tensor grad_input, torch::Tensor grad_output, torch::Tensor output, const char* op_name) {
227
+ TORCH_CHECK(output.is_cuda(), op_name, " output must be CUDA");
228
+ TORCH_CHECK(grad_output.is_cuda(), op_name, " grad_output must be CUDA");
229
+ TORCH_CHECK(grad_input.is_cuda(), op_name, " grad_input must be CUDA");
230
+ TORCH_CHECK(output.scalar_type() == grad_output.scalar_type(), op_name, " dtype mismatch");
231
+ TORCH_CHECK(output.scalar_type() == grad_input.scalar_type(), op_name, " dtype mismatch");
232
+ TORCH_CHECK(output.numel() == grad_output.numel(), op_name, " shape mismatch");
233
+ TORCH_CHECK(output.numel() == grad_input.numel(), op_name, " shape mismatch");
234
+
235
+ const int threads = 256;
236
+ const int blocks = static_cast<int>(std::min<int64_t>((output.numel() + threads - 1) / threads, 4096));
237
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
238
+ const at::cuda::OptionalCUDAGuard guard(device_of(output));
239
+
240
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::kHalf, at::kBFloat16, output.scalar_type(), "areno_sigmoid_backward", [&] {
241
+ sigmoid_backward_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
242
+ grad_input.data_ptr<scalar_t>(), grad_output.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(), output.numel());
243
+ });
244
+ }
245
+
246
+ void run_softplus_backward(torch::Tensor grad_input, torch::Tensor grad_output, torch::Tensor input, const char* op_name) {
247
+ TORCH_CHECK(input.is_cuda(), op_name, " input must be CUDA");
248
+ TORCH_CHECK(grad_output.is_cuda(), op_name, " grad_output must be CUDA");
249
+ TORCH_CHECK(grad_input.is_cuda(), op_name, " grad_input must be CUDA");
250
+ TORCH_CHECK(input.scalar_type() == grad_output.scalar_type(), op_name, " dtype mismatch");
251
+ TORCH_CHECK(input.scalar_type() == grad_input.scalar_type(), op_name, " dtype mismatch");
252
+ TORCH_CHECK(input.numel() == grad_output.numel(), op_name, " shape mismatch");
253
+ TORCH_CHECK(input.numel() == grad_input.numel(), op_name, " shape mismatch");
254
+
255
+ const int threads = 256;
256
+ const int blocks = static_cast<int>(std::min<int64_t>((input.numel() + threads - 1) / threads, 4096));
257
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
258
+ const at::cuda::OptionalCUDAGuard guard(device_of(input));
259
+
260
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::kHalf, at::kBFloat16, input.scalar_type(), "areno_softplus_backward", [&] {
261
+ softplus_backward_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
262
+ grad_input.data_ptr<scalar_t>(), grad_output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), input.numel());
263
+ });
264
+ }
265
+
266
+ template <typename Activation>
267
+ void run_gated_product(torch::Tensor output, torch::Tensor input, const char* op_name) {
268
+ TORCH_CHECK(input.is_cuda(), op_name, " input must be CUDA");
269
+ TORCH_CHECK(output.is_cuda(), op_name, " output must be CUDA");
270
+ TORCH_CHECK(input.dim() > 0, op_name, " input must have at least one dimension");
271
+ TORCH_CHECK(input.size(-1) % 2 == 0, op_name, " input last dimension must be even");
272
+ TORCH_CHECK(output.numel() * 2 == input.numel(), op_name, " output shape does not match input");
273
+
274
+ const int half_width = input.size(-1) / 2;
275
+ const int64_t rows = input.numel() / input.size(-1);
276
+ const int threads = std::min(1024, std::max(32, ((half_width + 31) / 32) * 32));
277
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
278
+ const at::cuda::OptionalCUDAGuard guard(device_of(input));
279
+
280
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::kHalf, at::kBFloat16, input.scalar_type(), "areno_gated_product", [&] {
281
+ gated_product_kernel<scalar_t, Activation><<<rows, threads, 0, stream>>>(
282
+ output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), half_width);
283
+ });
284
+ }
285
+
286
+ template <typename ActivationGrad>
287
+ void run_gated_product_backward(torch::Tensor grad_input, torch::Tensor grad_output, torch::Tensor input, const char* op_name) {
288
+ TORCH_CHECK(input.is_cuda(), op_name, " input must be CUDA");
289
+ TORCH_CHECK(grad_output.is_cuda(), op_name, " grad_output must be CUDA");
290
+ TORCH_CHECK(grad_input.is_cuda(), op_name, " grad_input must be CUDA");
291
+ TORCH_CHECK(input.size(-1) % 2 == 0, op_name, " input last dimension must be even");
292
+ TORCH_CHECK(grad_input.numel() == input.numel(), op_name, " grad_input shape does not match input");
293
+ TORCH_CHECK(grad_output.numel() * 2 == input.numel(), op_name, " grad_output shape does not match input");
294
+
295
+ const int half_width = input.size(-1) / 2;
296
+ const int64_t rows = input.numel() / input.size(-1);
297
+ const int threads = std::min(1024, std::max(32, ((half_width + 31) / 32) * 32));
298
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
299
+ const at::cuda::OptionalCUDAGuard guard(device_of(input));
300
+
301
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::kHalf, at::kBFloat16, input.scalar_type(), "areno_gated_product_backward", [&] {
302
+ gated_product_backward_kernel<scalar_t, ActivationGrad><<<rows, threads, 0, stream>>>(
303
+ grad_input.data_ptr<scalar_t>(), grad_output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), half_width);
304
+ });
305
+ }
306
+
307
+ } // namespace areno_accel
308
+
309
+ void areno_silu_and_mul_cuda(torch::Tensor output, torch::Tensor input) {
310
+ areno_accel::run_gated_product<areno_accel::SiluMul>(output, input, "areno_silu_and_mul");
311
+ }
312
+
313
+ void areno_gelu_tanh_and_mul_cuda(torch::Tensor output, torch::Tensor input) {
314
+ areno_accel::run_gated_product<areno_accel::GeluTanhMul>(output, input, "areno_gelu_tanh_and_mul");
315
+ }
316
+
317
+ torch::Tensor areno_silu_cuda(torch::Tensor input) {
318
+ auto output = torch::empty_like(input);
319
+ areno_accel::run_silu(output, input, "areno_silu");
320
+ return output;
321
+ }
322
+
323
+ torch::Tensor areno_sigmoid_cuda(torch::Tensor input) {
324
+ auto output = torch::empty_like(input);
325
+ areno_accel::run_sigmoid(output, input, "areno_sigmoid");
326
+ return output;
327
+ }
328
+
329
+ torch::Tensor areno_softplus_cuda(torch::Tensor input) {
330
+ auto output = torch::empty_like(input);
331
+ areno_accel::run_softplus(output, input, "areno_softplus");
332
+ return output;
333
+ }
334
+
335
+ void areno_d_silu_and_mul_cuda(torch::Tensor grad_input, torch::Tensor grad_output, torch::Tensor input) {
336
+ areno_accel::run_gated_product_backward<areno_accel::SiluMulGrad>(grad_input, grad_output, input, "areno_d_silu_and_mul");
337
+ }
338
+
339
+ void areno_d_gelu_tanh_and_mul_cuda(torch::Tensor grad_input, torch::Tensor grad_output, torch::Tensor input) {
340
+ areno_accel::run_gated_product_backward<areno_accel::GeluTanhMulGrad>(
341
+ grad_input, grad_output, input, "areno_d_gelu_tanh_and_mul");
342
+ }
343
+
344
+ torch::Tensor areno_d_silu_cuda(torch::Tensor grad_output, torch::Tensor input) {
345
+ auto grad_input = torch::empty_like(input);
346
+ areno_accel::run_silu_backward(grad_input, grad_output, input, "areno_d_silu");
347
+ return grad_input;
348
+ }
349
+
350
+ torch::Tensor areno_d_sigmoid_cuda(torch::Tensor grad_output, torch::Tensor output) {
351
+ auto grad_input = torch::empty_like(output);
352
+ areno_accel::run_sigmoid_backward(grad_input, grad_output, output, "areno_d_sigmoid");
353
+ return grad_input;
354
+ }
355
+
356
+ torch::Tensor areno_d_softplus_cuda(torch::Tensor grad_output, torch::Tensor input) {
357
+ auto grad_input = torch::empty_like(input);
358
+ areno_accel::run_softplus_backward(grad_input, grad_output, input, "areno_d_softplus");
359
+ return grad_input;
360
+ }
@@ -0,0 +1,25 @@
1
+ #pragma once
2
+
3
+ #include <cuda_bf16.h>
4
+ #include <cuda_fp16.h>
5
+ #include <c10/util/BFloat16.h>
6
+ #include <c10/util/Half.h>
7
+
8
+ namespace areno_accel {
9
+
10
+ template <typename scalar_t>
11
+ __device__ inline void atomic_add(scalar_t* address, scalar_t value) {
12
+ atomicAdd(address, value);
13
+ }
14
+
15
+ template <>
16
+ __device__ inline void atomic_add<c10::Half>(c10::Half* address, c10::Half value) {
17
+ atomicAdd(reinterpret_cast<__half*>(address), static_cast<__half>(value));
18
+ }
19
+
20
+ template <>
21
+ __device__ inline void atomic_add<c10::BFloat16>(c10::BFloat16* address, c10::BFloat16 value) {
22
+ atomicAdd(reinterpret_cast<__nv_bfloat16*>(address), static_cast<__nv_bfloat16>(value));
23
+ }
24
+
25
+ } // namespace areno_accel