llama-stack 0.4.4__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (155) hide show
  1. llama_stack/cli/stack/_list_deps.py +11 -7
  2. llama_stack/cli/stack/run.py +3 -25
  3. llama_stack/core/access_control/datatypes.py +78 -0
  4. llama_stack/core/configure.py +2 -2
  5. llama_stack/{distributions/meta-reference-gpu → core/connectors}/__init__.py +3 -1
  6. llama_stack/core/connectors/connectors.py +162 -0
  7. llama_stack/core/conversations/conversations.py +61 -58
  8. llama_stack/core/datatypes.py +54 -8
  9. llama_stack/core/library_client.py +60 -13
  10. llama_stack/core/prompts/prompts.py +43 -42
  11. llama_stack/core/routers/datasets.py +20 -17
  12. llama_stack/core/routers/eval_scoring.py +143 -53
  13. llama_stack/core/routers/inference.py +20 -9
  14. llama_stack/core/routers/safety.py +30 -42
  15. llama_stack/core/routers/vector_io.py +15 -7
  16. llama_stack/core/routing_tables/models.py +42 -3
  17. llama_stack/core/routing_tables/scoring_functions.py +19 -19
  18. llama_stack/core/routing_tables/shields.py +20 -17
  19. llama_stack/core/routing_tables/vector_stores.py +8 -5
  20. llama_stack/core/server/auth.py +192 -17
  21. llama_stack/core/server/fastapi_router_registry.py +40 -5
  22. llama_stack/core/server/server.py +24 -5
  23. llama_stack/core/stack.py +54 -10
  24. llama_stack/core/storage/datatypes.py +9 -0
  25. llama_stack/core/store/registry.py +1 -1
  26. llama_stack/core/utils/exec.py +2 -2
  27. llama_stack/core/utils/type_inspection.py +16 -2
  28. llama_stack/distributions/dell/config.yaml +4 -1
  29. llama_stack/distributions/dell/run-with-safety.yaml +4 -1
  30. llama_stack/distributions/nvidia/config.yaml +4 -1
  31. llama_stack/distributions/nvidia/run-with-safety.yaml +4 -1
  32. llama_stack/distributions/oci/config.yaml +4 -1
  33. llama_stack/distributions/open-benchmark/config.yaml +9 -1
  34. llama_stack/distributions/postgres-demo/config.yaml +1 -1
  35. llama_stack/distributions/starter/build.yaml +62 -0
  36. llama_stack/distributions/starter/config.yaml +22 -3
  37. llama_stack/distributions/starter/run-with-postgres-store.yaml +22 -3
  38. llama_stack/distributions/starter/starter.py +13 -1
  39. llama_stack/distributions/starter-gpu/build.yaml +62 -0
  40. llama_stack/distributions/starter-gpu/config.yaml +22 -3
  41. llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +22 -3
  42. llama_stack/distributions/template.py +10 -2
  43. llama_stack/distributions/watsonx/config.yaml +4 -1
  44. llama_stack/log.py +1 -0
  45. llama_stack/providers/inline/agents/meta_reference/__init__.py +1 -0
  46. llama_stack/providers/inline/agents/meta_reference/agents.py +57 -61
  47. llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +49 -51
  48. llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +94 -22
  49. llama_stack/providers/inline/agents/meta_reference/responses/types.py +2 -1
  50. llama_stack/providers/inline/agents/meta_reference/responses/utils.py +4 -1
  51. llama_stack/providers/inline/agents/meta_reference/safety.py +2 -2
  52. llama_stack/providers/inline/batches/reference/batches.py +2 -1
  53. llama_stack/providers/inline/eval/meta_reference/eval.py +40 -32
  54. llama_stack/providers/inline/post_training/huggingface/post_training.py +33 -38
  55. llama_stack/providers/inline/post_training/huggingface/utils.py +2 -5
  56. llama_stack/providers/inline/post_training/torchtune/post_training.py +28 -33
  57. llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +2 -4
  58. llama_stack/providers/inline/safety/code_scanner/code_scanner.py +12 -15
  59. llama_stack/providers/inline/safety/llama_guard/llama_guard.py +15 -18
  60. llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +11 -17
  61. llama_stack/providers/inline/scoring/basic/scoring.py +13 -17
  62. llama_stack/providers/inline/scoring/braintrust/braintrust.py +15 -15
  63. llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +13 -17
  64. llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +1 -1
  65. llama_stack/providers/registry/agents.py +1 -0
  66. llama_stack/providers/registry/inference.py +1 -9
  67. llama_stack/providers/registry/vector_io.py +136 -16
  68. llama_stack/providers/remote/eval/nvidia/eval.py +22 -21
  69. llama_stack/providers/remote/files/s3/config.py +5 -3
  70. llama_stack/providers/remote/files/s3/files.py +2 -2
  71. llama_stack/providers/remote/inference/gemini/gemini.py +4 -0
  72. llama_stack/providers/remote/inference/openai/openai.py +2 -0
  73. llama_stack/providers/remote/inference/together/together.py +4 -0
  74. llama_stack/providers/remote/inference/vertexai/config.py +3 -3
  75. llama_stack/providers/remote/inference/vertexai/vertexai.py +5 -2
  76. llama_stack/providers/remote/inference/vllm/config.py +37 -18
  77. llama_stack/providers/remote/inference/vllm/vllm.py +0 -3
  78. llama_stack/providers/remote/inference/watsonx/watsonx.py +4 -0
  79. llama_stack/providers/remote/post_training/nvidia/post_training.py +31 -33
  80. llama_stack/providers/remote/safety/bedrock/bedrock.py +10 -27
  81. llama_stack/providers/remote/safety/nvidia/nvidia.py +9 -25
  82. llama_stack/providers/remote/safety/sambanova/sambanova.py +13 -11
  83. llama_stack/providers/remote/vector_io/elasticsearch/__init__.py +17 -0
  84. llama_stack/providers/remote/vector_io/elasticsearch/config.py +32 -0
  85. llama_stack/providers/remote/vector_io/elasticsearch/elasticsearch.py +463 -0
  86. llama_stack/providers/remote/vector_io/oci/__init__.py +22 -0
  87. llama_stack/providers/remote/vector_io/oci/config.py +41 -0
  88. llama_stack/providers/remote/vector_io/oci/oci26ai.py +595 -0
  89. llama_stack/providers/remote/vector_io/pgvector/config.py +69 -2
  90. llama_stack/providers/remote/vector_io/pgvector/pgvector.py +255 -6
  91. llama_stack/providers/remote/vector_io/qdrant/qdrant.py +62 -38
  92. llama_stack/providers/utils/bedrock/client.py +3 -3
  93. llama_stack/providers/utils/bedrock/config.py +7 -7
  94. llama_stack/providers/utils/inference/embedding_mixin.py +4 -0
  95. llama_stack/providers/utils/inference/http_client.py +239 -0
  96. llama_stack/providers/utils/inference/litellm_openai_mixin.py +5 -0
  97. llama_stack/providers/utils/inference/model_registry.py +148 -2
  98. llama_stack/providers/utils/inference/openai_compat.py +2 -1
  99. llama_stack/providers/utils/inference/openai_mixin.py +41 -2
  100. llama_stack/providers/utils/memory/openai_vector_store_mixin.py +92 -5
  101. llama_stack/providers/utils/memory/vector_store.py +46 -19
  102. llama_stack/providers/utils/responses/responses_store.py +7 -7
  103. llama_stack/providers/utils/safety.py +114 -0
  104. llama_stack/providers/utils/tools/mcp.py +44 -3
  105. llama_stack/testing/api_recorder.py +9 -3
  106. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/METADATA +14 -2
  107. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/RECORD +111 -144
  108. llama_stack/distributions/meta-reference-gpu/config.yaml +0 -140
  109. llama_stack/distributions/meta-reference-gpu/doc_template.md +0 -119
  110. llama_stack/distributions/meta-reference-gpu/meta_reference.py +0 -163
  111. llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml +0 -155
  112. llama_stack/models/llama/hadamard_utils.py +0 -88
  113. llama_stack/models/llama/llama3/args.py +0 -74
  114. llama_stack/models/llama/llama3/dog.jpg +0 -0
  115. llama_stack/models/llama/llama3/generation.py +0 -378
  116. llama_stack/models/llama/llama3/model.py +0 -304
  117. llama_stack/models/llama/llama3/multimodal/__init__.py +0 -12
  118. llama_stack/models/llama/llama3/multimodal/encoder_utils.py +0 -180
  119. llama_stack/models/llama/llama3/multimodal/image_transform.py +0 -409
  120. llama_stack/models/llama/llama3/multimodal/model.py +0 -1430
  121. llama_stack/models/llama/llama3/multimodal/utils.py +0 -26
  122. llama_stack/models/llama/llama3/pasta.jpeg +0 -0
  123. llama_stack/models/llama/llama3/quantization/__init__.py +0 -5
  124. llama_stack/models/llama/llama3/quantization/loader.py +0 -316
  125. llama_stack/models/llama/llama3_1/__init__.py +0 -12
  126. llama_stack/models/llama/llama3_1/prompt_format.md +0 -358
  127. llama_stack/models/llama/llama3_1/prompts.py +0 -258
  128. llama_stack/models/llama/llama3_2/__init__.py +0 -5
  129. llama_stack/models/llama/llama3_2/prompts_text.py +0 -229
  130. llama_stack/models/llama/llama3_2/prompts_vision.py +0 -126
  131. llama_stack/models/llama/llama3_2/text_prompt_format.md +0 -286
  132. llama_stack/models/llama/llama3_2/vision_prompt_format.md +0 -141
  133. llama_stack/models/llama/llama3_3/__init__.py +0 -5
  134. llama_stack/models/llama/llama3_3/prompts.py +0 -259
  135. llama_stack/models/llama/llama4/args.py +0 -107
  136. llama_stack/models/llama/llama4/ffn.py +0 -58
  137. llama_stack/models/llama/llama4/moe.py +0 -214
  138. llama_stack/models/llama/llama4/preprocess.py +0 -435
  139. llama_stack/models/llama/llama4/quantization/__init__.py +0 -5
  140. llama_stack/models/llama/llama4/quantization/loader.py +0 -226
  141. llama_stack/models/llama/llama4/vision/__init__.py +0 -5
  142. llama_stack/models/llama/llama4/vision/embedding.py +0 -210
  143. llama_stack/models/llama/llama4/vision/encoder.py +0 -412
  144. llama_stack/models/llama/quantize_impls.py +0 -316
  145. llama_stack/providers/inline/inference/meta_reference/__init__.py +0 -20
  146. llama_stack/providers/inline/inference/meta_reference/common.py +0 -24
  147. llama_stack/providers/inline/inference/meta_reference/config.py +0 -68
  148. llama_stack/providers/inline/inference/meta_reference/generators.py +0 -201
  149. llama_stack/providers/inline/inference/meta_reference/inference.py +0 -542
  150. llama_stack/providers/inline/inference/meta_reference/model_parallel.py +0 -77
  151. llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +0 -353
  152. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/WHEEL +0 -0
  153. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/entry_points.txt +0 -0
  154. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
  155. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -1,74 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the terms described in the LICENSE file in
