areno 0.0.0.dev0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (146) hide show
  1. areno-0.0.0.dev0/PKG-INFO +19 -0
  2. areno-0.0.0.dev0/areno/__init__.py +66 -0
  3. areno-0.0.0.dev0/areno/accel/__init__.py +42 -0
  4. areno-0.0.0.dev0/areno/accel/_extension.py +24 -0
  5. areno-0.0.0.dev0/areno/accel/activations.py +172 -0
  6. areno-0.0.0.dev0/areno/accel/conv.py +141 -0
  7. areno-0.0.0.dev0/areno/accel/embedding.py +49 -0
  8. areno-0.0.0.dev0/areno/accel/kernels/__init__.py +1 -0
  9. areno-0.0.0.dev0/areno/accel/kernels/fused_moe.py +426 -0
  10. areno-0.0.0.dev0/areno/accel/kernels/group_rmsnorm.py +155 -0
  11. areno-0.0.0.dev0/areno/accel/kernels/seg_la.py +1168 -0
  12. areno-0.0.0.dev0/areno/accel/linear.py +116 -0
  13. areno-0.0.0.dev0/areno/accel/moe.py +155 -0
  14. areno-0.0.0.dev0/areno/accel/normalization.py +137 -0
  15. areno-0.0.0.dev0/areno/accel/ops.py +79 -0
  16. areno-0.0.0.dev0/areno/accel/router.py +42 -0
  17. areno-0.0.0.dev0/areno/accel/routing.py +47 -0
  18. areno-0.0.0.dev0/areno/accel/topk.py +36 -0
  19. areno-0.0.0.dev0/areno/api/__init__.py +46 -0
  20. areno-0.0.0.dev0/areno/api/advantages.py +74 -0
  21. areno-0.0.0.dev0/areno/api/algorithms.py +207 -0
  22. areno-0.0.0.dev0/areno/api/backend/__init__.py +7 -0
  23. areno-0.0.0.dev0/areno/api/backend/areno/__init__.py +9 -0
  24. areno-0.0.0.dev0/areno/api/backend/areno/backend.py +406 -0
  25. areno-0.0.0.dev0/areno/api/backend/base.py +128 -0
  26. areno-0.0.0.dev0/areno/api/config.py +57 -0
  27. areno-0.0.0.dev0/areno/api/context.py +31 -0
  28. areno-0.0.0.dev0/areno/api/data.py +50 -0
  29. areno-0.0.0.dev0/areno/api/data_utils.py +69 -0
  30. areno-0.0.0.dev0/areno/api/defaults.py +3 -0
  31. areno-0.0.0.dev0/areno/api/loss_fns/__init__.py +20 -0
  32. areno-0.0.0.dev0/areno/api/loss_fns/dpo.py +77 -0
  33. areno-0.0.0.dev0/areno/api/loss_fns/grpo.py +47 -0
  34. areno-0.0.0.dev0/areno/api/loss_fns/gspo.py +65 -0
  35. areno-0.0.0.dev0/areno/api/loss_fns/layout.py +161 -0
  36. areno-0.0.0.dev0/areno/api/loss_fns/ppo.py +100 -0
  37. areno-0.0.0.dev0/areno/api/loss_fns/sft.py +37 -0
  38. areno-0.0.0.dev0/areno/api/metrics.py +174 -0
  39. areno-0.0.0.dev0/areno/api/models.py +76 -0
  40. areno-0.0.0.dev0/areno/api/rewards.py +48 -0
  41. areno-0.0.0.dev0/areno/api/roles.py +33 -0
  42. areno-0.0.0.dev0/areno/api/tokenizer.py +74 -0
  43. areno-0.0.0.dev0/areno/api/trainer.py +258 -0
  44. areno-0.0.0.dev0/areno/api/trainer_config.py +179 -0
  45. areno-0.0.0.dev0/areno/api/trainer_factory.py +16 -0
  46. areno-0.0.0.dev0/areno/api/trainers/__init__.py +16 -0
  47. areno-0.0.0.dev0/areno/api/trainers/dpo.py +199 -0
  48. areno-0.0.0.dev0/areno/api/trainers/policy_only.py +213 -0
  49. areno-0.0.0.dev0/areno/api/trainers/ppo.py +355 -0
  50. areno-0.0.0.dev0/areno/api/trainers/sft.py +168 -0
  51. areno-0.0.0.dev0/areno/cli/__init__.py +2 -0
  52. areno-0.0.0.dev0/areno/cli/main.py +50 -0
  53. areno-0.0.0.dev0/areno/cli/model_refs.py +46 -0
  54. areno-0.0.0.dev0/areno/cli/serve.py +759 -0
  55. areno-0.0.0.dev0/areno/cli/train.py +506 -0
  56. areno-0.0.0.dev0/areno/engine/__init__.py +17 -0
  57. areno-0.0.0.dev0/areno/engine/api.py +582 -0
  58. areno-0.0.0.dev0/areno/engine/checkpoints/__init__.py +9 -0
  59. areno-0.0.0.dev0/areno/engine/checkpoints/common.py +1158 -0
  60. areno-0.0.0.dev0/areno/engine/checkpoints/io.py +664 -0
  61. areno-0.0.0.dev0/areno/engine/config.py +208 -0
  62. areno-0.0.0.dev0/areno/engine/data/__init__.py +12 -0
  63. areno-0.0.0.dev0/areno/engine/data/batch.py +83 -0
  64. areno-0.0.0.dev0/areno/engine/data/rollout_state.py +196 -0
  65. areno-0.0.0.dev0/areno/engine/data/sampling.py +250 -0
  66. areno-0.0.0.dev0/areno/engine/data/tokenizer.py +30 -0
  67. areno-0.0.0.dev0/areno/engine/inference.py +941 -0
  68. areno-0.0.0.dev0/areno/engine/layers/__init__.py +24 -0
  69. areno-0.0.0.dev0/areno/engine/layers/attention.py +134 -0
  70. areno-0.0.0.dev0/areno/engine/layers/attention_backend/__init__.py +25 -0
  71. areno-0.0.0.dev0/areno/engine/layers/attention_backend/common.py +107 -0
  72. areno-0.0.0.dev0/areno/engine/layers/attention_backend/infer.py +301 -0
  73. areno-0.0.0.dev0/areno/engine/layers/attention_backend/train.py +203 -0
  74. areno-0.0.0.dev0/areno/engine/layers/linear.py +242 -0
  75. areno-0.0.0.dev0/areno/engine/layers/mlp.py +52 -0
  76. areno-0.0.0.dev0/areno/engine/layers/norm.py +104 -0
  77. areno-0.0.0.dev0/areno/engine/layers/rotary.py +130 -0
  78. areno-0.0.0.dev0/areno/engine/layers/vocab.py +73 -0
  79. areno-0.0.0.dev0/areno/engine/log.py +36 -0
  80. areno-0.0.0.dev0/areno/engine/modeling.py +99 -0
  81. areno-0.0.0.dev0/areno/engine/optim/__init__.py +11 -0
  82. areno-0.0.0.dev0/areno/engine/optim/adamw_8bit.py +274 -0
  83. areno-0.0.0.dev0/areno/engine/optim/adamw_fp32_master.py +502 -0
  84. areno-0.0.0.dev0/areno/engine/parallel/__init__.py +12 -0
  85. areno-0.0.0.dev0/areno/engine/parallel/collectives.py +243 -0
  86. areno-0.0.0.dev0/areno/engine/parallel/context.py +142 -0
  87. areno-0.0.0.dev0/areno/engine/protocol.py +445 -0
  88. areno-0.0.0.dev0/areno/engine/roles.py +557 -0
  89. areno-0.0.0.dev0/areno/engine/runtime/__init__.py +1 -0
  90. areno-0.0.0.dev0/areno/engine/runtime/common.py +153 -0
  91. areno-0.0.0.dev0/areno/engine/runtime/decode_graph.py +172 -0
  92. areno-0.0.0.dev0/areno/engine/runtime/logprobs.py +204 -0
  93. areno-0.0.0.dev0/areno/engine/runtime/metadata.py +45 -0
  94. areno-0.0.0.dev0/areno/engine/runtime/recompute.py +51 -0
  95. areno-0.0.0.dev0/areno/engine/runtime/rollout.py +127 -0
  96. areno-0.0.0.dev0/areno/engine/runtime/train_step.py +292 -0
  97. areno-0.0.0.dev0/areno/engine/training.py +190 -0
  98. areno-0.0.0.dev0/areno/engine/worker.py +220 -0
  99. areno-0.0.0.dev0/areno/experimental/__init__.py +1 -0
  100. areno-0.0.0.dev0/areno/models/__init__.py +31 -0
  101. areno-0.0.0.dev0/areno/models/_shared/__init__.py +2 -0
  102. areno-0.0.0.dev0/areno/models/_shared/dynamo_wrappers.py +133 -0
  103. areno-0.0.0.dev0/areno/models/bailing/__init__.py +13 -0
  104. areno-0.0.0.dev0/areno/models/bailing/checkpoint.py +92 -0
  105. areno-0.0.0.dev0/areno/models/bailing/model.py +1304 -0
  106. areno-0.0.0.dev0/areno/models/base.py +70 -0
  107. areno-0.0.0.dev0/areno/models/gemma4/__init__.py +13 -0
  108. areno-0.0.0.dev0/areno/models/gemma4/checkpoint.py +295 -0
  109. areno-0.0.0.dev0/areno/models/gemma4/model.py +1005 -0
  110. areno-0.0.0.dev0/areno/models/llama/__init__.py +13 -0
  111. areno-0.0.0.dev0/areno/models/llama/checkpoint.py +68 -0
  112. areno-0.0.0.dev0/areno/models/llama/model.py +70 -0
  113. areno-0.0.0.dev0/areno/models/minicpmv46/__init__.py +11 -0
  114. areno-0.0.0.dev0/areno/models/minicpmv46/checkpoint.py +298 -0
  115. areno-0.0.0.dev0/areno/models/minicpmv46/model.py +686 -0
  116. areno-0.0.0.dev0/areno/models/qwen3/__init__.py +14 -0
  117. areno-0.0.0.dev0/areno/models/qwen3/checkpoint.py +131 -0
  118. areno-0.0.0.dev0/areno/models/qwen3/model.py +526 -0
  119. areno-0.0.0.dev0/areno/models/qwen3_5/__init__.py +5 -0
  120. areno-0.0.0.dev0/areno/models/qwen3_5/checkpoint.py +609 -0
  121. areno-0.0.0.dev0/areno/models/qwen3_5/model.py +891 -0
  122. areno-0.0.0.dev0/areno/models/registry.py +129 -0
  123. areno-0.0.0.dev0/areno.egg-info/PKG-INFO +19 -0
  124. areno-0.0.0.dev0/areno.egg-info/SOURCES.txt +144 -0
  125. areno-0.0.0.dev0/areno.egg-info/dependency_links.txt +1 -0
  126. areno-0.0.0.dev0/areno.egg-info/entry_points.txt +2 -0
  127. areno-0.0.0.dev0/areno.egg-info/requires.txt +14 -0
  128. areno-0.0.0.dev0/areno.egg-info/top_level.txt +1 -0
  129. areno-0.0.0.dev0/pyproject.toml +32 -0
  130. areno-0.0.0.dev0/setup.cfg +4 -0
  131. areno-0.0.0.dev0/setup.py +55 -0
  132. areno-0.0.0.dev0/tests/test_algorithms_cpu.py +75 -0
  133. areno-0.0.0.dev0/tests/test_cli_model_refs_cpu.py +59 -0
  134. areno-0.0.0.dev0/tests/test_config_data_cpu.py +123 -0
  135. areno-0.0.0.dev0/tests/test_logprobs_cpu.py +95 -0
  136. areno-0.0.0.dev0/tests/test_losses_rewards_cpu.py +131 -0
  137. areno-0.0.0.dev0/tests/test_metrics_cpu.py +66 -0
  138. areno-0.0.0.dev0/tests/test_more_losses_cpu.py +101 -0
  139. areno-0.0.0.dev0/tests/test_protocol_cpu.py +87 -0
  140. areno-0.0.0.dev0/tests/test_recompute_cpu.py +64 -0
  141. areno-0.0.0.dev0/tests/test_registry_cpu.py +174 -0
  142. areno-0.0.0.dev0/tests/test_runtime_utils_cpu.py +158 -0
  143. areno-0.0.0.dev0/tests/test_sampling_cpu.py +78 -0
  144. areno-0.0.0.dev0/tests/test_tokenizer_api_cpu.py +110 -0
  145. areno-0.0.0.dev0/tests/test_trainer_api_cpu.py +148 -0
  146. areno-0.0.0.dev0/tests/test_trainer_dataset_utils_cpu.py +114 -0
