areno 0.0.0.dev0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- areno-0.0.0.dev0/PKG-INFO +19 -0
- areno-0.0.0.dev0/areno/__init__.py +66 -0
- areno-0.0.0.dev0/areno/accel/__init__.py +42 -0
- areno-0.0.0.dev0/areno/accel/_extension.py +24 -0
- areno-0.0.0.dev0/areno/accel/activations.py +172 -0
- areno-0.0.0.dev0/areno/accel/conv.py +141 -0
- areno-0.0.0.dev0/areno/accel/embedding.py +49 -0
- areno-0.0.0.dev0/areno/accel/kernels/__init__.py +1 -0
- areno-0.0.0.dev0/areno/accel/kernels/fused_moe.py +426 -0
- areno-0.0.0.dev0/areno/accel/kernels/group_rmsnorm.py +155 -0
- areno-0.0.0.dev0/areno/accel/kernels/seg_la.py +1168 -0
- areno-0.0.0.dev0/areno/accel/linear.py +116 -0
- areno-0.0.0.dev0/areno/accel/moe.py +155 -0
- areno-0.0.0.dev0/areno/accel/normalization.py +137 -0
- areno-0.0.0.dev0/areno/accel/ops.py +79 -0
- areno-0.0.0.dev0/areno/accel/router.py +42 -0
- areno-0.0.0.dev0/areno/accel/routing.py +47 -0
- areno-0.0.0.dev0/areno/accel/topk.py +36 -0
- areno-0.0.0.dev0/areno/api/__init__.py +46 -0
- areno-0.0.0.dev0/areno/api/advantages.py +74 -0
- areno-0.0.0.dev0/areno/api/algorithms.py +207 -0
- areno-0.0.0.dev0/areno/api/backend/__init__.py +7 -0
- areno-0.0.0.dev0/areno/api/backend/areno/__init__.py +9 -0
- areno-0.0.0.dev0/areno/api/backend/areno/backend.py +406 -0
- areno-0.0.0.dev0/areno/api/backend/base.py +128 -0
- areno-0.0.0.dev0/areno/api/config.py +57 -0
- areno-0.0.0.dev0/areno/api/context.py +31 -0
- areno-0.0.0.dev0/areno/api/data.py +50 -0
- areno-0.0.0.dev0/areno/api/data_utils.py +69 -0
- areno-0.0.0.dev0/areno/api/defaults.py +3 -0
- areno-0.0.0.dev0/areno/api/loss_fns/__init__.py +20 -0
- areno-0.0.0.dev0/areno/api/loss_fns/dpo.py +77 -0
- areno-0.0.0.dev0/areno/api/loss_fns/grpo.py +47 -0
- areno-0.0.0.dev0/areno/api/loss_fns/gspo.py +65 -0
- areno-0.0.0.dev0/areno/api/loss_fns/layout.py +161 -0
- areno-0.0.0.dev0/areno/api/loss_fns/ppo.py +100 -0
- areno-0.0.0.dev0/areno/api/loss_fns/sft.py +37 -0
- areno-0.0.0.dev0/areno/api/metrics.py +174 -0
- areno-0.0.0.dev0/areno/api/models.py +76 -0
- areno-0.0.0.dev0/areno/api/rewards.py +48 -0
- areno-0.0.0.dev0/areno/api/roles.py +33 -0
- areno-0.0.0.dev0/areno/api/tokenizer.py +74 -0
- areno-0.0.0.dev0/areno/api/trainer.py +258 -0
- areno-0.0.0.dev0/areno/api/trainer_config.py +179 -0
- areno-0.0.0.dev0/areno/api/trainer_factory.py +16 -0
- areno-0.0.0.dev0/areno/api/trainers/__init__.py +16 -0
- areno-0.0.0.dev0/areno/api/trainers/dpo.py +199 -0
- areno-0.0.0.dev0/areno/api/trainers/policy_only.py +213 -0
- areno-0.0.0.dev0/areno/api/trainers/ppo.py +355 -0
- areno-0.0.0.dev0/areno/api/trainers/sft.py +168 -0
- areno-0.0.0.dev0/areno/cli/__init__.py +2 -0
- areno-0.0.0.dev0/areno/cli/main.py +50 -0
- areno-0.0.0.dev0/areno/cli/model_refs.py +46 -0
- areno-0.0.0.dev0/areno/cli/serve.py +759 -0
- areno-0.0.0.dev0/areno/cli/train.py +506 -0
- areno-0.0.0.dev0/areno/engine/__init__.py +17 -0
- areno-0.0.0.dev0/areno/engine/api.py +582 -0
- areno-0.0.0.dev0/areno/engine/checkpoints/__init__.py +9 -0
- areno-0.0.0.dev0/areno/engine/checkpoints/common.py +1158 -0
- areno-0.0.0.dev0/areno/engine/checkpoints/io.py +664 -0
- areno-0.0.0.dev0/areno/engine/config.py +208 -0
- areno-0.0.0.dev0/areno/engine/data/__init__.py +12 -0
- areno-0.0.0.dev0/areno/engine/data/batch.py +83 -0
- areno-0.0.0.dev0/areno/engine/data/rollout_state.py +196 -0
- areno-0.0.0.dev0/areno/engine/data/sampling.py +250 -0
- areno-0.0.0.dev0/areno/engine/data/tokenizer.py +30 -0
- areno-0.0.0.dev0/areno/engine/inference.py +941 -0
- areno-0.0.0.dev0/areno/engine/layers/__init__.py +24 -0
- areno-0.0.0.dev0/areno/engine/layers/attention.py +134 -0
- areno-0.0.0.dev0/areno/engine/layers/attention_backend/__init__.py +25 -0
- areno-0.0.0.dev0/areno/engine/layers/attention_backend/common.py +107 -0
- areno-0.0.0.dev0/areno/engine/layers/attention_backend/infer.py +301 -0
- areno-0.0.0.dev0/areno/engine/layers/attention_backend/train.py +203 -0
- areno-0.0.0.dev0/areno/engine/layers/linear.py +242 -0
- areno-0.0.0.dev0/areno/engine/layers/mlp.py +52 -0
- areno-0.0.0.dev0/areno/engine/layers/norm.py +104 -0
- areno-0.0.0.dev0/areno/engine/layers/rotary.py +130 -0
- areno-0.0.0.dev0/areno/engine/layers/vocab.py +73 -0
- areno-0.0.0.dev0/areno/engine/log.py +36 -0
- areno-0.0.0.dev0/areno/engine/modeling.py +99 -0
- areno-0.0.0.dev0/areno/engine/optim/__init__.py +11 -0
- areno-0.0.0.dev0/areno/engine/optim/adamw_8bit.py +274 -0
- areno-0.0.0.dev0/areno/engine/optim/adamw_fp32_master.py +502 -0
- areno-0.0.0.dev0/areno/engine/parallel/__init__.py +12 -0
- areno-0.0.0.dev0/areno/engine/parallel/collectives.py +243 -0
- areno-0.0.0.dev0/areno/engine/parallel/context.py +142 -0
- areno-0.0.0.dev0/areno/engine/protocol.py +445 -0
- areno-0.0.0.dev0/areno/engine/roles.py +557 -0
- areno-0.0.0.dev0/areno/engine/runtime/__init__.py +1 -0
- areno-0.0.0.dev0/areno/engine/runtime/common.py +153 -0
- areno-0.0.0.dev0/areno/engine/runtime/decode_graph.py +172 -0
- areno-0.0.0.dev0/areno/engine/runtime/logprobs.py +204 -0
- areno-0.0.0.dev0/areno/engine/runtime/metadata.py +45 -0
- areno-0.0.0.dev0/areno/engine/runtime/recompute.py +51 -0
- areno-0.0.0.dev0/areno/engine/runtime/rollout.py +127 -0
- areno-0.0.0.dev0/areno/engine/runtime/train_step.py +292 -0
- areno-0.0.0.dev0/areno/engine/training.py +190 -0
- areno-0.0.0.dev0/areno/engine/worker.py +220 -0
- areno-0.0.0.dev0/areno/experimental/__init__.py +1 -0
- areno-0.0.0.dev0/areno/models/__init__.py +31 -0
- areno-0.0.0.dev0/areno/models/_shared/__init__.py +2 -0
- areno-0.0.0.dev0/areno/models/_shared/dynamo_wrappers.py +133 -0
- areno-0.0.0.dev0/areno/models/bailing/__init__.py +13 -0
- areno-0.0.0.dev0/areno/models/bailing/checkpoint.py +92 -0
- areno-0.0.0.dev0/areno/models/bailing/model.py +1304 -0
- areno-0.0.0.dev0/areno/models/base.py +70 -0
- areno-0.0.0.dev0/areno/models/gemma4/__init__.py +13 -0
- areno-0.0.0.dev0/areno/models/gemma4/checkpoint.py +295 -0
- areno-0.0.0.dev0/areno/models/gemma4/model.py +1005 -0
- areno-0.0.0.dev0/areno/models/llama/__init__.py +13 -0
- areno-0.0.0.dev0/areno/models/llama/checkpoint.py +68 -0
- areno-0.0.0.dev0/areno/models/llama/model.py +70 -0
- areno-0.0.0.dev0/areno/models/minicpmv46/__init__.py +11 -0
- areno-0.0.0.dev0/areno/models/minicpmv46/checkpoint.py +298 -0
- areno-0.0.0.dev0/areno/models/minicpmv46/model.py +686 -0
- areno-0.0.0.dev0/areno/models/qwen3/__init__.py +14 -0
- areno-0.0.0.dev0/areno/models/qwen3/checkpoint.py +131 -0
- areno-0.0.0.dev0/areno/models/qwen3/model.py +526 -0
- areno-0.0.0.dev0/areno/models/qwen3_5/__init__.py +5 -0
- areno-0.0.0.dev0/areno/models/qwen3_5/checkpoint.py +609 -0
- areno-0.0.0.dev0/areno/models/qwen3_5/model.py +891 -0
- areno-0.0.0.dev0/areno/models/registry.py +129 -0
- areno-0.0.0.dev0/areno.egg-info/PKG-INFO +19 -0
- areno-0.0.0.dev0/areno.egg-info/SOURCES.txt +144 -0
- areno-0.0.0.dev0/areno.egg-info/dependency_links.txt +1 -0
- areno-0.0.0.dev0/areno.egg-info/entry_points.txt +2 -0
- areno-0.0.0.dev0/areno.egg-info/requires.txt +14 -0
- areno-0.0.0.dev0/areno.egg-info/top_level.txt +1 -0
- areno-0.0.0.dev0/pyproject.toml +32 -0
- areno-0.0.0.dev0/setup.cfg +4 -0
- areno-0.0.0.dev0/setup.py +55 -0
- areno-0.0.0.dev0/tests/test_algorithms_cpu.py +75 -0
- areno-0.0.0.dev0/tests/test_cli_model_refs_cpu.py +59 -0
- areno-0.0.0.dev0/tests/test_config_data_cpu.py +123 -0
- areno-0.0.0.dev0/tests/test_logprobs_cpu.py +95 -0
- areno-0.0.0.dev0/tests/test_losses_rewards_cpu.py +131 -0
- areno-0.0.0.dev0/tests/test_metrics_cpu.py +66 -0
- areno-0.0.0.dev0/tests/test_more_losses_cpu.py +101 -0
- areno-0.0.0.dev0/tests/test_protocol_cpu.py +87 -0
- areno-0.0.0.dev0/tests/test_recompute_cpu.py +64 -0
- areno-0.0.0.dev0/tests/test_registry_cpu.py +174 -0
- areno-0.0.0.dev0/tests/test_runtime_utils_cpu.py +158 -0
- areno-0.0.0.dev0/tests/test_sampling_cpu.py +78 -0
- areno-0.0.0.dev0/tests/test_tokenizer_api_cpu.py +110 -0
- areno-0.0.0.dev0/tests/test_trainer_api_cpu.py +148 -0
- areno-0.0.0.dev0/tests/test_trainer_dataset_utils_cpu.py +114 -0
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: areno
|
|
3
|
+
Version: 0.0.0.dev0
|
|
4
|
+
Summary: Local LLM post-training stack with the areno engine and model plugins.
|
|
5
|
+
Requires-Python: >=3.10
|
|
6
|
+
Requires-Dist: torch>=2.6
|
|
7
|
+
Requires-Dist: flash-attn>=2.7
|
|
8
|
+
Requires-Dist: flash-linear-attention>=0.2
|
|
9
|
+
Requires-Dist: safetensors>=0.4
|
|
10
|
+
Requires-Dist: transformers>=4.56
|
|
11
|
+
Requires-Dist: huggingface-hub>=0.25
|
|
12
|
+
Requires-Dist: datasets>=3.3.0
|
|
13
|
+
Requires-Dist: numpy
|
|
14
|
+
Requires-Dist: tensorboard
|
|
15
|
+
Requires-Dist: einops
|
|
16
|
+
Requires-Dist: fastapi>=0.110
|
|
17
|
+
Requires-Dist: click>=8.1
|
|
18
|
+
Requires-Dist: tqdm>=4.66
|
|
19
|
+
Requires-Dist: uvicorn>=0.27
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
"""Top-level areno package.
|
|
2
|
+
|
|
3
|
+
Sets process-wide knobs that must be in place before any CUDA/Triton kernel or
|
|
4
|
+
torch.compile call runs: a single CUDA stream for collectives and a generous
|
|
5
|
+
TorchDynamo cache so the engine can compile many specialized graphs (per shape
|
|
6
|
+
bucket, prefill vs decode, train vs infer) without thrashing.
|
|
7
|
+
|
|
8
|
+
Exposes the user-facing surface: configuration dataclasses, the rollout
|
|
9
|
+
output container, and the `ArenoEngine` coordinator.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
import os
|
|
15
|
+
|
|
16
|
+
# A single CUDA stream connection keeps NCCL collectives ordered with compute,
|
|
17
|
+
# which is what areno's TP/DP all-reduce + all-gather patterns assume.
|
|
18
|
+
os.environ.setdefault("CUDA_DEVICE_MAX_CONNECTIONS", "1")
|
|
19
|
+
|
|
20
|
+
try:
|
|
21
|
+
import torch._dynamo as _dynamo
|
|
22
|
+
except ModuleNotFoundError:
|
|
23
|
+
_dynamo = None
|
|
24
|
+
|
|
25
|
+
if _dynamo is not None:
|
|
26
|
+
# Train, prefill, decode, scoring and multiple shape buckets all produce
|
|
27
|
+
# distinct compiled artifacts; raise the cache limits so recompilation does
|
|
28
|
+
# not evict graphs that will be replayed across RL steps.
|
|
29
|
+
_dynamo.config.cache_size_limit = max(_dynamo.config.cache_size_limit, 64)
|
|
30
|
+
try:
|
|
31
|
+
_dynamo.config.accumulated_cache_size_limit = max(_dynamo.config.accumulated_cache_size_limit, 256)
|
|
32
|
+
except AttributeError:
|
|
33
|
+
pass
|
|
34
|
+
|
|
35
|
+
from areno.engine.log import configure_default_logging
|
|
36
|
+
|
|
37
|
+
configure_default_logging()
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def __getattr__(name: str):
|
|
41
|
+
"""Lazily expose engine symbols without importing kernel-heavy modules."""
|
|
42
|
+
|
|
43
|
+
if name == "ArenoEngine":
|
|
44
|
+
from areno.engine import ArenoEngine
|
|
45
|
+
|
|
46
|
+
return ArenoEngine
|
|
47
|
+
if name in {"EngineConfig", "ModelConfig", "OptimizerConfig", "RuntimeConfig"}:
|
|
48
|
+
from areno.engine import config
|
|
49
|
+
|
|
50
|
+
return getattr(config, name)
|
|
51
|
+
if name in {"RolloutOutput", "SamplingParams", "TrainStats"}:
|
|
52
|
+
from areno.engine import data
|
|
53
|
+
|
|
54
|
+
return getattr(data, name)
|
|
55
|
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
|
56
|
+
|
|
57
|
+
__all__ = [
|
|
58
|
+
"ArenoEngine",
|
|
59
|
+
"EngineConfig",
|
|
60
|
+
"ModelConfig",
|
|
61
|
+
"OptimizerConfig",
|
|
62
|
+
"RolloutOutput",
|
|
63
|
+
"RuntimeConfig",
|
|
64
|
+
"SamplingParams",
|
|
65
|
+
"TrainStats",
|
|
66
|
+
]
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
"""Public entry point for the ARENO acceleration shims.
|
|
2
|
+
|
|
3
|
+
Re-exports thin Python wrappers around the ``areno.accel._areno_accel`` C++/CUDA
|
|
4
|
+
extension. Each submodule defines a ``torch.autograd.Function`` (where backward
|
|
5
|
+
is needed) and a ``@torch._dynamo.disable``-decorated user-facing function that
|
|
6
|
+
performs argument validation, dispatches to the fused CUDA kernel and exposes a
|
|
7
|
+
PyTorch-friendly signature. The kernels themselves live in ``csrc/`` and are
|
|
8
|
+
built by ``setup.py``; if the extension is not built the first call into the
|
|
9
|
+
shim raises ``ModuleNotFoundError`` at import time.
|
|
10
|
+
"""
|
|
11
|
+
from areno.accel.activations import areno_gelu_tanh_and_mul, areno_sigmoid, areno_silu, areno_silu_and_mul, areno_softplus
|
|
12
|
+
from areno.accel.conv import areno_depthwise_causal_conv1d_silu, areno_depthwise_causal_conv1d_silu_decode, areno_packed_depthwise_causal_conv1d_silu
|
|
13
|
+
from areno.accel.embedding import areno_vocab_embedding
|
|
14
|
+
from areno.accel.linear import areno_grouped_linear, areno_linear
|
|
15
|
+
from areno.accel.moe import areno_moe_permute, areno_moe_topk_permute, areno_moe_unpermute
|
|
16
|
+
from areno.accel.normalization import areno_optional_scale_rmsnorm, areno_rmsnorm, areno_rmsnorm_silu_gate
|
|
17
|
+
from areno.accel.router import areno_grouped_topk_router
|
|
18
|
+
from areno.accel.routing import areno_moe_align
|
|
19
|
+
from areno.accel.topk import areno_topk_softmax
|
|
20
|
+
|
|
21
|
+
__all__ = [
|
|
22
|
+
"areno_depthwise_causal_conv1d_silu",
|
|
23
|
+
"areno_depthwise_causal_conv1d_silu_decode",
|
|
24
|
+
"areno_packed_depthwise_causal_conv1d_silu",
|
|
25
|
+
"areno_gelu_tanh_and_mul",
|
|
26
|
+
"areno_grouped_topk_router",
|
|
27
|
+
"areno_grouped_linear",
|
|
28
|
+
"areno_linear",
|
|
29
|
+
"areno_moe_align",
|
|
30
|
+
"areno_moe_permute",
|
|
31
|
+
"areno_moe_topk_permute",
|
|
32
|
+
"areno_moe_unpermute",
|
|
33
|
+
"areno_optional_scale_rmsnorm",
|
|
34
|
+
"areno_rmsnorm",
|
|
35
|
+
"areno_rmsnorm_silu_gate",
|
|
36
|
+
"areno_sigmoid",
|
|
37
|
+
"areno_silu",
|
|
38
|
+
"areno_silu_and_mul",
|
|
39
|
+
"areno_softplus",
|
|
40
|
+
"areno_topk_softmax",
|
|
41
|
+
"areno_vocab_embedding",
|
|
42
|
+
]
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
"""Lazy loader for the compiled ``areno.accel._areno_accel`` C++/CUDA extension.
|
|
2
|
+
|
|
3
|
+
The extension module is imported on first use rather than at package import
|
|
4
|
+
time so that ``import areno.accel`` succeeds in environments where only the
|
|
5
|
+
Python shims are needed (e.g. for type checking). Each shim calls
|
|
6
|
+
``extension()`` to obtain the compiled module and dispatch into the fused
|
|
7
|
+
kernel. There is no pure-Python fallback: if the extension was not built the
|
|
8
|
+
``importlib.import_module`` call below raises ``ModuleNotFoundError``.
|
|
9
|
+
"""
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import importlib
|
|
13
|
+
from types import ModuleType
|
|
14
|
+
|
|
15
|
+
# Cached reference to the compiled extension; populated on first call.
|
|
16
|
+
_EXT: ModuleType | None = None
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def extension() -> ModuleType:
|
|
20
|
+
"""Return the compiled C++/CUDA extension module, importing it lazily."""
|
|
21
|
+
global _EXT
|
|
22
|
+
if _EXT is None:
|
|
23
|
+
_EXT = importlib.import_module("areno.accel._areno_accel")
|
|
24
|
+
return _EXT
|
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
"""Fused activation kernels (SiLU, GELU-tanh, sigmoid, softplus, gated variants).
|
|
2
|
+
|
|
3
|
+
Each wrapper dispatches to the ARENO CUDA kernel via the compiled extension and
|
|
4
|
+
preserves autograd by routing through ``torch.autograd.Function`` whenever the
|
|
5
|
+
input requires gradients. The ``*_and_mul`` variants implement the common
|
|
6
|
+
"gated MLP" pattern where the last input dimension is split in two and the
|
|
7
|
+
first half is activated then element-wise multiplied with the second half,
|
|
8
|
+
producing an output with half the last-dimension size.
|
|
9
|
+
"""
|
|
10
|
+
import torch
|
|
11
|
+
|
|
12
|
+
from areno.accel._extension import extension as _extension
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _activation_out(x: torch.Tensor, out: torch.Tensor | None) -> torch.Tensor:
|
|
16
|
+
"""Allocate or validate the half-width output tensor for ``*_and_mul`` ops."""
|
|
17
|
+
if x.shape[-1] % 2 != 0:
|
|
18
|
+
raise ValueError(f"activation input last dimension must be even, got {x.shape[-1]}")
|
|
19
|
+
hidden = x.shape[-1] // 2
|
|
20
|
+
expected_shape = (*x.shape[:-1], hidden)
|
|
21
|
+
if out is None:
|
|
22
|
+
return torch.empty(expected_shape, device=x.device, dtype=x.dtype)
|
|
23
|
+
if tuple(out.shape) != expected_shape:
|
|
24
|
+
raise ValueError(f"activation output shape must be {expected_shape}, got {tuple(out.shape)}")
|
|
25
|
+
return out
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _can_use_cuda_extension(x: torch.Tensor) -> bool:
|
|
29
|
+
"""Guard that the input lives on CUDA; the kernels have no CPU path."""
|
|
30
|
+
if not x.is_cuda:
|
|
31
|
+
raise RuntimeError("ARENO activation kernels require CUDA tensors")
|
|
32
|
+
return True
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class _SiluMul(torch.autograd.Function):
|
|
36
|
+
"""Autograd glue for fused SiLU(x[..., :H]) * x[..., H:]."""
|
|
37
|
+
|
|
38
|
+
@staticmethod
|
|
39
|
+
def forward(ctx, x: torch.Tensor) -> torch.Tensor:
|
|
40
|
+
out = _activation_out(x, None)
|
|
41
|
+
_extension().areno_silu_and_mul(out, x.contiguous())
|
|
42
|
+
ctx.save_for_backward(x)
|
|
43
|
+
return out
|
|
44
|
+
|
|
45
|
+
@staticmethod
|
|
46
|
+
def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor]:
|
|
47
|
+
(x,) = ctx.saved_tensors
|
|
48
|
+
grad_input = torch.empty_like(x)
|
|
49
|
+
_extension().areno_d_silu_and_mul(grad_input, grad_output.contiguous(), x.contiguous())
|
|
50
|
+
return (grad_input,)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class _Silu(torch.autograd.Function):
|
|
54
|
+
"""Autograd glue for the element-wise SiLU kernel."""
|
|
55
|
+
|
|
56
|
+
@staticmethod
|
|
57
|
+
def forward(ctx, x: torch.Tensor) -> torch.Tensor:
|
|
58
|
+
out = _extension().areno_silu(x.contiguous())
|
|
59
|
+
ctx.save_for_backward(x)
|
|
60
|
+
return out
|
|
61
|
+
|
|
62
|
+
@staticmethod
|
|
63
|
+
def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor]:
|
|
64
|
+
(x,) = ctx.saved_tensors
|
|
65
|
+
return (_extension().areno_d_silu(grad_output.contiguous(), x.contiguous()),)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class _Sigmoid(torch.autograd.Function):
|
|
69
|
+
"""Autograd glue for the element-wise sigmoid kernel."""
|
|
70
|
+
|
|
71
|
+
@staticmethod
|
|
72
|
+
def forward(ctx, x: torch.Tensor) -> torch.Tensor:
|
|
73
|
+
out = _extension().areno_sigmoid(x.contiguous())
|
|
74
|
+
ctx.save_for_backward(out)
|
|
75
|
+
return out
|
|
76
|
+
|
|
77
|
+
@staticmethod
|
|
78
|
+
def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor]:
|
|
79
|
+
(out,) = ctx.saved_tensors
|
|
80
|
+
return (_extension().areno_d_sigmoid(grad_output.contiguous(), out.contiguous()),)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class _Softplus(torch.autograd.Function):
|
|
84
|
+
"""Autograd glue for the element-wise softplus kernel."""
|
|
85
|
+
|
|
86
|
+
@staticmethod
|
|
87
|
+
def forward(ctx, x: torch.Tensor) -> torch.Tensor:
|
|
88
|
+
out = _extension().areno_softplus(x.contiguous())
|
|
89
|
+
ctx.save_for_backward(x)
|
|
90
|
+
return out
|
|
91
|
+
|
|
92
|
+
@staticmethod
|
|
93
|
+
def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor]:
|
|
94
|
+
(x,) = ctx.saved_tensors
|
|
95
|
+
return (_extension().areno_d_softplus(grad_output.contiguous(), x.contiguous()),)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class _GeluTanhMul(torch.autograd.Function):
|
|
99
|
+
"""Autograd glue for fused tanh-approx GELU(x[..., :H]) * x[..., H:]."""
|
|
100
|
+
|
|
101
|
+
@staticmethod
|
|
102
|
+
def forward(ctx, x: torch.Tensor) -> torch.Tensor:
|
|
103
|
+
out = _activation_out(x, None)
|
|
104
|
+
_extension().areno_gelu_tanh_and_mul(out, x.contiguous())
|
|
105
|
+
ctx.save_for_backward(x)
|
|
106
|
+
return out
|
|
107
|
+
|
|
108
|
+
@staticmethod
|
|
109
|
+
def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor]:
|
|
110
|
+
(x,) = ctx.saved_tensors
|
|
111
|
+
grad_input = torch.empty_like(x)
|
|
112
|
+
_extension().areno_d_gelu_tanh_and_mul(grad_input, grad_output.contiguous(), x.contiguous())
|
|
113
|
+
return (grad_input,)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
@torch._dynamo.disable
|
|
117
|
+
def areno_silu_and_mul(x: torch.Tensor, out: torch.Tensor | None = None) -> torch.Tensor:
|
|
118
|
+
"""Apply SiLU to the first half of the last dimension and multiply by the second half.
|
|
119
|
+
|
|
120
|
+
Input shape (..., 2H) -> output shape (..., H). When an ``out`` tensor is
|
|
121
|
+
supplied the autograd path is skipped and the kernel writes in-place.
|
|
122
|
+
"""
|
|
123
|
+
result = _activation_out(x, out)
|
|
124
|
+
_can_use_cuda_extension(x)
|
|
125
|
+
# Only enter the autograd Function when we actually need gradients; the
|
|
126
|
+
# plain kernel path is used for inference / when an out buffer is given.
|
|
127
|
+
if out is None and torch.is_grad_enabled() and x.requires_grad:
|
|
128
|
+
return _SiluMul.apply(x)
|
|
129
|
+
_extension().areno_silu_and_mul(result, x.contiguous())
|
|
130
|
+
return result
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
@torch._dynamo.disable
|
|
134
|
+
def areno_silu(x: torch.Tensor) -> torch.Tensor:
|
|
135
|
+
"""Apply SiLU with an ARENO CUDA kernel."""
|
|
136
|
+
_can_use_cuda_extension(x)
|
|
137
|
+
if torch.is_grad_enabled() and x.requires_grad:
|
|
138
|
+
return _Silu.apply(x)
|
|
139
|
+
return _extension().areno_silu(x.contiguous())
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
@torch._dynamo.disable
|
|
143
|
+
def areno_sigmoid(x: torch.Tensor) -> torch.Tensor:
|
|
144
|
+
"""Apply sigmoid with an ARENO CUDA kernel."""
|
|
145
|
+
_can_use_cuda_extension(x)
|
|
146
|
+
if torch.is_grad_enabled() and x.requires_grad:
|
|
147
|
+
return _Sigmoid.apply(x)
|
|
148
|
+
return _extension().areno_sigmoid(x.contiguous())
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
@torch._dynamo.disable
|
|
152
|
+
def areno_softplus(x: torch.Tensor) -> torch.Tensor:
|
|
153
|
+
"""Apply softplus(beta=1, threshold=20) with an ARENO CUDA kernel."""
|
|
154
|
+
_can_use_cuda_extension(x)
|
|
155
|
+
if torch.is_grad_enabled() and x.requires_grad:
|
|
156
|
+
return _Softplus.apply(x)
|
|
157
|
+
return _extension().areno_softplus(x.contiguous())
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
@torch._dynamo.disable
|
|
161
|
+
def areno_gelu_tanh_and_mul(x: torch.Tensor, out: torch.Tensor | None = None) -> torch.Tensor:
|
|
162
|
+
"""Apply tanh-approximate GELU to the first half and multiply by the second half.
|
|
163
|
+
|
|
164
|
+
Input shape (..., 2H) -> output shape (..., H). Matches the GeGLU
|
|
165
|
+
formulation used by Gemma-family MLPs.
|
|
166
|
+
"""
|
|
167
|
+
result = _activation_out(x, out)
|
|
168
|
+
_can_use_cuda_extension(x)
|
|
169
|
+
if out is None and torch.is_grad_enabled() and x.requires_grad:
|
|
170
|
+
return _GeluTanhMul.apply(x)
|
|
171
|
+
_extension().areno_gelu_tanh_and_mul(result, x.contiguous())
|
|
172
|
+
return result
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
"""Fused depthwise causal Conv1d + SiLU kernels for short-range token mixing.
|
|
2
|
+
|
|
3
|
+
Used by the linear-attention / Mamba-style layers in areno to mix recent
|
|
4
|
+
tokens within each channel. The kernel performs the convolution with explicit
|
|
5
|
+
causal padding (no future leak) and applies SiLU in the same pass; weights are
|
|
6
|
+
always coerced to ``float32`` so the CUDA path matches the reference layout.
|
|
7
|
+
The ``*_decode`` entry point handles the single-token autoregressive case
|
|
8
|
+
using a separately maintained history cache.
|
|
9
|
+
"""
|
|
10
|
+
import torch
|
|
11
|
+
|
|
12
|
+
from areno.accel._extension import extension as _extension
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _check_weight_shape(weight: torch.Tensor) -> None:
|
|
16
|
+
"""Validate depthwise weight layout (channels, 1, kernel_size)."""
|
|
17
|
+
if weight.dim() != 3 or weight.shape[1] != 1:
|
|
18
|
+
raise ValueError(f"weight must have shape (channels, 1, kernel), got {tuple(weight.shape)}")
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _kernel_weight(weight: torch.Tensor) -> torch.Tensor:
|
|
22
|
+
"""Cast convolution weights to float32 as required by the CUDA kernel."""
|
|
23
|
+
_check_weight_shape(weight)
|
|
24
|
+
return weight if weight.dtype == torch.float32 else weight.float()
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class _DepthwiseCausalConv1dSilu(torch.autograd.Function):
|
|
28
|
+
"""Autograd glue for fused depthwise causal conv1d + SiLU.
|
|
29
|
+
|
|
30
|
+
Forward returns the activated output; ``preact`` is saved so the backward
|
|
31
|
+
pass can recompute the SiLU derivative cheaply.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
@staticmethod
|
|
35
|
+
def forward(ctx, x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
|
|
36
|
+
out, preact = _extension().areno_depthwise_causal_conv1d_silu_forward(x.contiguous(), weight.contiguous())
|
|
37
|
+
ctx.save_for_backward(x, weight, preact)
|
|
38
|
+
return out
|
|
39
|
+
|
|
40
|
+
@staticmethod
|
|
41
|
+
def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
42
|
+
x, weight, preact = ctx.saved_tensors
|
|
43
|
+
grad_input, grad_weight = _extension().areno_depthwise_causal_conv1d_silu_backward(
|
|
44
|
+
grad_output.contiguous(),
|
|
45
|
+
x.contiguous(),
|
|
46
|
+
weight.contiguous(),
|
|
47
|
+
preact,
|
|
48
|
+
)
|
|
49
|
+
return grad_input, grad_weight
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class _PackedDepthwiseCausalConv1dSilu(torch.autograd.Function):
|
|
53
|
+
"""Autograd glue for packed varlen depthwise causal conv1d + SiLU."""
|
|
54
|
+
|
|
55
|
+
@staticmethod
|
|
56
|
+
def forward(ctx, x: torch.Tensor, weight: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
|
|
57
|
+
cu_seqlens = cu_seqlens.to(device=x.device, dtype=torch.int32).contiguous()
|
|
58
|
+
out, preact = _extension().areno_packed_depthwise_causal_conv1d_silu_forward(
|
|
59
|
+
x.contiguous(),
|
|
60
|
+
weight.contiguous(),
|
|
61
|
+
cu_seqlens,
|
|
62
|
+
)
|
|
63
|
+
ctx.save_for_backward(x, weight, cu_seqlens, preact)
|
|
64
|
+
return out
|
|
65
|
+
|
|
66
|
+
@staticmethod
|
|
67
|
+
def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, None]:
|
|
68
|
+
x, weight, cu_seqlens, preact = ctx.saved_tensors
|
|
69
|
+
grad_input, grad_weight = _extension().areno_packed_depthwise_causal_conv1d_silu_backward(
|
|
70
|
+
grad_output.contiguous(),
|
|
71
|
+
x.contiguous(),
|
|
72
|
+
weight.contiguous(),
|
|
73
|
+
cu_seqlens,
|
|
74
|
+
preact,
|
|
75
|
+
)
|
|
76
|
+
return grad_input, grad_weight, None
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@torch._dynamo.disable
|
|
80
|
+
def areno_depthwise_causal_conv1d_silu(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
|
|
81
|
+
"""Apply depthwise causal conv1d followed by SiLU for (batch, seqlen, channels) tensors.
|
|
82
|
+
|
|
83
|
+
The kernel handles arbitrary ``seqlen`` and uses left-padding to keep the
|
|
84
|
+
convolution causal. Falls through to the inference path when autograd is
|
|
85
|
+
disabled to avoid stashing the pre-activation tensor.
|
|
86
|
+
"""
|
|
87
|
+
if not x.is_cuda or not weight.is_cuda:
|
|
88
|
+
raise RuntimeError("areno_depthwise_causal_conv1d_silu requires CUDA input and weight")
|
|
89
|
+
if x.dim() != 3:
|
|
90
|
+
raise ValueError(f"input must have shape (batch, seqlen, channels), got {tuple(x.shape)}")
|
|
91
|
+
weight = _kernel_weight(weight)
|
|
92
|
+
if x.shape[-1] != weight.shape[0]:
|
|
93
|
+
raise ValueError(f"channel mismatch: input={x.shape[-1]} weight={weight.shape[0]}")
|
|
94
|
+
if torch.is_grad_enabled() and (x.requires_grad or weight.requires_grad):
|
|
95
|
+
return _DepthwiseCausalConv1dSilu.apply(x, weight)
|
|
96
|
+
out, _ = _extension().areno_depthwise_causal_conv1d_silu_forward(x.contiguous(), weight.contiguous())
|
|
97
|
+
return out
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
@torch._dynamo.disable
|
|
101
|
+
def areno_packed_depthwise_causal_conv1d_silu(x: torch.Tensor, weight: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
|
|
102
|
+
"""Apply depthwise causal conv1d followed by SiLU to packed (1, tokens, channels) tensors."""
|
|
103
|
+
if not x.is_cuda or not weight.is_cuda or not cu_seqlens.is_cuda:
|
|
104
|
+
raise RuntimeError("areno_packed_depthwise_causal_conv1d_silu requires CUDA tensors")
|
|
105
|
+
if x.dim() != 3 or x.shape[0] != 1:
|
|
106
|
+
raise ValueError(f"input must have shape (1, tokens, channels), got {tuple(x.shape)}")
|
|
107
|
+
weight = _kernel_weight(weight)
|
|
108
|
+
if x.shape[-1] != weight.shape[0]:
|
|
109
|
+
raise ValueError(f"channel mismatch: input={x.shape[-1]} weight={weight.shape[0]}")
|
|
110
|
+
if torch.is_grad_enabled() and (x.requires_grad or weight.requires_grad):
|
|
111
|
+
return _PackedDepthwiseCausalConv1dSilu.apply(x, weight, cu_seqlens)
|
|
112
|
+
out, _ = _extension().areno_packed_depthwise_causal_conv1d_silu_forward(
|
|
113
|
+
x.contiguous(),
|
|
114
|
+
weight.contiguous(),
|
|
115
|
+
cu_seqlens.to(device=x.device, dtype=torch.int32).contiguous(),
|
|
116
|
+
)
|
|
117
|
+
return out
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@torch._dynamo.disable
|
|
121
|
+
def areno_depthwise_causal_conv1d_silu_decode(current: torch.Tensor, history: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
|
|
122
|
+
"""Apply one-token depthwise causal conv1d followed by SiLU for decode.
|
|
123
|
+
|
|
124
|
+
``current`` is the new token activations with shape ``(rows, channels)``
|
|
125
|
+
and ``history`` provides the previous ``kernel_size - 1`` tokens shaped
|
|
126
|
+
``(rows, channels, kernel - 1)``. Returns the activated single-step output;
|
|
127
|
+
callers are responsible for shifting ``history``.
|
|
128
|
+
"""
|
|
129
|
+
if not current.is_cuda or not history.is_cuda or not weight.is_cuda:
|
|
130
|
+
raise RuntimeError("areno_depthwise_causal_conv1d_silu_decode requires CUDA tensors")
|
|
131
|
+
if current.dim() != 2:
|
|
132
|
+
raise ValueError(f"current must have shape (rows, channels), got {tuple(current.shape)}")
|
|
133
|
+
if history.dim() != 3:
|
|
134
|
+
raise ValueError(f"history must have shape (rows, channels, kernel - 1), got {tuple(history.shape)}")
|
|
135
|
+
weight = _kernel_weight(weight)
|
|
136
|
+
out, _ = _extension().areno_depthwise_causal_conv1d_silu_decode(
|
|
137
|
+
current.contiguous(),
|
|
138
|
+
history.contiguous(),
|
|
139
|
+
weight.contiguous(),
|
|
140
|
+
)
|
|
141
|
+
return out
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
"""Tensor-parallel vocab embedding gather backed by an ARENO CUDA kernel.
|
|
2
|
+
|
|
3
|
+
In TP-sharded embedding tables each rank only holds a contiguous slice of the
|
|
4
|
+
vocabulary ``[vocab_start, vocab_end)``. The kernel gathers rows for ids that
|
|
5
|
+
fall inside that range and writes zeros for out-of-range ids so the subsequent
|
|
6
|
+
all-reduce reconstructs the full embedding.
|
|
7
|
+
"""
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from areno.accel._extension import extension as _extension
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class _VocabEmbedding(torch.autograd.Function):
|
|
14
|
+
"""Autograd glue for the vocab-parallel embedding gather/scatter."""
|
|
15
|
+
|
|
16
|
+
@staticmethod
|
|
17
|
+
def forward(ctx, input_ids: torch.Tensor, weight: torch.Tensor, vocab_start: int, vocab_end: int) -> torch.Tensor:
|
|
18
|
+
out = _extension().areno_vocab_embedding_forward(input_ids.contiguous(), weight.contiguous(), int(vocab_start), int(vocab_end))
|
|
19
|
+
ctx.save_for_backward(input_ids, weight)
|
|
20
|
+
ctx.vocab_start = int(vocab_start)
|
|
21
|
+
ctx.vocab_end = int(vocab_end)
|
|
22
|
+
return out
|
|
23
|
+
|
|
24
|
+
@staticmethod
|
|
25
|
+
def backward(ctx, grad_output: torch.Tensor) -> tuple[None, torch.Tensor, None, None]:
|
|
26
|
+
input_ids, weight = ctx.saved_tensors
|
|
27
|
+
grad_weight = _extension().areno_vocab_embedding_backward(
|
|
28
|
+
grad_output.contiguous(),
|
|
29
|
+
input_ids.contiguous(),
|
|
30
|
+
weight,
|
|
31
|
+
ctx.vocab_start,
|
|
32
|
+
ctx.vocab_end,
|
|
33
|
+
)
|
|
34
|
+
return None, grad_weight, None, None
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@torch._dynamo.disable
|
|
38
|
+
def areno_vocab_embedding(input_ids: torch.Tensor, weight: torch.Tensor, vocab_start: int, vocab_end: int) -> torch.Tensor:
|
|
39
|
+
"""Gather rank-local vocab embeddings and zero out non-local token ids.
|
|
40
|
+
|
|
41
|
+
``input_ids`` must be int64 on CUDA. ``weight`` is the local shard with
|
|
42
|
+
shape ``(vocab_end - vocab_start, hidden)``. Returns embeddings with shape
|
|
43
|
+
``(*input_ids.shape, hidden)`` ready for tensor-parallel reduction.
|
|
44
|
+
"""
|
|
45
|
+
if not input_ids.is_cuda or not weight.is_cuda:
|
|
46
|
+
raise RuntimeError("areno_vocab_embedding requires CUDA input_ids and weight")
|
|
47
|
+
if input_ids.dtype != torch.long:
|
|
48
|
+
raise TypeError("areno_vocab_embedding input_ids must be int64")
|
|
49
|
+
return _VocabEmbedding.apply(input_ids, weight, int(vocab_start), int(vocab_end))
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Triton kernels that belong to the areno.accel acceleration surface."""
|