5
- # the root directory of this source tree.
6
-
7
- from dataclasses import dataclass
8
- from enum import Enum
9
-
10
-
11
- class QuantizationScheme(Enum):
12
- int4_weight_int8_dynamic_activation = "int4_weight_int8_dynamic_activation"
13
-
14
-
15
- @dataclass
16
- class QuantizationArgs:
17
- scheme: QuantizationScheme | None = None
18
- group_size: int | None = None
19
- spinquant: bool = False
20
-
21
- def __init__(self, **kwargs):
22
- for k, v in kwargs.items():
23
- if k == "scheme":
24
- setattr(self, k, QuantizationScheme(v))
25
- else:
26
- if hasattr(self, k):
27
- setattr(self, k, v)
28
-
29
-
30
- @dataclass
31
- class LoRAArgs:
32
- rank: int
33
- scale: float
34
-
35
-
36
- @dataclass
37
- class ModelArgs:
38
- dim: int = 4096
39
- n_layers: int = 32
40
- n_heads: int = 32
41
- n_kv_heads: int | None = None
42
- vocab_size: int = -1
43
- multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
44
- ffn_dim_multiplier: float | None = None
45
- norm_eps: float = 1e-5
46
- rope_theta: float = 500000
47
- use_scaled_rope: bool = False
48
-
49
- max_batch_size: int = 32
50
- max_seq_len: int = 2048
51
-
52
- # vision model params
53
- vision_chunk_size: int = -1 # image resolution for image models
54
- vision_max_num_chunks: int = 4
55
- vision_num_cross_attention_layers: int = -1
56
-
57
- quantization_args: QuantizationArgs | None = None
58
- lora_args: LoRAArgs | None = None
59
-
60
- def __init__(self, **kwargs):
61
- for k, v in kwargs.items():
62
- if k == "lora_args":
63
- setattr(self, k, LoRAArgs(**v))
64
- elif k == "quantization_args":
65
- setattr(self, k, QuantizationArgs(**v))
66
- else:
67
- if hasattr(self, k):
68
- setattr(self, k, v)
69
-
70
- if self.n_kv_heads is None:
71
- self.n_kv_heads = self.n_heads
72
- assert self.n_kv_heads <= self.n_heads
73
- assert self.n_heads % self.n_kv_heads == 0
74
- assert self.dim % self.n_heads == 0
Binary file
@@ -1,378 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the terms described in the LICENSE file in
5
- # the root directory of this source tree.
6
-
7
- # Copyright (c) Meta Platforms, Inc. and affiliates.
8
- # All rights reserved.
9
- #
10
- # This source code is licensed under the terms described in the LICENSE file in
11
- # top-level folder for each specific model found within the models/ directory at
12
- # the top-level of this source tree.
13
-
14
- import json
15
- import os
16
- import sys
17
- import time
18
- from collections.abc import Callable, Generator
19
- from pathlib import Path
20
-
21
- import torch
22
- import torch.nn.functional as F
23
- from fairscale.nn.model_parallel.initialize import (
24
- initialize_model_parallel,
25
- model_parallel_is_initialized,
26
- )
27
- from termcolor import cprint
28
-
29
- from llama_stack.models.llama.datatypes import ToolPromptFormat
30
-
31
- from ..checkpoint import maybe_reshard_state_dict
32
- from ..datatypes import GenerationResult, QuantizationMode, RawContent, RawMessage
33
- from .args import ModelArgs
34
- from .chat_format import ChatFormat, LLMInput
35
- from .model import Transformer
36
- from .multimodal.model import CrossAttentionTransformer
37
- from .tokenizer import Tokenizer
38
-
39
-
40
- class Llama3:
41
- @staticmethod
42
- def build(
43
- ckpt_dir: str,
44
- max_seq_len: int,
45
- max_batch_size: int,
46
- world_size: int | None = None,
47
- quantization_mode: QuantizationMode | None = None,
48
- seed: int = 1,
49
- device: str = "cuda",
50
- ):
51
- device = torch.device(device)
52
- if (
53
- device.type == "cuda"
54
- and not torch.cuda.is_available()
55
- or device.type == "xpu"
56
- and not torch.xpu.is_available()
57
- ):
58
- raise RuntimeError(f"PyTorch backend for {device.type} device type is not available")
59
-
60
- if not torch.distributed.is_initialized():
61
- if device.type == "cuda":
62
- torch.distributed.init_process_group("nccl")
63
- else:
64
- torch.distributed.init_process_group("gloo")
65
-
66
- if not model_parallel_is_initialized():
67
- if world_size is None:
68
- world_size = int(os.environ.get("WORLD_SIZE", 1))
69
- initialize_model_parallel(world_size)
70
-
71
- local_rank = int(os.environ.get("LOCAL_RANK", 0))
72
- if device.type == "cuda":
73
- torch.cuda.set_device(local_rank)
74
- elif device.type == "xpu":
75
- torch.xpu.set_device(local_rank)
76
-
77
- torch.manual_seed(seed)
78
-
79
- if local_rank > 0:
80
- sys.stdout = open(os.devnull, "w")
81
-
82
- start_time = time.time()
83
-
84
- ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth"))
85
- assert len(ckpt_paths) > 0, f"no checkpoint files found in {ckpt_dir}"
86
- print(f"Loading a checkpoint (shards={len(ckpt_paths)}, current-mp-size={world_size})")
87
- with open(Path(ckpt_dir) / "params.json") as f:
88
- params = json.loads(f.read())
89
-
90
- model_args: ModelArgs = ModelArgs(
91
- max_seq_len=max_seq_len,
92
- max_batch_size=max_batch_size,
93
- **params,
94
- )
95
- tokenizer = Tokenizer.get_instance()
96
-
97
- state_dict = maybe_reshard_state_dict(
98
- ckpt_paths,
99
- n_kv_heads=model_args.n_kv_heads if model_args.n_kv_heads else model_args.n_heads,
100
- )
101
-
102
- assert model_args.vocab_size == tokenizer.n_words
103
-
104
- def build_model():
105
- if model_args.vision_chunk_size > 0:
106
- model = CrossAttentionTransformer(model_args)
107
- model.setup_cache(model_args.max_batch_size, device=device, dtype=torch.get_default_dtype())
108
- else:
109
- model = Transformer(model_args)
110
- return model
111
-
112
- if quantization_mode == QuantizationMode.fp8_mixed or quantization_mode == QuantizationMode.int4_mixed:
113
- from .quantization.loader import convert_to_quantized_model
114
-
115
- torch.set_default_tensor_type(torch.BFloat16Tensor)
116
- model = build_model()
117
- print("Loading state dict...")
118
- model.load_state_dict(state_dict, strict=False)
119
- print("Done...")
120
- model = convert_to_quantized_model(model, ckpt_dir, quantization_mode, device=device)
121
- torch.set_default_device(device)
122
- else:
123
- print(f"Setting default device to {device}")
124
- if device.type == "cuda":
125
- if torch.cuda.is_bf16_supported():
126
- torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
127
- else:
128
- torch.set_default_tensor_type(torch.cuda.Float16Tensor)
129
- elif device.type == "xpu":
130
- if torch.xpu.is_bf16_supported():
131
- torch.set_default_tensor_type(torch.xpu.BFloat16Tensor)
132
- else:
133
- torch.set_default_tensor_type(torch.xpu.Float16Tensor)
134
-
135
- model = build_model()
136
- print("Loading state dict...")
137
- model.load_state_dict(state_dict, strict=True)
138
- model.to(device)
139
- print("Done...")
140
-
141
- print(f"Loaded in {time.time() - start_time:.2f} seconds")
142
-
143
- return Llama3(model, tokenizer, model_args)
144
-
145
- def __init__(
146
- self,
147
- model: Transformer | CrossAttentionTransformer,
148
- tokenizer: Tokenizer,
149
- args: ModelArgs,
150
- ):
151
- self.args = args
152
- self.model = model
153
- self.tokenizer = tokenizer
154
- self.formatter = ChatFormat(tokenizer)
155
-
156
- @torch.inference_mode()
157
- def generate(
158
- self,
159
- llm_inputs: list[LLMInput],
160
- temperature: float = 0.6,
161
- top_p: float = 0.9,
162
- max_gen_len: int | None = None,
163
- logprobs: bool = False,
164
- echo: bool = False,
165
- print_model_input: bool = False,
166
- logits_processor: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
167
- ) -> Generator[list[GenerationResult], None, None]:
168
- if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
169
- max_gen_len = self.args.max_seq_len - 1
170
- params = self.model.params
171
-
172
- print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1"
173
- if print_model_input:
174
- for inp in llm_inputs:
175
- tokens_to_print = [self.formatter.vision_token if t == 128256 else t for t in inp.tokens]
176
- cprint(
177
- "Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n",
178
- "red",
179
- file=sys.stderr,
180
- )
181
- prompt_tokens = [inp.tokens for inp in llm_inputs]
182
-
183
- bsz = len(llm_inputs)
184
- assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
185
-
186
- min_prompt_len = min(len(t) for t in prompt_tokens)
187
- max_prompt_len = max(len(t) for t in prompt_tokens)
188
-
189
- if max_prompt_len >= params.max_seq_len:
190
- cprint(
191
- f"Out of token budget {max_prompt_len} vs {params.max_seq_len}",
192
- color="red",
193
- file=sys.stderr,
194
- )
195
- return
196
-
197
- total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
198
-
199
- pad_id = self.tokenizer.pad_id
200
- tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long)
201
- for k, t in enumerate(prompt_tokens):
202
- tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long)
203
- if logprobs:
204
- token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
205
-
206
- is_vision = not isinstance(self.model, Transformer)
207
- if is_vision:
208
- images = [inp.vision.images if inp.vision is not None else [] for inp in llm_inputs]
209
- mask = [inp.vision.mask if inp.vision is not None else [] for inp in llm_inputs]
210
-
211
- xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks(
212
- batch_images=images,
213
- batch_masks=mask,
214
- total_len=total_len,
215
- device=tokens.device,
216
- )
217
-
218
- eos_reached = torch.tensor([False] * bsz)
219
- input_text_mask = tokens != pad_id
220
-
221
- if echo:
222
- for i in range(max_prompt_len):
223
- results = []
224
- for j, t in enumerate(tokens[:, i]):
225
- results.append(
226
- GenerationResult(
227
- token=t.item(),
228
- text=self.tokenizer.decode([t.item()]),
229
- source="input",
230
- logprobs=(token_logprobs[j, i : i + 1].tolist() if logprobs else None),
231
- batch_idx=j,
232
- finished=False,
233
- ignore_token=t.item() == pad_id,
234
- )
235
- )
236
- yield results
237
-
238
- stop_tokens = torch.tensor(self.tokenizer.stop_tokens)
239
-
240
- prev_pos = 0
241
- for cur_pos in range(min_prompt_len, total_len):
242
- if is_vision:
243
- position_ids = torch.arange(prev_pos, cur_pos, dtype=torch.long)
244
- text_only_inference = all(inp.vision is None for inp in llm_inputs)
245
- logits = self.model.forward(
246
- position_ids,
247
- tokens,
248
- cross_attention_masks,
249
- full_text_row_masked_out_mask,
250
- xattn_caches,
251
- text_only_inference,
252
- )
253
- else:
254
- logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
255
-
256
- if logits_processor is not None:
257
- logits = logits_processor(tokens[:, :cur_pos], logits)
258
-
259
- if temperature > 0:
260
- probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
261
- next_token = sample_top_p(probs, top_p)
262
- else:
263
- next_token = torch.argmax(logits[:, -1], dim=-1)
264
-
265
- next_token = next_token.reshape(-1)
266
- # only replace token if prompt has already been generated
267
- next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token)
268
- tokens[:, cur_pos] = next_token
269
-
270
- target = tokens[:, prev_pos + 1 : cur_pos + 1]
271
- if is_vision:
272
- # the logits space (num_classes) is designed to never contain a media_token
273
- # however our input token stream does contain them. we need to nuke them here
274
- # or else the CUDA kernels will crash with an illegal memory access
275
- vision_tokens = [self.tokenizer.special_tokens["<|image|>"], 128256]
276
- masks = [target.eq(t) for t in vision_tokens]
277
- if len(masks) > 1:
278
- mask = torch.logical_or(*masks)
279
- else:
280
- mask = masks[0]
281
- target[mask] = 0
282
-
283
- if logprobs:
284
- token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
285
- input=logits.transpose(1, 2),
286
- target=target,
287
- reduction="none",
288
- ignore_index=pad_id,
289
- )
290
- eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens))
291
- results = []
292
- for idx, t in enumerate(next_token):
293
- results.append(
294
- GenerationResult(
295
- token=t.item(),
296
- text=self.tokenizer.decode([t.item()]),
297
- source="output",
298
- logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None),
299
- batch_idx=idx,
300
- finished=eos_reached[idx].item(),
301
- ignore_token=cur_pos < len(prompt_tokens[idx]),
302
- )
303
- )
304
- yield results
305
-
306
- prev_pos = cur_pos
307
- if all(eos_reached):
308
- break
309
-
310
- def completion(
311
- self,
312
- contents: list[RawContent],
313
- temperature: float = 0.6,
314
- top_p: float = 0.9,
315
- max_gen_len: int | None = None,
316
- logprobs: bool = False,
317
- echo: bool = False,
318
- ) -> Generator[list[GenerationResult], None, None]:
319
- model_inputs = [self.formatter.encode_content(c) for c in contents]
320
- for result in self.generate(
321
- model_inputs=model_inputs,
322
- temperature=temperature,
323
- top_p=top_p,
324
- max_gen_len=max_gen_len,
325
- logprobs=logprobs,
326
- echo=echo,
327
- ):
328
- yield result
329
- if all(r.finished for r in result):
330
- break
331
-
332
- def chat_completion(
333
- self,
334
- messages_batch: list[list[RawMessage]],
335
- temperature: float = 0.6,
336
- top_p: float = 0.9,
337
- max_gen_len: int | None = None,
338
- logprobs: bool = False,
339
- tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
340
- echo: bool = False,
341
- ) -> Generator[list[GenerationResult], None, None]:
342
- model_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch]
343
- for result in self.generate(
344
- model_inputs=model_inputs,
345
- temperature=temperature,
346
- top_p=top_p,
347
- max_gen_len=max_gen_len,
348
- logprobs=logprobs,
349
- echo=echo,
350
- ):
351
- yield result
352
- if all(r.finished for r in result):
353
- break
354
-
355
-
356
- def sample_top_p(probs, p):
357
- """
358
- Perform top-p (nucleus) sampling on a probability distribution.
359
-
360
- Args:
361
- probs (torch.Tensor): Probability distribution tensor.
362
- p (float): Probability threshold for top-p sampling.
363
-
364
- Returns:
365
- torch.Tensor: Sampled token indices.
366
-
367
- Note:
368
- Top-p sampling selects the smallest set of tokens whose cumulative probability mass
369
- exceeds the threshold p. The distribution is renormalized based on the selected tokens.
370
- """
371
- probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
372
- probs_sum = torch.cumsum(probs_sort, dim=-1)
373
- mask = probs_sum - probs_sort > p
374
- probs_sort[mask] = 0.0
375
- probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
376
- next_token = torch.multinomial(probs_sort, num_samples=1)
377
- next_token = torch.gather(probs_idx, -1, next_token)
378
- return next_token