@@ -0,0 +1,19 @@
1
+ Metadata-Version: 2.4
2
+ Name: areno
3
+ Version: 0.0.0.dev0
4
+ Summary: Local LLM post-training stack with the areno engine and model plugins.
5
+ Requires-Python: >=3.10
6
+ Requires-Dist: torch>=2.6
7
+ Requires-Dist: flash-attn>=2.7
8
+ Requires-Dist: flash-linear-attention>=0.2
9
+ Requires-Dist: safetensors>=0.4
10
+ Requires-Dist: transformers>=4.56
11
+ Requires-Dist: huggingface-hub>=0.25
12
+ Requires-Dist: datasets>=3.3.0
13
+ Requires-Dist: numpy
14
+ Requires-Dist: tensorboard
15
+ Requires-Dist: einops
16
+ Requires-Dist: fastapi>=0.110
17
+ Requires-Dist: click>=8.1
18
+ Requires-Dist: tqdm>=4.66
19
+ Requires-Dist: uvicorn>=0.27
@@ -0,0 +1,66 @@
1
+ """Top-level areno package.
2
+
3
+ Sets process-wide knobs that must be in place before any CUDA/Triton kernel or
4
+ torch.compile call runs: a single CUDA stream for collectives and a generous
5
+ TorchDynamo cache so the engine can compile many specialized graphs (per shape
6
+ bucket, prefill vs decode, train vs infer) without thrashing.
7
+
8
+ Exposes the user-facing surface: configuration dataclasses, the rollout
9
+ output container, and the `ArenoEngine` coordinator.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import os
15
+
16
+ # A single CUDA stream connection keeps NCCL collectives ordered with compute,
17
+ # which is what areno's TP/DP all-reduce + all-gather patterns assume.
18
+ os.environ.setdefault("CUDA_DEVICE_MAX_CONNECTIONS", "1")
19
+
20
+ try:
21
+ import torch._dynamo as _dynamo
22
+ except ModuleNotFoundError:
23
+ _dynamo = None
24
+
25
+ if _dynamo is not None:
26
+ # Train, prefill, decode, scoring and multiple shape buckets all produce
27
+ # distinct compiled artifacts; raise the cache limits so recompilation does
28
+ # not evict graphs that will be replayed across RL steps.
29
+ _dynamo.config.cache_size_limit = max(_dynamo.config.cache_size_limit, 64)
30
+ try:
31
+ _dynamo.config.accumulated_cache_size_limit = max(_dynamo.config.accumulated_cache_size_limit, 256)
32
+ except AttributeError:
33
+ pass
34
+
35
+ from areno.engine.log import configure_default_logging
36
+
37
+ configure_default_logging()
38
+
39
+
40
+ def __getattr__(name: str):
41
+ """Lazily expose engine symbols without importing kernel-heavy modules."""
42
+
43
+ if name == "ArenoEngine":
44
+ from areno.engine import ArenoEngine
45
+
46
+ return ArenoEngine
47
+ if name in {"EngineConfig", "ModelConfig", "OptimizerConfig", "RuntimeConfig"}:
48
+ from areno.engine import config
49
+
50
+ return getattr(config, name)
51
+ if name in {"RolloutOutput", "SamplingParams", "TrainStats"}:
52
+ from areno.engine import data
53
+
54
+ return getattr(data, name)
55
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
56
+
57
+ __all__ = [
58
+ "ArenoEngine",
59
+ "EngineConfig",
60
+ "ModelConfig",
61
+ "OptimizerConfig",
62
+ "RolloutOutput",
63
+ "RuntimeConfig",
64
+ "SamplingParams",
65
+ "TrainStats",
66
+ ]
@@ -0,0 +1,42 @@
1
+ """Public entry point for the ARENO acceleration shims.
2
+
3
+ Re-exports thin Python wrappers around the ``areno.accel._areno_accel`` C++/CUDA
4
+ extension. Each submodule defines a ``torch.autograd.Function`` (where backward
5
+ is needed) and a ``@torch._dynamo.disable``-decorated user-facing function that
6
+ performs argument validation, dispatches to the fused CUDA kernel and exposes a
7
+ PyTorch-friendly signature. The kernels themselves live in ``csrc/`` and are
8
+ built by ``setup.py``; if the extension is not built the first call into the
9
+ shim raises ``ModuleNotFoundError`` at import time.
10
+ """
11
+ from areno.accel.activations import areno_gelu_tanh_and_mul, areno_sigmoid, areno_silu, areno_silu_and_mul, areno_softplus
12
+ from areno.accel.conv import areno_depthwise_causal_conv1d_silu, areno_depthwise_causal_conv1d_silu_decode, areno_packed_depthwise_causal_conv1d_silu
13
+ from areno.accel.embedding import areno_vocab_embedding
14
+ from areno.accel.linear import areno_grouped_linear, areno_linear
15
+ from areno.accel.moe import areno_moe_permute, areno_moe_topk_permute, areno_moe_unpermute
16
+ from areno.accel.normalization import areno_optional_scale_rmsnorm, areno_rmsnorm, areno_rmsnorm_silu_gate
17
+ from areno.accel.router import areno_grouped_topk_router
18
+ from areno.accel.routing import areno_moe_align
19
+ from areno.accel.topk import areno_topk_softmax
20
+
21
+ __all__ = [
22
+ "areno_depthwise_causal_conv1d_silu",
23
+ "areno_depthwise_causal_conv1d_silu_decode",
24
+ "areno_packed_depthwise_causal_conv1d_silu",
25
+ "areno_gelu_tanh_and_mul",
26
+ "areno_grouped_topk_router",
27
+ "areno_grouped_linear",
28
+ "areno_linear",
29
+ "areno_moe_align",
30
+ "areno_moe_permute",
31
+ "areno_moe_topk_permute",
32
+ "areno_moe_unpermute",
33
+ "areno_optional_scale_rmsnorm",
34
+ "areno_rmsnorm",
35
+ "areno_rmsnorm_silu_gate",
36
+ "areno_sigmoid",
37
+ "areno_silu",
38
+ "areno_silu_and_mul",
39
+ "areno_softplus",
40
+ "areno_topk_softmax",
41
+ "areno_vocab_embedding",
42
+ ]
@@ -0,0 +1,24 @@
1
+ """Lazy loader for the compiled ``areno.accel._areno_accel`` C++/CUDA extension.
2
+
3
+ The extension module is imported on first use rather than at package import
4
+ time so that ``import areno.accel`` succeeds in environments where only the
5
+ Python shims are needed (e.g. for type checking). Each shim calls
6
+ ``extension()`` to obtain the compiled module and dispatch into the fused
7
+ kernel. There is no pure-Python fallback: if the extension was not built the
8
+ ``importlib.import_module`` call below raises ``ModuleNotFoundError``.
9
+ """
10
+ from __future__ import annotations
11
+
12
+ import importlib
13
+ from types import ModuleType
14
+
15
+ # Cached reference to the compiled extension; populated on first call.
16
+ _EXT: ModuleType | None = None
17
+
18
+
19
+ def extension() -> ModuleType:
20
+ """Return the compiled C++/CUDA extension module, importing it lazily."""
21
+ global _EXT
22
+ if _EXT is None:
23
+ _EXT = importlib.import_module("areno.accel._areno_accel")
24
+ return _EXT
@@ -0,0 +1,172 @@
1
+ """Fused activation kernels (SiLU, GELU-tanh, sigmoid, softplus, gated variants).
2
+
3
+ Each wrapper dispatches to the ARENO CUDA kernel via the compiled extension and
4
+ preserves autograd by routing through ``torch.autograd.Function`` whenever the
5
+ input requires gradients. The ``*_and_mul`` variants implement the common
6
+ "gated MLP" pattern where the last input dimension is split in two and the
7
+ first half is activated then element-wise multiplied with the second half,
8
+ producing an output with half the last-dimension size.
9
+ """
10
+ import torch
11
+
12
+ from areno.accel._extension import extension as _extension
13
+
14
+
15
+ def _activation_out(x: torch.Tensor, out: torch.Tensor | None) -> torch.Tensor:
16
+ """Allocate or validate the half-width output tensor for ``*_and_mul`` ops."""
17
+ if x.shape[-1] % 2 != 0:
18
+ raise ValueError(f"activation input last dimension must be even, got {x.shape[-1]}")
19
+ hidden = x.shape[-1] // 2
20
+ expected_shape = (*x.shape[:-1], hidden)
21
+ if out is None:
22
+ return torch.empty(expected_shape, device=x.device, dtype=x.dtype)
23
+ if tuple(out.shape) != expected_shape:
24
+ raise ValueError(f"activation output shape must be {expected_shape}, got {tuple(out.shape)}")
25
+ return out
26
+
27
+
28
+ def _can_use_cuda_extension(x: torch.Tensor) -> bool:
29
+ """Guard that the input lives on CUDA; the kernels have no CPU path."""
30
+ if not x.is_cuda:
31
+ raise RuntimeError("ARENO activation kernels require CUDA tensors")
32
+ return True
33
+
34
+
35
+ class _SiluMul(torch.autograd.Function):
36
+ """Autograd glue for fused SiLU(x[..., :H]) * x[..., H:]."""
37
+
38
+ @staticmethod
39
+ def forward(ctx, x: torch.Tensor) -> torch.Tensor:
40
+ out = _activation_out(x, None)
41
+ _extension().areno_silu_and_mul(out, x.contiguous())
42
+ ctx.save_for_backward(x)
43
+ return out
44
+
45
+ @staticmethod
46
+ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor]:
47
+ (x,) = ctx.saved_tensors
48
+ grad_input = torch.empty_like(x)
49
+ _extension().areno_d_silu_and_mul(grad_input, grad_output.contiguous(), x.contiguous())
50
+ return (grad_input,)
51
+
52
+
53
+ class _Silu(torch.autograd.Function):
54
+ """Autograd glue for the element-wise SiLU kernel."""
55
+
56
+ @staticmethod
57
+ def forward(ctx, x: torch.Tensor) -> torch.Tensor:
58
+ out = _extension().areno_silu(x.contiguous())
59
+ ctx.save_for_backward(x)
60
+ return out
61
+
62
+ @staticmethod
63
+ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor]:
64
+ (x,) = ctx.saved_tensors
65
+ return (_extension().areno_d_silu(grad_output.contiguous(), x.contiguous()),)
66
+
67
+
68
+ class _Sigmoid(torch.autograd.Function):
69
+ """Autograd glue for the element-wise sigmoid kernel."""
70
+
71
+ @staticmethod
72
+ def forward(ctx, x: torch.Tensor) -> torch.Tensor:
73
+ out = _extension().areno_sigmoid(x.contiguous())
74
+ ctx.save_for_backward(out)
75
+ return out
76
+
77
+ @staticmethod
78
+ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor]:
79
+ (out,) = ctx.saved_tensors
80
+ return (_extension().areno_d_sigmoid(grad_output.contiguous(), out.contiguous()),)
81
+
82
+
83
+ class _Softplus(torch.autograd.Function):
84
+ """Autograd glue for the element-wise softplus kernel."""
85
+
86
+ @staticmethod
87
+ def forward(ctx, x: torch.Tensor) -> torch.Tensor:
88
+ out = _extension().areno_softplus(x.contiguous())
89
+ ctx.save_for_backward(x)
90
+ return out
91
+
92
+ @staticmethod
93
+ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor]:
94
+ (x,) = ctx.saved_tensors
95
+ return (_extension().areno_d_softplus(grad_output.contiguous(), x.contiguous()),)
96
+
97
+
98
+ class _GeluTanhMul(torch.autograd.Function):
99
+ """Autograd glue for fused tanh-approx GELU(x[..., :H]) * x[..., H:]."""
100
+
101
+ @staticmethod
102
+ def forward(ctx, x: torch.Tensor) -> torch.Tensor:
103
+ out = _activation_out(x, None)
104
+ _extension().areno_gelu_tanh_and_mul(out, x.contiguous())
105
+ ctx.save_for_backward(x)
106
+ return out
107
+
108
+ @staticmethod
109
+ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor]:
110
+ (x,) = ctx.saved_tensors
111
+ grad_input = torch.empty_like(x)
112
+ _extension().areno_d_gelu_tanh_and_mul(grad_input, grad_output.contiguous(), x.contiguous())
113
+ return (grad_input,)
114
+
115
+
116
+ @torch._dynamo.disable
117
+ def areno_silu_and_mul(x: torch.Tensor, out: torch.Tensor | None = None) -> torch.Tensor:
118
+ """Apply SiLU to the first half of the last dimension and multiply by the second half.
119
+
120
+ Input shape (..., 2H) -> output shape (..., H). When an ``out`` tensor is
121
+ supplied the autograd path is skipped and the kernel writes in-place.
122
+ """
123
+ result = _activation_out(x, out)
124
+ _can_use_cuda_extension(x)
125
+ # Only enter the autograd Function when we actually need gradients; the
126
+ # plain kernel path is used for inference / when an out buffer is given.
127
+ if out is None and torch.is_grad_enabled() and x.requires_grad:
128
+ return _SiluMul.apply(x)
129
+ _extension().areno_silu_and_mul(result, x.contiguous())
130
+ return result
131
+
132
+
133
+ @torch._dynamo.disable
134
+ def areno_silu(x: torch.Tensor) -> torch.Tensor:
135
+ """Apply SiLU with an ARENO CUDA kernel."""
136
+ _can_use_cuda_extension(x)
137
+ if torch.is_grad_enabled() and x.requires_grad:
138
+ return _Silu.apply(x)
139
+ return _extension().areno_silu(x.contiguous())
140
+
141
+
142
+ @torch._dynamo.disable
143
+ def areno_sigmoid(x: torch.Tensor) -> torch.Tensor:
144
+ """Apply sigmoid with an ARENO CUDA kernel."""
145
+ _can_use_cuda_extension(x)
146
+ if torch.is_grad_enabled() and x.requires_grad:
147
+ return _Sigmoid.apply(x)
148
+ return _extension().areno_sigmoid(x.contiguous())
149
+
150
+
151
+ @torch._dynamo.disable
152
+ def areno_softplus(x: torch.Tensor) -> torch.Tensor:
153
+ """Apply softplus(beta=1, threshold=20) with an ARENO CUDA kernel."""
154
+ _can_use_cuda_extension(x)
155
+ if torch.is_grad_enabled() and x.requires_grad:
156
+ return _Softplus.apply(x)
157
+ return _extension().areno_softplus(x.contiguous())
158
+
159
+
160
+ @torch._dynamo.disable
161
+ def areno_gelu_tanh_and_mul(x: torch.Tensor, out: torch.Tensor | None = None) -> torch.Tensor:
162
+ """Apply tanh-approximate GELU to the first half and multiply by the second half.
163
+
164
+ Input shape (..., 2H) -> output shape (..., H). Matches the GeGLU
165
+ formulation used by Gemma-family MLPs.
166
+ """
167
+ result = _activation_out(x, out)
168
+ _can_use_cuda_extension(x)
169
+ if out is None and torch.is_grad_enabled() and x.requires_grad:
170
+ return _GeluTanhMul.apply(x)
171
+ _extension().areno_gelu_tanh_and_mul(result, x.contiguous())
172
+ return result
@@ -0,0 +1,141 @@
1
+ """Fused depthwise causal Conv1d + SiLU kernels for short-range token mixing.
2
+
3
+ Used by the linear-attention / Mamba-style layers in areno to mix recent
4
+ tokens within each channel. The kernel performs the convolution with explicit
5
+ causal padding (no future leak) and applies SiLU in the same pass; weights are
6
+ always coerced to ``float32`` so the CUDA path matches the reference layout.
7
+ The ``*_decode`` entry point handles the single-token autoregressive case
8
+ using a separately maintained history cache.
9
+ """
10
+ import torch
11
+
12
+ from areno.accel._extension import extension as _extension
13
+
14
+
15
+ def _check_weight_shape(weight: torch.Tensor) -> None:
16
+ """Validate depthwise weight layout (channels, 1, kernel_size)."""
17
+ if weight.dim() != 3 or weight.shape[1] != 1:
18
+ raise ValueError(f"weight must have shape (channels, 1, kernel), got {tuple(weight.shape)}")
19
+
20
+
21
+ def _kernel_weight(weight: torch.Tensor) -> torch.Tensor:
22
+ """Cast convolution weights to float32 as required by the CUDA kernel."""
23
+ _check_weight_shape(weight)
24
+ return weight if weight.dtype == torch.float32 else weight.float()
25
+
26
+
27
+ class _DepthwiseCausalConv1dSilu(torch.autograd.Function):
28
+ """Autograd glue for fused depthwise causal conv1d + SiLU.
29
+
30
+ Forward returns the activated output; ``preact`` is saved so the backward
31
+ pass can recompute the SiLU derivative cheaply.
32
+ """
33
+
34
+ @staticmethod
35
+ def forward(ctx, x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
36
+ out, preact = _extension().areno_depthwise_causal_conv1d_silu_forward(x.contiguous(), weight.contiguous())
37
+ ctx.save_for_backward(x, weight, preact)
38
+ return out
39
+
40
+ @staticmethod
41
+ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
42
+ x, weight, preact = ctx.saved_tensors
43
+ grad_input, grad_weight = _extension().areno_depthwise_causal_conv1d_silu_backward(
44
+ grad_output.contiguous(),
45
+ x.contiguous(),
46
+ weight.contiguous(),
47
+ preact,
48
+ )
49
+ return grad_input, grad_weight
50
+
51
+
52
+ class _PackedDepthwiseCausalConv1dSilu(torch.autograd.Function):
53
+ """Autograd glue for packed varlen depthwise causal conv1d + SiLU."""
54
+
55
+ @staticmethod
56
+ def forward(ctx, x: torch.Tensor, weight: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
57
+ cu_seqlens = cu_seqlens.to(device=x.device, dtype=torch.int32).contiguous()
58
+ out, preact = _extension().areno_packed_depthwise_causal_conv1d_silu_forward(
59
+ x.contiguous(),
60
+ weight.contiguous(),
61
+ cu_seqlens,
62
+ )
63
+ ctx.save_for_backward(x, weight, cu_seqlens, preact)
64
+ return out
65
+
66
+ @staticmethod
67
+ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, None]:
68
+ x, weight, cu_seqlens, preact = ctx.saved_tensors
69
+ grad_input, grad_weight = _extension().areno_packed_depthwise_causal_conv1d_silu_backward(
70
+ grad_output.contiguous(),
71
+ x.contiguous(),
72
+ weight.contiguous(),
73
+ cu_seqlens,
74
+ preact,
75
+ )
76
+ return grad_input, grad_weight, None
77
+
78
+
79
+ @torch._dynamo.disable
80
+ def areno_depthwise_causal_conv1d_silu(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
81
+ """Apply depthwise causal conv1d followed by SiLU for (batch, seqlen, channels) tensors.
82
+
83
+ The kernel handles arbitrary ``seqlen`` and uses left-padding to keep the
84
+ convolution causal. Falls through to the inference path when autograd is
85
+ disabled to avoid stashing the pre-activation tensor.
86
+ """
87
+ if not x.is_cuda or not weight.is_cuda:
88
+ raise RuntimeError("areno_depthwise_causal_conv1d_silu requires CUDA input and weight")
89
+ if x.dim() != 3:
90
+ raise ValueError(f"input must have shape (batch, seqlen, channels), got {tuple(x.shape)}")
91
+ weight = _kernel_weight(weight)
92
+ if x.shape[-1] != weight.shape[0]:
93
+ raise ValueError(f"channel mismatch: input={x.shape[-1]} weight={weight.shape[0]}")
94
+ if torch.is_grad_enabled() and (x.requires_grad or weight.requires_grad):
95
+ return _DepthwiseCausalConv1dSilu.apply(x, weight)
96
+ out, _ = _extension().areno_depthwise_causal_conv1d_silu_forward(x.contiguous(), weight.contiguous())
97
+ return out
98
+
99
+
100
+ @torch._dynamo.disable
101
+ def areno_packed_depthwise_causal_conv1d_silu(x: torch.Tensor, weight: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
102
+ """Apply depthwise causal conv1d followed by SiLU to packed (1, tokens, channels) tensors."""
103
+ if not x.is_cuda or not weight.is_cuda or not cu_seqlens.is_cuda:
104
+ raise RuntimeError("areno_packed_depthwise_causal_conv1d_silu requires CUDA tensors")
105
+ if x.dim() != 3 or x.shape[0] != 1:
106
+ raise ValueError(f"input must have shape (1, tokens, channels), got {tuple(x.shape)}")
107
+ weight = _kernel_weight(weight)
108
+ if x.shape[-1] != weight.shape[0]:
109
+ raise ValueError(f"channel mismatch: input={x.shape[-1]} weight={weight.shape[0]}")
110
+ if torch.is_grad_enabled() and (x.requires_grad or weight.requires_grad):
111
+ return _PackedDepthwiseCausalConv1dSilu.apply(x, weight, cu_seqlens)
112
+ out, _ = _extension().areno_packed_depthwise_causal_conv1d_silu_forward(
113
+ x.contiguous(),
114
+ weight.contiguous(),
115
+ cu_seqlens.to(device=x.device, dtype=torch.int32).contiguous(),
116
+ )
117
+ return out
118
+
119
+
120
+ @torch._dynamo.disable
121
+ def areno_depthwise_causal_conv1d_silu_decode(current: torch.Tensor, history: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
122
+ """Apply one-token depthwise causal conv1d followed by SiLU for decode.
123
+
124
+ ``current`` is the new token activations with shape ``(rows, channels)``
125
+ and ``history`` provides the previous ``kernel_size - 1`` tokens shaped
126
+ ``(rows, channels, kernel - 1)``. Returns the activated single-step output;
127
+ callers are responsible for shifting ``history``.
128
+ """
129
+ if not current.is_cuda or not history.is_cuda or not weight.is_cuda:
130
+ raise RuntimeError("areno_depthwise_causal_conv1d_silu_decode requires CUDA tensors")
131
+ if current.dim() != 2:
132
+ raise ValueError(f"current must have shape (rows, channels), got {tuple(current.shape)}")
133
+ if history.dim() != 3:
134
+ raise ValueError(f"history must have shape (rows, channels, kernel - 1), got {tuple(history.shape)}")
135
+ weight = _kernel_weight(weight)
136
+ out, _ = _extension().areno_depthwise_causal_conv1d_silu_decode(
137
+ current.contiguous(),
138
+ history.contiguous(),
139
+ weight.contiguous(),
140
+ )
141
+ return out
@@ -0,0 +1,49 @@
1
+ """Tensor-parallel vocab embedding gather backed by an ARENO CUDA kernel.
2
+
3
+ In TP-sharded embedding tables each rank only holds a contiguous slice of the
4
+ vocabulary ``[vocab_start, vocab_end)``. The kernel gathers rows for ids that
5
+ fall inside that range and writes zeros for out-of-range ids so the subsequent
6
+ all-reduce reconstructs the full embedding.
7
+ """
8
+ import torch
9
+
10
+ from areno.accel._extension import extension as _extension
11
+
12
+
13
+ class _VocabEmbedding(torch.autograd.Function):
14
+ """Autograd glue for the vocab-parallel embedding gather/scatter."""
15
+
16
+ @staticmethod
17
+ def forward(ctx, input_ids: torch.Tensor, weight: torch.Tensor, vocab_start: int, vocab_end: int) -> torch.Tensor:
18
+ out = _extension().areno_vocab_embedding_forward(input_ids.contiguous(), weight.contiguous(), int(vocab_start), int(vocab_end))
19
+ ctx.save_for_backward(input_ids, weight)
20
+ ctx.vocab_start = int(vocab_start)
21
+ ctx.vocab_end = int(vocab_end)
22
+ return out
23
+
24
+ @staticmethod
25
+ def backward(ctx, grad_output: torch.Tensor) -> tuple[None, torch.Tensor, None, None]:
26
+ input_ids, weight = ctx.saved_tensors
27
+ grad_weight = _extension().areno_vocab_embedding_backward(
28
+ grad_output.contiguous(),
29
+ input_ids.contiguous(),
30
+ weight,
31
+ ctx.vocab_start,
32
+ ctx.vocab_end,
33
+ )
34
+ return None, grad_weight, None, None
35
+
36
+
37
+ @torch._dynamo.disable
38
+ def areno_vocab_embedding(input_ids: torch.Tensor, weight: torch.Tensor, vocab_start: int, vocab_end: int) -> torch.Tensor:
39
+ """Gather rank-local vocab embeddings and zero out non-local token ids.
40
+
41
+ ``input_ids`` must be int64 on CUDA. ``weight`` is the local shard with
42
+ shape ``(vocab_end - vocab_start, hidden)``. Returns embeddings with shape
43
+ ``(*input_ids.shape, hidden)`` ready for tensor-parallel reduction.
44
+ """
45
+ if not input_ids.is_cuda or not weight.is_cuda:
46
+ raise RuntimeError("areno_vocab_embedding requires CUDA input_ids and weight")
47
+ if input_ids.dtype != torch.long:
48
+ raise TypeError("areno_vocab_embedding input_ids must be int64")
49
+ return _VocabEmbedding.apply(input_ids, weight, int(vocab_start), int(vocab_end))
@@ -0,0 +1 @@
1
+ """Triton kernels that belong to the areno.accel acceleration surface."""