llama-stack 0.4.4__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llama_stack/cli/stack/_list_deps.py +11 -7
- llama_stack/cli/stack/run.py +3 -25
- llama_stack/core/access_control/datatypes.py +78 -0
- llama_stack/core/configure.py +2 -2
- llama_stack/{distributions/meta-reference-gpu → core/connectors}/__init__.py +3 -1
- llama_stack/core/connectors/connectors.py +162 -0
- llama_stack/core/conversations/conversations.py +61 -58
- llama_stack/core/datatypes.py +54 -8
- llama_stack/core/library_client.py +60 -13
- llama_stack/core/prompts/prompts.py +43 -42
- llama_stack/core/routers/datasets.py +20 -17
- llama_stack/core/routers/eval_scoring.py +143 -53
- llama_stack/core/routers/inference.py +20 -9
- llama_stack/core/routers/safety.py +30 -42
- llama_stack/core/routers/vector_io.py +15 -7
- llama_stack/core/routing_tables/models.py +42 -3
- llama_stack/core/routing_tables/scoring_functions.py +19 -19
- llama_stack/core/routing_tables/shields.py +20 -17
- llama_stack/core/routing_tables/vector_stores.py +8 -5
- llama_stack/core/server/auth.py +192 -17
- llama_stack/core/server/fastapi_router_registry.py +40 -5
- llama_stack/core/server/server.py +24 -5
- llama_stack/core/stack.py +54 -10
- llama_stack/core/storage/datatypes.py +9 -0
- llama_stack/core/store/registry.py +1 -1
- llama_stack/core/utils/exec.py +2 -2
- llama_stack/core/utils/type_inspection.py +16 -2
- llama_stack/distributions/dell/config.yaml +4 -1
- llama_stack/distributions/dell/run-with-safety.yaml +4 -1
- llama_stack/distributions/nvidia/config.yaml +4 -1
- llama_stack/distributions/nvidia/run-with-safety.yaml +4 -1
- llama_stack/distributions/oci/config.yaml +4 -1
- llama_stack/distributions/open-benchmark/config.yaml +9 -1
- llama_stack/distributions/postgres-demo/config.yaml +1 -1
- llama_stack/distributions/starter/build.yaml +62 -0
- llama_stack/distributions/starter/config.yaml +22 -3
- llama_stack/distributions/starter/run-with-postgres-store.yaml +22 -3
- llama_stack/distributions/starter/starter.py +13 -1
- llama_stack/distributions/starter-gpu/build.yaml +62 -0
- llama_stack/distributions/starter-gpu/config.yaml +22 -3
- llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +22 -3
- llama_stack/distributions/template.py +10 -2
- llama_stack/distributions/watsonx/config.yaml +4 -1
- llama_stack/log.py +1 -0
- llama_stack/providers/inline/agents/meta_reference/__init__.py +1 -0
- llama_stack/providers/inline/agents/meta_reference/agents.py +57 -61
- llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +49 -51
- llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +94 -22
- llama_stack/providers/inline/agents/meta_reference/responses/types.py +2 -1
- llama_stack/providers/inline/agents/meta_reference/responses/utils.py +4 -1
- llama_stack/providers/inline/agents/meta_reference/safety.py +2 -2
- llama_stack/providers/inline/batches/reference/batches.py +2 -1
- llama_stack/providers/inline/eval/meta_reference/eval.py +40 -32
- llama_stack/providers/inline/post_training/huggingface/post_training.py +33 -38
- llama_stack/providers/inline/post_training/huggingface/utils.py +2 -5
- llama_stack/providers/inline/post_training/torchtune/post_training.py +28 -33
- llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +2 -4
- llama_stack/providers/inline/safety/code_scanner/code_scanner.py +12 -15
- llama_stack/providers/inline/safety/llama_guard/llama_guard.py +15 -18
- llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +11 -17
- llama_stack/providers/inline/scoring/basic/scoring.py +13 -17
- llama_stack/providers/inline/scoring/braintrust/braintrust.py +15 -15
- llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +13 -17
- llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +1 -1
- llama_stack/providers/registry/agents.py +1 -0
- llama_stack/providers/registry/inference.py +1 -9
- llama_stack/providers/registry/vector_io.py +136 -16
- llama_stack/providers/remote/eval/nvidia/eval.py +22 -21
- llama_stack/providers/remote/files/s3/config.py +5 -3
- llama_stack/providers/remote/files/s3/files.py +2 -2
- llama_stack/providers/remote/inference/gemini/gemini.py +4 -0
- llama_stack/providers/remote/inference/openai/openai.py +2 -0
- llama_stack/providers/remote/inference/together/together.py +4 -0
- llama_stack/providers/remote/inference/vertexai/config.py +3 -3
- llama_stack/providers/remote/inference/vertexai/vertexai.py +5 -2
- llama_stack/providers/remote/inference/vllm/config.py +37 -18
- llama_stack/providers/remote/inference/vllm/vllm.py +0 -3
- llama_stack/providers/remote/inference/watsonx/watsonx.py +4 -0
- llama_stack/providers/remote/post_training/nvidia/post_training.py +31 -33
- llama_stack/providers/remote/safety/bedrock/bedrock.py +10 -27
- llama_stack/providers/remote/safety/nvidia/nvidia.py +9 -25
- llama_stack/providers/remote/safety/sambanova/sambanova.py +13 -11
- llama_stack/providers/remote/vector_io/elasticsearch/__init__.py +17 -0
- llama_stack/providers/remote/vector_io/elasticsearch/config.py +32 -0
- llama_stack/providers/remote/vector_io/elasticsearch/elasticsearch.py +463 -0
- llama_stack/providers/remote/vector_io/oci/__init__.py +22 -0
- llama_stack/providers/remote/vector_io/oci/config.py +41 -0
- llama_stack/providers/remote/vector_io/oci/oci26ai.py +595 -0
- llama_stack/providers/remote/vector_io/pgvector/config.py +69 -2
- llama_stack/providers/remote/vector_io/pgvector/pgvector.py +255 -6
- llama_stack/providers/remote/vector_io/qdrant/qdrant.py +62 -38
- llama_stack/providers/utils/bedrock/client.py +3 -3
- llama_stack/providers/utils/bedrock/config.py +7 -7
- llama_stack/providers/utils/inference/embedding_mixin.py +4 -0
- llama_stack/providers/utils/inference/http_client.py +239 -0
- llama_stack/providers/utils/inference/litellm_openai_mixin.py +5 -0
- llama_stack/providers/utils/inference/model_registry.py +148 -2
- llama_stack/providers/utils/inference/openai_compat.py +2 -1
- llama_stack/providers/utils/inference/openai_mixin.py +41 -2
- llama_stack/providers/utils/memory/openai_vector_store_mixin.py +92 -5
- llama_stack/providers/utils/memory/vector_store.py +46 -19
- llama_stack/providers/utils/responses/responses_store.py +7 -7
- llama_stack/providers/utils/safety.py +114 -0
- llama_stack/providers/utils/tools/mcp.py +44 -3
- llama_stack/testing/api_recorder.py +9 -3
- {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/METADATA +14 -2
- {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/RECORD +111 -144
- llama_stack/distributions/meta-reference-gpu/config.yaml +0 -140
- llama_stack/distributions/meta-reference-gpu/doc_template.md +0 -119
- llama_stack/distributions/meta-reference-gpu/meta_reference.py +0 -163
- llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml +0 -155
- llama_stack/models/llama/hadamard_utils.py +0 -88
- llama_stack/models/llama/llama3/args.py +0 -74
- llama_stack/models/llama/llama3/dog.jpg +0 -0
- llama_stack/models/llama/llama3/generation.py +0 -378
- llama_stack/models/llama/llama3/model.py +0 -304
- llama_stack/models/llama/llama3/multimodal/__init__.py +0 -12
- llama_stack/models/llama/llama3/multimodal/encoder_utils.py +0 -180
- llama_stack/models/llama/llama3/multimodal/image_transform.py +0 -409
- llama_stack/models/llama/llama3/multimodal/model.py +0 -1430
- llama_stack/models/llama/llama3/multimodal/utils.py +0 -26
- llama_stack/models/llama/llama3/pasta.jpeg +0 -0
- llama_stack/models/llama/llama3/quantization/__init__.py +0 -5
- llama_stack/models/llama/llama3/quantization/loader.py +0 -316
- llama_stack/models/llama/llama3_1/__init__.py +0 -12
- llama_stack/models/llama/llama3_1/prompt_format.md +0 -358
- llama_stack/models/llama/llama3_1/prompts.py +0 -258
- llama_stack/models/llama/llama3_2/__init__.py +0 -5
- llama_stack/models/llama/llama3_2/prompts_text.py +0 -229
- llama_stack/models/llama/llama3_2/prompts_vision.py +0 -126
- llama_stack/models/llama/llama3_2/text_prompt_format.md +0 -286
- llama_stack/models/llama/llama3_2/vision_prompt_format.md +0 -141
- llama_stack/models/llama/llama3_3/__init__.py +0 -5
- llama_stack/models/llama/llama3_3/prompts.py +0 -259
- llama_stack/models/llama/llama4/args.py +0 -107
- llama_stack/models/llama/llama4/ffn.py +0 -58
- llama_stack/models/llama/llama4/moe.py +0 -214
- llama_stack/models/llama/llama4/preprocess.py +0 -435
- llama_stack/models/llama/llama4/quantization/__init__.py +0 -5
- llama_stack/models/llama/llama4/quantization/loader.py +0 -226
- llama_stack/models/llama/llama4/vision/__init__.py +0 -5
- llama_stack/models/llama/llama4/vision/embedding.py +0 -210
- llama_stack/models/llama/llama4/vision/encoder.py +0 -412
- llama_stack/models/llama/quantize_impls.py +0 -316
- llama_stack/providers/inline/inference/meta_reference/__init__.py +0 -20
- llama_stack/providers/inline/inference/meta_reference/common.py +0 -24
- llama_stack/providers/inline/inference/meta_reference/config.py +0 -68
- llama_stack/providers/inline/inference/meta_reference/generators.py +0 -201
- llama_stack/providers/inline/inference/meta_reference/inference.py +0 -542
- llama_stack/providers/inline/inference/meta_reference/model_parallel.py +0 -77
- llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +0 -353
- {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/WHEEL +0 -0
- {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/entry_points.txt +0 -0
- {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
- {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -1,304 +0,0 @@
|
|
|
1
|
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
-
# All rights reserved.
|
|
3
|
-
#
|
|
4
|
-
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
|
-
# the root directory of this source tree.
|
|
6
|
-
|
|
7
|
-
import math
|
|
8
|
-
|
|
9
|
-
import fairscale.nn.model_parallel.initialize as fs_init
|
|
10
|
-
import torch
|
|
11
|
-
import torch.nn.functional as F
|
|
12
|
-
from fairscale.nn.model_parallel.layers import (
|
|
13
|
-
ColumnParallelLinear,
|
|
14
|
-
RowParallelLinear,
|
|
15
|
-
VocabParallelEmbedding,
|
|
16
|
-
)
|
|
17
|
-
from torch import nn
|
|
18
|
-
|
|
19
|
-
from .args import ModelArgs
|
|
20
|
-
|
|
21
|
-
# **NOTE**: This code is not runnable without installing `torch` and `fairscale`
|
|
22
|
-
# dependencies. These dependencies are not part of the default dependencies
|
|
23
|
-
# (requirements.txt) of the `llama-models` package.
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
class RMSNorm(torch.nn.Module):
|
|
27
|
-
def __init__(self, dim: int, eps: float = 1e-6):
|
|
28
|
-
super().__init__()
|
|
29
|
-
self.eps = eps
|
|
30
|
-
self.weight = nn.Parameter(torch.ones(dim))
|
|
31
|
-
|
|
32
|
-
def _norm(self, x):
|
|
33
|
-
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
|
34
|
-
|
|
35
|
-
def forward(self, x):
|
|
36
|
-
output = self._norm(x.float()).type_as(x)
|
|
37
|
-
return output * self.weight
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
def apply_scaling(freqs: torch.Tensor) -> torch.Tensor:
|
|
41
|
-
# Values obtained from grid search
|
|
42
|
-
scale_factor = 8
|
|
43
|
-
low_freq_factor = 1
|
|
44
|
-
high_freq_factor = 4
|
|
45
|
-
old_context_len = 8192 # original llama3 length
|
|
46
|
-
|
|
47
|
-
low_freq_wavelen = old_context_len / low_freq_factor
|
|
48
|
-
high_freq_wavelen = old_context_len / high_freq_factor
|
|
49
|
-
|
|
50
|
-
wavelen = 2 * torch.pi / freqs
|
|
51
|
-
new_freqs = torch.where(wavelen > low_freq_wavelen, freqs / scale_factor, freqs)
|
|
52
|
-
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
|
|
53
|
-
return torch.where(
|
|
54
|
-
(wavelen >= high_freq_wavelen) & (wavelen <= low_freq_wavelen),
|
|
55
|
-
(1 - smooth) * new_freqs / scale_factor + smooth * new_freqs,
|
|
56
|
-
new_freqs,
|
|
57
|
-
)
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False):
|
|
61
|
-
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
|
62
|
-
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
|
|
63
|
-
if use_scaled:
|
|
64
|
-
freqs = apply_scaling(freqs)
|
|
65
|
-
freqs = torch.outer(t, freqs)
|
|
66
|
-
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
|
67
|
-
return freqs_cis
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
|
71
|
-
ndim = x.ndim
|
|
72
|
-
assert 0 <= 1 < ndim
|
|
73
|
-
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
|
|
74
|
-
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
|
75
|
-
return freqs_cis.view(*shape)
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
def apply_rotary_emb(
|
|
79
|
-
xq: torch.Tensor,
|
|
80
|
-
xk: torch.Tensor,
|
|
81
|
-
freqs_cis: torch.Tensor,
|
|
82
|
-
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
83
|
-
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
|
84
|
-
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
|
85
|
-
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
|
86
|
-
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
|
|
87
|
-
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
|
88
|
-
return xq_out.type_as(xq), xk_out.type_as(xk)
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
92
|
-
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
|
|
93
|
-
bs, slen, n_kv_heads, head_dim = x.shape
|
|
94
|
-
if n_rep == 1:
|
|
95
|
-
return x
|
|
96
|
-
return (
|
|
97
|
-
x[:, :, :, None, :]
|
|
98
|
-
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
|
|
99
|
-
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
|
|
100
|
-
)
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
class Attention(nn.Module):
|
|
104
|
-
def __init__(self, args: ModelArgs):
|
|
105
|
-
super().__init__()
|
|
106
|
-
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
|
107
|
-
world_size = fs_init.get_model_parallel_world_size()
|
|
108
|
-
self.n_local_heads = args.n_heads // world_size
|
|
109
|
-
self.n_local_kv_heads = self.n_kv_heads // world_size
|
|
110
|
-
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
|
111
|
-
self.head_dim = args.dim // args.n_heads
|
|
112
|
-
|
|
113
|
-
self.wq = ColumnParallelLinear(
|
|
114
|
-
args.dim,
|
|
115
|
-
args.n_heads * self.head_dim,
|
|
116
|
-
bias=False,
|
|
117
|
-
gather_output=False,
|
|
118
|
-
init_method=lambda x: x,
|
|
119
|
-
)
|
|
120
|
-
self.wk = ColumnParallelLinear(
|
|
121
|
-
args.dim,
|
|
122
|
-
self.n_kv_heads * self.head_dim,
|
|
123
|
-
bias=False,
|
|
124
|
-
gather_output=False,
|
|
125
|
-
init_method=lambda x: x,
|
|
126
|
-
)
|
|
127
|
-
self.wv = ColumnParallelLinear(
|
|
128
|
-
args.dim,
|
|
129
|
-
self.n_kv_heads * self.head_dim,
|
|
130
|
-
bias=False,
|
|
131
|
-
gather_output=False,
|
|
132
|
-
init_method=lambda x: x,
|
|
133
|
-
)
|
|
134
|
-
self.wo = RowParallelLinear(
|
|
135
|
-
args.n_heads * self.head_dim,
|
|
136
|
-
args.dim,
|
|
137
|
-
bias=False,
|
|
138
|
-
input_is_parallel=True,
|
|
139
|
-
init_method=lambda x: x,
|
|
140
|
-
)
|
|
141
|
-
|
|
142
|
-
self.cache_k = torch.zeros(
|
|
143
|
-
(
|
|
144
|
-
args.max_batch_size,
|
|
145
|
-
args.max_seq_len,
|
|
146
|
-
self.n_local_kv_heads,
|
|
147
|
-
self.head_dim,
|
|
148
|
-
)
|
|
149
|
-
)
|
|
150
|
-
self.cache_v = torch.zeros(
|
|
151
|
-
(
|
|
152
|
-
args.max_batch_size,
|
|
153
|
-
args.max_seq_len,
|
|
154
|
-
self.n_local_kv_heads,
|
|
155
|
-
self.head_dim,
|
|
156
|
-
)
|
|
157
|
-
)
|
|
158
|
-
|
|
159
|
-
def forward(
|
|
160
|
-
self,
|
|
161
|
-
x: torch.Tensor,
|
|
162
|
-
start_pos: int,
|
|
163
|
-
freqs_cis: torch.Tensor,
|
|
164
|
-
mask: torch.Tensor | None,
|
|
165
|
-
):
|
|
166
|
-
bsz, seqlen, _ = x.shape
|
|
167
|
-
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
|
168
|
-
|
|
169
|
-
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
|
170
|
-
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
|
171
|
-
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
|
172
|
-
|
|
173
|
-
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
|
|
174
|
-
|
|
175
|
-
self.cache_k = self.cache_k.to(xq)
|
|
176
|
-
self.cache_v = self.cache_v.to(xq)
|
|
177
|
-
|
|
178
|
-
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
|
|
179
|
-
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
|
|
180
|
-
|
|
181
|
-
keys = self.cache_k[:bsz, : start_pos + seqlen]
|
|
182
|
-
values = self.cache_v[:bsz, : start_pos + seqlen]
|
|
183
|
-
|
|
184
|
-
# repeat k/v heads if n_kv_heads < n_heads
|
|
185
|
-
keys = repeat_kv(keys, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)
|
|
186
|
-
values = repeat_kv(values, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)
|
|
187
|
-
|
|
188
|
-
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
|
|
189
|
-
keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
|
|
190
|
-
values = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
|
|
191
|
-
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
|
|
192
|
-
if mask is not None:
|
|
193
|
-
scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
|
|
194
|
-
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
|
|
195
|
-
output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim)
|
|
196
|
-
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
|
|
197
|
-
return self.wo(output)
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
class FeedForward(nn.Module):
|
|
201
|
-
def __init__(
|
|
202
|
-
self,
|
|
203
|
-
dim: int,
|
|
204
|
-
hidden_dim: int,
|
|
205
|
-
multiple_of: int,
|
|
206
|
-
ffn_dim_multiplier: float | None,
|
|
207
|
-
):
|
|
208
|
-
super().__init__()
|
|
209
|
-
hidden_dim = int(2 * hidden_dim / 3)
|
|
210
|
-
# custom dim factor multiplier
|
|
211
|
-
if ffn_dim_multiplier is not None:
|
|
212
|
-
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
|
213
|
-
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
|
214
|
-
|
|
215
|
-
self.w1 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
|
|
216
|
-
self.w2 = RowParallelLinear(hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x)
|
|
217
|
-
self.w3 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
|
|
218
|
-
|
|
219
|
-
def forward(self, x):
|
|
220
|
-
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
class TransformerBlock(nn.Module):
|
|
224
|
-
def __init__(self, layer_id: int, args: ModelArgs):
|
|
225
|
-
super().__init__()
|
|
226
|
-
self.n_heads = args.n_heads
|
|
227
|
-
self.dim = args.dim
|
|
228
|
-
self.head_dim = args.dim // args.n_heads
|
|
229
|
-
self.attention = Attention(args)
|
|
230
|
-
self.feed_forward = FeedForward(
|
|
231
|
-
dim=args.dim,
|
|
232
|
-
hidden_dim=4 * args.dim,
|
|
233
|
-
multiple_of=args.multiple_of,
|
|
234
|
-
ffn_dim_multiplier=args.ffn_dim_multiplier,
|
|
235
|
-
)
|
|
236
|
-
self.layer_id = layer_id
|
|
237
|
-
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
|
238
|
-
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
|
239
|
-
|
|
240
|
-
def forward(
|
|
241
|
-
self,
|
|
242
|
-
x: torch.Tensor,
|
|
243
|
-
start_pos: int,
|
|
244
|
-
freqs_cis: torch.Tensor,
|
|
245
|
-
mask: torch.Tensor | None,
|
|
246
|
-
):
|
|
247
|
-
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
|
|
248
|
-
out = h + self.feed_forward(self.ffn_norm(h))
|
|
249
|
-
return out
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
class Transformer(nn.Module):
|
|
253
|
-
def __init__(self, params: ModelArgs):
|
|
254
|
-
super().__init__()
|
|
255
|
-
self.params = params
|
|
256
|
-
self.vocab_size = params.vocab_size
|
|
257
|
-
self.n_layers = params.n_layers
|
|
258
|
-
|
|
259
|
-
self.tok_embeddings = VocabParallelEmbedding(params.vocab_size, params.dim, init_method=lambda x: x)
|
|
260
|
-
|
|
261
|
-
self.layers = torch.nn.ModuleList()
|
|
262
|
-
for layer_id in range(params.n_layers):
|
|
263
|
-
self.layers.append(TransformerBlock(layer_id, params))
|
|
264
|
-
|
|
265
|
-
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
|
|
266
|
-
self.output = ColumnParallelLinear(params.dim, params.vocab_size, bias=False, init_method=lambda x: x)
|
|
267
|
-
|
|
268
|
-
self.freqs_cis = precompute_freqs_cis(
|
|
269
|
-
params.dim // params.n_heads,
|
|
270
|
-
params.max_seq_len * 2,
|
|
271
|
-
params.rope_theta,
|
|
272
|
-
params.use_scaled_rope,
|
|
273
|
-
)
|
|
274
|
-
|
|
275
|
-
@torch.inference_mode()
|
|
276
|
-
def forward(self, tokens: torch.Tensor, start_pos: int):
|
|
277
|
-
_bsz, seqlen = tokens.shape
|
|
278
|
-
h = self.tok_embeddings(tokens)
|
|
279
|
-
self.freqs_cis = self.freqs_cis.to(h.device)
|
|
280
|
-
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
|
|
281
|
-
|
|
282
|
-
mask = None
|
|
283
|
-
if seqlen > 1:
|
|
284
|
-
mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)
|
|
285
|
-
|
|
286
|
-
mask = torch.triu(mask, diagonal=1)
|
|
287
|
-
|
|
288
|
-
# https://github.com/pytorch/pytorch/issues/100005
|
|
289
|
-
# torch.triu is buggy when the device is mps: filled values are
|
|
290
|
-
# nan instead of 0.
|
|
291
|
-
if mask.device.type == torch.device("mps").type:
|
|
292
|
-
mask = torch.nan_to_num(mask, nan=0.0)
|
|
293
|
-
|
|
294
|
-
# When performing key-value caching, we compute the attention scores
|
|
295
|
-
# only for the new sequence. Thus, the matrix of scores is of size
|
|
296
|
-
# (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
|
|
297
|
-
# j > cache_len + i, since row i corresponds to token cache_len + i.
|
|
298
|
-
mask = torch.hstack([torch.zeros((seqlen, start_pos), device=tokens.device), mask]).type_as(h)
|
|
299
|
-
|
|
300
|
-
for layer in self.layers:
|
|
301
|
-
h = layer(h, start_pos, freqs_cis, mask)
|
|
302
|
-
h = self.norm(h)
|
|
303
|
-
output = self.output(h).float()
|
|
304
|
-
return output
|
|
@@ -1,12 +0,0 @@
|
|
|
1
|
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
-
# All rights reserved.
|
|
3
|
-
#
|
|
4
|
-
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
|
-
# the root directory of this source tree.
|
|
6
|
-
|
|
7
|
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
8
|
-
# All rights reserved.
|
|
9
|
-
#
|
|
10
|
-
# This source code is licensed under the terms described in the LICENSE file in
|
|
11
|
-
# top-level folder for each specific model found within the models/ directory at
|
|
12
|
-
# the top-level of this source tree.
|
|
@@ -1,180 +0,0 @@
|
|
|
1
|
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
-
# All rights reserved.
|
|
3
|
-
#
|
|
4
|
-
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
|
-
# the root directory of this source tree.
|
|
6
|
-
|
|
7
|
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
8
|
-
# All rights reserved.
|
|
9
|
-
#
|
|
10
|
-
# This source code is licensed under the terms described in the LICENSE file in
|
|
11
|
-
# top-level folder for each specific model found within the models/ directory at
|
|
12
|
-
# the top-level of this source tree.
|
|
13
|
-
|
|
14
|
-
# Copyright (c) Meta Platforms, Inc. and its affiliates.
|
|
15
|
-
import math
|
|
16
|
-
|
|
17
|
-
import torch
|
|
18
|
-
import torch.nn.functional as F
|
|
19
|
-
|
|
20
|
-
from llama_stack.log import get_logger
|
|
21
|
-
|
|
22
|
-
from .utils import get_negative_inf_value, to_2tuple
|
|
23
|
-
|
|
24
|
-
logger = get_logger(name=__name__, category="models::llama")
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
def resize_local_position_embedding(orig_pos_embed, grid_size):
|
|
28
|
-
"""
|
|
29
|
-
Resize position embedding for vision encoder.
|
|
30
|
-
Original position embedding is [n_tiles * n_tiles + 1, dim]
|
|
31
|
-
New position embedding will be [grid_size[0] * grid_size[1] + 1, dim]
|
|
32
|
-
"""
|
|
33
|
-
new_grid_size = to_2tuple(grid_size)
|
|
34
|
-
orig_grid_size = to_2tuple(int(math.sqrt(len(orig_pos_embed) - 1)))
|
|
35
|
-
|
|
36
|
-
new_pos_emb_tok, new_pos_emb_img = (
|
|
37
|
-
orig_pos_embed[:1],
|
|
38
|
-
orig_pos_embed[1:],
|
|
39
|
-
)
|
|
40
|
-
logger.info(f"resizing position embedding grid-size from {orig_grid_size} to {new_grid_size}")
|
|
41
|
-
|
|
42
|
-
new_pos_emb_img = new_pos_emb_img.reshape(1, orig_grid_size[0], orig_grid_size[1], -1).permute(0, 3, 1, 2)
|
|
43
|
-
|
|
44
|
-
new_pos_emb_img = F.interpolate(
|
|
45
|
-
new_pos_emb_img,
|
|
46
|
-
size=new_grid_size,
|
|
47
|
-
mode="bilinear",
|
|
48
|
-
align_corners=True,
|
|
49
|
-
)
|
|
50
|
-
new_pos_emb_img = new_pos_emb_img.permute(0, 2, 3, 1).reshape(1, new_grid_size[0] * new_grid_size[1], -1)[0]
|
|
51
|
-
new_pos_embed = torch.cat([new_pos_emb_tok, new_pos_emb_img], dim=0)
|
|
52
|
-
return new_pos_embed
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
def initialize_global_position_embedding_from_local(pos_and_cls_embed, grid_size, x_scale, y_scale):
|
|
56
|
-
"""
|
|
57
|
-
Takes a local position embedding for vision encoder and uses it
|
|
58
|
-
to initialize the global position embedding.
|
|
59
|
-
Input: local position embedding of shape [grid_size[0] * grid_size[1] + 1, dim]
|
|
60
|
-
Returns: global position embedding of shape [x_scale, y_scale, grid_size[0] * grid_size[1] + 1, dim]
|
|
61
|
-
Here x_scale and y_scale are the number of tiles along x-axis and y-axis respectively.
|
|
62
|
-
"""
|
|
63
|
-
pos_embed = pos_and_cls_embed[1:]
|
|
64
|
-
cls_embed = pos_and_cls_embed[0].view(1, 1, 1, -1)
|
|
65
|
-
grid_size = to_2tuple(grid_size)
|
|
66
|
-
new_pos_emb_img = pos_embed.reshape(1, grid_size[0], grid_size[1], -1).permute(0, 3, 1, 2)
|
|
67
|
-
new_grid_size = (x_scale * grid_size[0], y_scale * grid_size[1])
|
|
68
|
-
new_pos_emb_img = F.interpolate(
|
|
69
|
-
new_pos_emb_img,
|
|
70
|
-
size=new_grid_size,
|
|
71
|
-
mode="bilinear",
|
|
72
|
-
align_corners=True,
|
|
73
|
-
)
|
|
74
|
-
new_pos_emb_img = new_pos_emb_img.permute(0, 2, 3, 1)
|
|
75
|
-
new_pos_emb_img = new_pos_emb_img.view(x_scale, grid_size[0], y_scale, grid_size[1], -1)
|
|
76
|
-
new_pos_emb_img = new_pos_emb_img.permute(0, 2, 1, 3, 4).contiguous()
|
|
77
|
-
new_pos_emb_img = new_pos_emb_img.reshape(x_scale, y_scale, grid_size[0] * grid_size[1], -1)
|
|
78
|
-
cls_embed = cls_embed.expand(x_scale, y_scale, -1, -1)
|
|
79
|
-
pos_and_cls_embed = torch.cat([cls_embed, new_pos_emb_img], dim=2)
|
|
80
|
-
return pos_and_cls_embed
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
def resize_global_position_embedding(pos_and_cls_embed, grid_size, x_scale, y_scale):
|
|
84
|
-
"""
|
|
85
|
-
Takes a global position embedding for vision encoder and resizes it to new size.
|
|
86
|
-
Input: global position embedding of shape [x_old, y_old, old_grid_size[0] * old_grid_size[1] + 1, dim]
|
|
87
|
-
Returns: global position embedding of shape [x_scale, y_scale, grid_size[0] * grid_size[1] + 1, dim]
|
|
88
|
-
Here x_scale and y_scale are the number of tiles along x-axis and y-axis respectively.
|
|
89
|
-
"""
|
|
90
|
-
# first remove cls token
|
|
91
|
-
pos_embed = pos_and_cls_embed[:, :, 1:]
|
|
92
|
-
cls_embed = pos_and_cls_embed[:, :, 0].unsqueeze(2)
|
|
93
|
-
|
|
94
|
-
xs_old, ys_old, ntok, dim = pos_embed.shape
|
|
95
|
-
old_grid_size = int(math.sqrt(ntok))
|
|
96
|
-
|
|
97
|
-
# move to correct form for interpolation
|
|
98
|
-
pos_embed = pos_embed.view(xs_old, ys_old, old_grid_size, old_grid_size, dim)
|
|
99
|
-
pos_embed = pos_embed.permute(0, 2, 1, 3, 4).contiguous()
|
|
100
|
-
pos_embed = pos_embed.view(xs_old * old_grid_size, ys_old * old_grid_size, dim)
|
|
101
|
-
pos_embed = pos_embed.unsqueeze(0)
|
|
102
|
-
|
|
103
|
-
# interpolate
|
|
104
|
-
new_size = (grid_size[0] * x_scale, grid_size[1] * y_scale)
|
|
105
|
-
pos_embed = pos_embed.permute(0, 3, 1, 2)
|
|
106
|
-
pos_embed_resized = F.interpolate(
|
|
107
|
-
pos_embed,
|
|
108
|
-
size=new_size,
|
|
109
|
-
mode="bilinear",
|
|
110
|
-
align_corners=True,
|
|
111
|
-
)
|
|
112
|
-
pos_embed = pos_embed_resized.permute(0, 2, 3, 1)[0]
|
|
113
|
-
|
|
114
|
-
# move it back in place
|
|
115
|
-
pos_embed = pos_embed.view(x_scale, grid_size[0], y_scale, grid_size[1], dim)
|
|
116
|
-
pos_embed = pos_embed.permute(0, 2, 1, 3, 4).contiguous()
|
|
117
|
-
pos_embed = pos_embed.view(x_scale, y_scale, grid_size[0] * grid_size[1], dim)
|
|
118
|
-
|
|
119
|
-
# interpolate cls token
|
|
120
|
-
cls_embed = cls_embed.permute(2, 3, 0, 1)
|
|
121
|
-
cls_embed_resized = F.interpolate(
|
|
122
|
-
cls_embed,
|
|
123
|
-
size=(x_scale, y_scale),
|
|
124
|
-
mode="bilinear",
|
|
125
|
-
align_corners=True,
|
|
126
|
-
)
|
|
127
|
-
cls_embed = cls_embed_resized.permute(2, 3, 0, 1)
|
|
128
|
-
# add cls token back in
|
|
129
|
-
pos_and_cls_embed = torch.cat([cls_embed, pos_embed], dim=2)
|
|
130
|
-
|
|
131
|
-
return pos_and_cls_embed
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
def build_encoder_attention_mask(
|
|
135
|
-
x: torch.Tensor,
|
|
136
|
-
ar: torch.Tensor,
|
|
137
|
-
ntok: int,
|
|
138
|
-
num_chunks: int,
|
|
139
|
-
n_heads: int,
|
|
140
|
-
):
|
|
141
|
-
"""
|
|
142
|
-
Build vision encoder attention mask that omits padding tokens.
|
|
143
|
-
"""
|
|
144
|
-
masks_list: list[torch.Tensor] = []
|
|
145
|
-
for arx in ar:
|
|
146
|
-
mask_i = torch.ones((num_chunks, x.shape[2], 1), dtype=x.dtype)
|
|
147
|
-
mask_i[: arx[0] * arx[1], :ntok] = 0
|
|
148
|
-
mask_i = mask_i.view(num_chunks * x.shape[2], -1)
|
|
149
|
-
mask_i = mask_i @ mask_i.T * get_negative_inf_value(x.dtype)
|
|
150
|
-
mask_i = mask_i.unsqueeze(0)
|
|
151
|
-
masks_list.append(mask_i)
|
|
152
|
-
masks = torch.stack(masks_list).to(x.device).expand(-1, n_heads, -1, -1)
|
|
153
|
-
return masks
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
def expand_num_tokens_to_mult8(x):
|
|
157
|
-
num_pad_tokens = 8 - (x.shape[-2] % 8)
|
|
158
|
-
if num_pad_tokens == 0:
|
|
159
|
-
return x, 0
|
|
160
|
-
else:
|
|
161
|
-
return (
|
|
162
|
-
torch.cat(
|
|
163
|
-
[
|
|
164
|
-
x,
|
|
165
|
-
torch.zeros(
|
|
166
|
-
(x.shape[0], x.shape[1], num_pad_tokens, x.shape[-1]),
|
|
167
|
-
dtype=x.dtype,
|
|
168
|
-
device=x.device,
|
|
169
|
-
),
|
|
170
|
-
],
|
|
171
|
-
dim=-2,
|
|
172
|
-
),
|
|
173
|
-
num_pad_tokens,
|
|
174
|
-
)
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
def contract_num_tokens_from_mult8(x, num_pad_tokens):
|
|
178
|
-
if num_pad_tokens == 0:
|
|
179
|
-
return x
|
|
180
|
-
return x[:, :, :-num_pad_tokens]
|