sglang 0.2.15__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +10 -6
- sglang/bench_serving.py +33 -38
- sglang/global_config.py +0 -4
- sglang/lang/backend/runtime_endpoint.py +13 -6
- sglang/lang/interpreter.py +1 -1
- sglang/launch_server.py +3 -6
- sglang/launch_server_llavavid.py +7 -8
- sglang/srt/{model_config.py → configs/model_config.py} +5 -0
- sglang/srt/constrained/__init__.py +2 -0
- sglang/srt/constrained/fsm_cache.py +29 -38
- sglang/srt/constrained/jump_forward.py +0 -1
- sglang/srt/conversation.py +4 -1
- sglang/srt/hf_transformers_utils.py +2 -4
- sglang/srt/layers/attention_backend.py +480 -0
- sglang/srt/layers/flashinfer_utils.py +235 -0
- sglang/srt/layers/logits_processor.py +64 -77
- sglang/srt/layers/radix_attention.py +11 -161
- sglang/srt/layers/sampler.py +40 -35
- sglang/srt/layers/torchao_utils.py +75 -0
- sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
- sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
- sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
- sglang/srt/lora/lora.py +403 -0
- sglang/srt/lora/lora_config.py +43 -0
- sglang/srt/lora/lora_manager.py +256 -0
- sglang/srt/managers/controller_multi.py +1 -5
- sglang/srt/managers/controller_single.py +0 -5
- sglang/srt/managers/io_struct.py +16 -1
- sglang/srt/managers/policy_scheduler.py +122 -5
- sglang/srt/managers/schedule_batch.py +110 -74
- sglang/srt/managers/tokenizer_manager.py +24 -15
- sglang/srt/managers/tp_worker.py +181 -115
- sglang/srt/model_executor/cuda_graph_runner.py +60 -133
- sglang/srt/model_executor/forward_batch_info.py +35 -312
- sglang/srt/model_executor/model_runner.py +118 -141
- sglang/srt/models/baichuan.py +416 -0
- sglang/srt/models/chatglm.py +6 -8
- sglang/srt/models/commandr.py +1 -5
- sglang/srt/models/dbrx.py +1 -5
- sglang/srt/models/deepseek.py +1 -5
- sglang/srt/models/deepseek_v2.py +1 -5
- sglang/srt/models/exaone.py +8 -43
- sglang/srt/models/gemma.py +1 -5
- sglang/srt/models/gemma2.py +1 -5
- sglang/srt/models/gpt_bigcode.py +1 -5
- sglang/srt/models/grok.py +1 -5
- sglang/srt/models/internlm2.py +1 -5
- sglang/srt/models/{llama2.py → llama.py} +48 -26
- sglang/srt/models/llama_classification.py +14 -40
- sglang/srt/models/llama_embedding.py +7 -6
- sglang/srt/models/llava.py +38 -16
- sglang/srt/models/llavavid.py +7 -8
- sglang/srt/models/minicpm.py +1 -5
- sglang/srt/models/minicpm3.py +665 -0
- sglang/srt/models/mistral.py +2 -3
- sglang/srt/models/mixtral.py +6 -5
- sglang/srt/models/mixtral_quant.py +1 -5
- sglang/srt/models/qwen.py +1 -5
- sglang/srt/models/qwen2.py +1 -5
- sglang/srt/models/qwen2_moe.py +6 -5
- sglang/srt/models/stablelm.py +1 -5
- sglang/srt/models/xverse.py +375 -0
- sglang/srt/models/xverse_moe.py +445 -0
- sglang/srt/openai_api/adapter.py +65 -46
- sglang/srt/openai_api/protocol.py +11 -3
- sglang/srt/sampling/sampling_batch_info.py +67 -58
- sglang/srt/server.py +24 -14
- sglang/srt/server_args.py +130 -28
- sglang/srt/utils.py +12 -0
- sglang/test/few_shot_gsm8k.py +132 -0
- sglang/test/runners.py +114 -22
- sglang/test/test_programs.py +70 -0
- sglang/test/test_utils.py +89 -1
- sglang/utils.py +38 -4
- sglang/version.py +1 -1
- {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/METADATA +31 -18
- sglang-0.3.1.dist-info/RECORD +129 -0
- {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/WHEEL +1 -1
- sglang-0.2.15.dist-info/RECORD +0 -118
- {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/LICENSE +0 -0
- {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/top_level.txt +0 -0
sglang/bench_latency.py
CHANGED
@@ -57,10 +57,9 @@ import pandas as pd
|
|
57
57
|
import torch
|
58
58
|
import torch.distributed as dist
|
59
59
|
|
60
|
+
from sglang.srt.configs.model_config import ModelConfig
|
60
61
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
61
62
|
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
62
|
-
from sglang.srt.model_config import ModelConfig
|
63
|
-
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
64
63
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
65
64
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
66
65
|
from sglang.srt.server_args import ServerArgs
|
@@ -165,6 +164,7 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer):
|
|
165
164
|
req.prefix_indices = []
|
166
165
|
req.sampling_params = sampling_params
|
167
166
|
req.fill_ids = req.origin_input_ids
|
167
|
+
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
168
168
|
reqs.append(req)
|
169
169
|
|
170
170
|
return input_ids, reqs
|
@@ -179,6 +179,7 @@ def prepare_extend_inputs_for_correctness_test(
|
|
179
179
|
req.prefix_indices = model_runner.req_to_token_pool.req_to_token[
|
180
180
|
i, : bench_args.cut_len
|
181
181
|
]
|
182
|
+
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
182
183
|
return reqs
|
183
184
|
|
184
185
|
|
@@ -195,6 +196,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
|
|
195
196
|
req.prefix_indices = []
|
196
197
|
req.sampling_params = sampling_params
|
197
198
|
req.fill_ids = req.origin_input_ids
|
199
|
+
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
198
200
|
reqs.append(req)
|
199
201
|
|
200
202
|
return reqs
|
@@ -208,15 +210,15 @@ def extend(reqs, model_runner):
|
|
208
210
|
tree_cache=None,
|
209
211
|
)
|
210
212
|
batch.prepare_for_extend(model_runner.model_config.vocab_size)
|
211
|
-
|
212
|
-
next_token_ids =
|
213
|
+
logits_output = model_runner.forward(batch)
|
214
|
+
next_token_ids = model_runner.sample(logits_output, batch).tolist()
|
213
215
|
return next_token_ids, logits_output.next_token_logits, batch
|
214
216
|
|
215
217
|
|
216
218
|
def decode(input_token_ids, batch, model_runner):
|
217
219
|
batch.prepare_for_decode(input_token_ids)
|
218
|
-
|
219
|
-
next_token_ids =
|
220
|
+
logits_output = model_runner.forward(batch)
|
221
|
+
next_token_ids = model_runner.sample(logits_output, batch).tolist()
|
220
222
|
return next_token_ids, logits_output.next_token_logits
|
221
223
|
|
222
224
|
|
@@ -480,6 +482,8 @@ def main(server_args, bench_args):
|
|
480
482
|
|
481
483
|
|
482
484
|
if __name__ == "__main__":
|
485
|
+
multiprocessing.set_start_method("spawn", force=True)
|
486
|
+
|
483
487
|
parser = argparse.ArgumentParser()
|
484
488
|
ServerArgs.add_cli_args(parser)
|
485
489
|
BenchArgs.add_cli_args(parser)
|
sglang/bench_serving.py
CHANGED
@@ -298,34 +298,41 @@ class BenchmarkMetrics:
|
|
298
298
|
median_e2e_latency_ms: float
|
299
299
|
|
300
300
|
|
301
|
-
|
301
|
+
SHAREGPT_URL = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
|
302
302
|
|
303
303
|
|
304
|
-
def
|
305
|
-
|
304
|
+
def download_and_cache_file(url: str, filename: Optional[str] = None):
|
305
|
+
"""Read and cache a file from a url."""
|
306
|
+
if filename is None:
|
307
|
+
filename = os.path.join("/tmp", url.split("/")[-1])
|
306
308
|
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
309
|
+
# Check if the cache file already exists
|
310
|
+
if os.path.exists(filename):
|
311
|
+
return filename
|
312
|
+
|
313
|
+
print(f"Downloading from {url} to {filename}")
|
311
314
|
|
312
|
-
|
313
|
-
|
315
|
+
# Stream the response to show the progress bar
|
316
|
+
response = requests.get(url, stream=True)
|
317
|
+
response.raise_for_status() # Check for request errors
|
314
318
|
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
unit="iB",
|
319
|
-
unit_scale=True,
|
320
|
-
unit_divisor=1024,
|
321
|
-
) as progress_bar:
|
322
|
-
for data in response.iter_content(block_size):
|
323
|
-
size = f.write(data)
|
324
|
-
progress_bar.update(size)
|
319
|
+
# Total size of the file in bytes
|
320
|
+
total_size = int(response.headers.get("content-length", 0))
|
321
|
+
chunk_size = 1024 # Download in chunks of 1KB
|
325
322
|
|
326
|
-
|
327
|
-
|
328
|
-
|
323
|
+
# Use tqdm to display the progress bar
|
324
|
+
with open(filename, "wb") as f, tqdm(
|
325
|
+
desc=filename,
|
326
|
+
total=total_size,
|
327
|
+
unit="B",
|
328
|
+
unit_scale=True,
|
329
|
+
unit_divisor=1024,
|
330
|
+
) as bar:
|
331
|
+
for chunk in response.iter_content(chunk_size=chunk_size):
|
332
|
+
f.write(chunk)
|
333
|
+
bar.update(len(chunk))
|
334
|
+
|
335
|
+
return filename
|
329
336
|
|
330
337
|
|
331
338
|
def sample_sharegpt_requests(
|
@@ -338,13 +345,8 @@ def sample_sharegpt_requests(
|
|
338
345
|
raise ValueError("output_len too small")
|
339
346
|
|
340
347
|
# Download sharegpt if necessary
|
341
|
-
if not os.path.isfile(dataset_path)
|
342
|
-
|
343
|
-
dataset_path = default_sharegpt_path
|
344
|
-
else:
|
345
|
-
dataset_path = (
|
346
|
-
dataset_path if os.path.isfile(dataset_path) else default_sharegpt_path
|
347
|
-
)
|
348
|
+
if not os.path.isfile(dataset_path):
|
349
|
+
dataset_path = download_and_cache_file(SHAREGPT_URL)
|
348
350
|
|
349
351
|
# Load the dataset.
|
350
352
|
with open(dataset_path) as f:
|
@@ -412,15 +414,8 @@ def sample_random_requests(
|
|
412
414
|
# Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens
|
413
415
|
|
414
416
|
# Download sharegpt if necessary
|
415
|
-
if not os.path.isfile(dataset_path)
|
416
|
-
|
417
|
-
):
|
418
|
-
download_sharegpt_dataset(default_sharegpt_path)
|
419
|
-
dataset_path = default_sharegpt_path
|
420
|
-
else:
|
421
|
-
dataset_path = (
|
422
|
-
dataset_path if os.path.isfile(dataset_path) else default_sharegpt_path
|
423
|
-
)
|
417
|
+
if not os.path.isfile(dataset_path):
|
418
|
+
dataset_path = download_and_cache_file(SHAREGPT_URL)
|
424
419
|
|
425
420
|
# Load the dataset.
|
426
421
|
with open(dataset_path) as f:
|
sglang/global_config.py
CHANGED
@@ -11,10 +11,6 @@ class GlobalConfig:
|
|
11
11
|
# Default backend of the language
|
12
12
|
self.default_backend = None
|
13
13
|
|
14
|
-
# Runtime constants: Request dependency time due to network delay
|
15
|
-
self.request_dependency_delay = 0.02
|
16
|
-
self.wait_for_new_request_delay = 0.0006
|
17
|
-
|
18
14
|
# Runtime constants: New generation token ratio estimation
|
19
15
|
self.init_new_token_ratio = 0.7
|
20
16
|
self.base_min_new_token_ratio = 0.1
|
@@ -4,7 +4,7 @@ from typing import List, Optional
|
|
4
4
|
|
5
5
|
from sglang.global_config import global_config
|
6
6
|
from sglang.lang.backend.base_backend import BaseBackend
|
7
|
-
from sglang.lang.chat_template import get_chat_template_by_model_path
|
7
|
+
from sglang.lang.chat_template import get_chat_template, get_chat_template_by_model_path
|
8
8
|
from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
|
9
9
|
from sglang.lang.interpreter import StreamExecutor
|
10
10
|
from sglang.lang.ir import (
|
@@ -23,6 +23,7 @@ class RuntimeEndpoint(BaseBackend):
|
|
23
23
|
base_url: str,
|
24
24
|
api_key: Optional[str] = None,
|
25
25
|
verify: Optional[str] = None,
|
26
|
+
chat_template_name: Optional[str] = None,
|
26
27
|
):
|
27
28
|
super().__init__()
|
28
29
|
self.support_concate_and_append = True
|
@@ -39,9 +40,12 @@ class RuntimeEndpoint(BaseBackend):
|
|
39
40
|
self._assert_success(res)
|
40
41
|
self.model_info = res.json()
|
41
42
|
|
42
|
-
|
43
|
-
self.
|
44
|
-
|
43
|
+
if chat_template_name:
|
44
|
+
self.chat_template = get_chat_template(chat_template_name)
|
45
|
+
else:
|
46
|
+
self.chat_template = get_chat_template_by_model_path(
|
47
|
+
self.model_info["model_path"]
|
48
|
+
)
|
45
49
|
|
46
50
|
def get_model_name(self):
|
47
51
|
return self.model_info["model_path"]
|
@@ -235,9 +239,12 @@ class RuntimeEndpoint(BaseBackend):
|
|
235
239
|
# Compute logprob
|
236
240
|
data = {
|
237
241
|
"text": [s.text_ + c for c in choices],
|
238
|
-
"sampling_params": {
|
242
|
+
"sampling_params": {
|
243
|
+
"max_new_tokens": 0,
|
244
|
+
"temperature": 0,
|
245
|
+
},
|
239
246
|
"return_logprob": True,
|
240
|
-
"logprob_start_len": max(prompt_len - 2, 0),
|
247
|
+
"logprob_start_len": max(prompt_len - 2, 0), # for token healing
|
241
248
|
}
|
242
249
|
obj = self._generate_http_request(s, data)
|
243
250
|
|
sglang/lang/interpreter.py
CHANGED
@@ -9,7 +9,7 @@ import uuid
|
|
9
9
|
import warnings
|
10
10
|
from concurrent.futures import ThreadPoolExecutor
|
11
11
|
from contextlib import contextmanager
|
12
|
-
from typing import Any, Callable, Dict, List, Optional
|
12
|
+
from typing import Any, Callable, Dict, List, Optional
|
13
13
|
|
14
14
|
import tqdm
|
15
15
|
|
sglang/launch_server.py
CHANGED
@@ -1,17 +1,14 @@
|
|
1
1
|
"""Launch the inference server."""
|
2
2
|
|
3
|
-
import argparse
|
4
3
|
import os
|
4
|
+
import sys
|
5
5
|
|
6
6
|
from sglang.srt.server import launch_server
|
7
|
-
from sglang.srt.server_args import
|
7
|
+
from sglang.srt.server_args import prepare_server_args
|
8
8
|
from sglang.srt.utils import kill_child_process
|
9
9
|
|
10
10
|
if __name__ == "__main__":
|
11
|
-
|
12
|
-
ServerArgs.add_cli_args(parser)
|
13
|
-
args = parser.parse_args()
|
14
|
-
server_args = ServerArgs.from_cli_args(args)
|
11
|
+
server_args = prepare_server_args(sys.argv[1:])
|
15
12
|
|
16
13
|
try:
|
17
14
|
launch_server(server_args)
|
sglang/launch_server_llavavid.py
CHANGED
@@ -1,14 +1,12 @@
|
|
1
1
|
"""Launch the inference server for Llava-video model."""
|
2
2
|
|
3
|
-
import
|
3
|
+
import json
|
4
|
+
import sys
|
4
5
|
|
5
|
-
from sglang.srt.server import
|
6
|
+
from sglang.srt.server import launch_server, prepare_server_args
|
6
7
|
|
7
8
|
if __name__ == "__main__":
|
8
|
-
|
9
|
-
ServerArgs.add_cli_args(parser)
|
10
|
-
args = parser.parse_args()
|
11
|
-
server_args = ServerArgs.from_cli_args(args)
|
9
|
+
server_args = prepare_server_args(sys.argv[1:])
|
12
10
|
|
13
11
|
model_override_args = {}
|
14
12
|
model_override_args["mm_spatial_pool_stride"] = 2
|
@@ -20,7 +18,8 @@ if __name__ == "__main__":
|
|
20
18
|
model_override_args["max_sequence_length"] = 4096 * 2
|
21
19
|
model_override_args["tokenizer_model_max_length"] = 4096 * 2
|
22
20
|
model_override_args["model_max_length"] = 4096 * 2
|
23
|
-
if "34b" in
|
21
|
+
if "34b" in server_args.model_path.lower():
|
24
22
|
model_override_args["image_token_index"] = 64002
|
23
|
+
server_args.json_model_override_args = json.dumps(model_override_args)
|
25
24
|
|
26
|
-
launch_server(server_args
|
25
|
+
launch_server(server_args)
|
@@ -64,6 +64,11 @@ class ModelConfig:
|
|
64
64
|
self.attention_arch = AttentionArch.MLA
|
65
65
|
self.kv_lora_rank = self.hf_config.kv_lora_rank
|
66
66
|
self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
|
67
|
+
elif "MiniCPM3ForCausalLM" in self.hf_config.architectures:
|
68
|
+
self.head_dim = 128
|
69
|
+
self.attention_arch = AttentionArch.MLA
|
70
|
+
self.kv_lora_rank = self.hf_config.kv_lora_rank
|
71
|
+
self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
|
67
72
|
else:
|
68
73
|
self.attention_arch = AttentionArch.MHA
|
69
74
|
|
@@ -16,6 +16,7 @@ limitations under the License.
|
|
16
16
|
"""Cache for the compressed finite state machine."""
|
17
17
|
|
18
18
|
from outlines.fsm.json_schema import build_regex_from_schema
|
19
|
+
from transformers import AutoTokenizer
|
19
20
|
|
20
21
|
from sglang.srt.constrained import RegexGuide, TransformerTokenizer
|
21
22
|
from sglang.srt.constrained.base_tool_cache import BaseToolCache
|
@@ -28,12 +29,9 @@ class FSMCache(BaseToolCache):
|
|
28
29
|
tokenizer_args_dict,
|
29
30
|
enable=True,
|
30
31
|
skip_tokenizer_init=False,
|
31
|
-
json_schema_mode=False,
|
32
32
|
):
|
33
33
|
super().__init__(enable=enable)
|
34
34
|
|
35
|
-
self.json_schema_mode = json_schema_mode
|
36
|
-
|
37
35
|
if (
|
38
36
|
skip_tokenizer_init
|
39
37
|
or tokenizer_path.endswith(".json")
|
@@ -42,44 +40,37 @@ class FSMCache(BaseToolCache):
|
|
42
40
|
# Do not support TiktokenTokenizer or SentencePieceTokenizer
|
43
41
|
return
|
44
42
|
|
45
|
-
|
43
|
+
tokenizer_args_dict.setdefault("padding_side", "left")
|
44
|
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_args_dict)
|
45
|
+
try:
|
46
|
+
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
|
47
|
+
except AttributeError:
|
48
|
+
# FIXME: tmp fix for chatglm2 & chatglm3 (pad_token_id=0)
|
49
|
+
origin_pad_token_id = tokenizer.pad_token_id
|
46
50
|
|
47
|
-
|
48
|
-
|
51
|
+
def fset(self, value):
|
52
|
+
self._value = value
|
49
53
|
|
50
|
-
|
51
|
-
|
52
|
-
tokenizer_path, **tokenizer_args_dict
|
54
|
+
type(tokenizer).pad_token_id = property(
|
55
|
+
fget=type(tokenizer).pad_token_id.fget, fset=fset
|
53
56
|
)
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
type(tokenizer).pad_token_id = property(
|
64
|
-
fget=type(tokenizer).pad_token_id.fget, fset=fset
|
65
|
-
)
|
66
|
-
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
|
67
|
-
self.outlines_tokenizer.tokenizer.pad_token_id = origin_pad_token_id
|
68
|
-
self.outlines_tokenizer.pad_token_id = origin_pad_token_id
|
69
|
-
self.outlines_tokenizer.pad_token = (
|
70
|
-
self.outlines_tokenizer.tokenizer.pad_token
|
71
|
-
)
|
72
|
-
self.outlines_tokenizer.vocabulary = (
|
73
|
-
self.outlines_tokenizer.tokenizer.get_vocab()
|
74
|
-
)
|
75
|
-
else:
|
76
|
-
self.outlines_tokenizer = TransformerTokenizer(
|
77
|
-
tokenizer_path, **tokenizer_args_dict
|
57
|
+
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
|
58
|
+
self.outlines_tokenizer.tokenizer.pad_token_id = origin_pad_token_id
|
59
|
+
self.outlines_tokenizer.pad_token_id = origin_pad_token_id
|
60
|
+
self.outlines_tokenizer.pad_token = (
|
61
|
+
self.outlines_tokenizer.tokenizer.pad_token
|
62
|
+
)
|
63
|
+
self.outlines_tokenizer.vocabulary = (
|
64
|
+
self.outlines_tokenizer.tokenizer.get_vocab()
|
78
65
|
)
|
79
66
|
|
80
|
-
def init_value(self,
|
81
|
-
|
82
|
-
|
83
|
-
|
67
|
+
def init_value(self, key):
|
68
|
+
key_type, key_string = key
|
69
|
+
if key_type == "json":
|
70
|
+
regex = build_regex_from_schema(key_string, whitespace_pattern=r"[\n\t ]*")
|
71
|
+
elif key_type == "regex":
|
72
|
+
regex = key_string
|
84
73
|
else:
|
85
|
-
|
74
|
+
raise ValueError(f"Invalid key_type: {key_type}")
|
75
|
+
|
76
|
+
return RegexGuide(regex, self.outlines_tokenizer), regex
|
sglang/srt/conversation.py
CHANGED
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|
13
13
|
limitations under the License.
|
14
14
|
"""
|
15
15
|
|
16
|
-
"""Conversation templates."""
|
16
|
+
"""Conversation chat templates."""
|
17
17
|
|
18
18
|
# Adapted from
|
19
19
|
# https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
@@ -71,6 +71,7 @@ class Conversation:
|
|
71
71
|
# Stop criteria (the default one is EOS token)
|
72
72
|
stop_str: Union[str, List[str]] = None
|
73
73
|
image_data: Optional[List[str]] = None
|
74
|
+
modalities: Optional[List[str]] = None
|
74
75
|
|
75
76
|
def get_prompt(self) -> str:
|
76
77
|
"""Get the prompt for generation."""
|
@@ -379,6 +380,7 @@ def generate_chat_conv(
|
|
379
380
|
sep2=conv.sep2,
|
380
381
|
stop_str=conv.stop_str,
|
381
382
|
image_data=[],
|
383
|
+
modalities=[],
|
382
384
|
)
|
383
385
|
|
384
386
|
if isinstance(request.messages, str):
|
@@ -408,6 +410,7 @@ def generate_chat_conv(
|
|
408
410
|
for content in message.content:
|
409
411
|
if content.type == "image_url":
|
410
412
|
num_image_url += 1
|
413
|
+
conv.modalities.append(content.modalities)
|
411
414
|
if num_image_url > 1:
|
412
415
|
image_token = "<image>"
|
413
416
|
else:
|
@@ -16,11 +16,9 @@ limitations under the License.
|
|
16
16
|
"""Utilities for Huggingface Transformers."""
|
17
17
|
|
18
18
|
import contextlib
|
19
|
-
import functools
|
20
|
-
import json
|
21
19
|
import os
|
22
20
|
import warnings
|
23
|
-
from typing import
|
21
|
+
from typing import Dict, Optional, Type, Union
|
24
22
|
|
25
23
|
from huggingface_hub import snapshot_download
|
26
24
|
from transformers import (
|
@@ -92,7 +90,7 @@ def get_context_length(config):
|
|
92
90
|
"""Get the context length of a model from a huggingface model configs."""
|
93
91
|
rope_scaling = getattr(config, "rope_scaling", None)
|
94
92
|
if rope_scaling:
|
95
|
-
rope_scaling_factor = config.rope_scaling
|
93
|
+
rope_scaling_factor = config.rope_scaling.get("factor", 1)
|
96
94
|
if "original_max_position_embeddings" in rope_scaling:
|
97
95
|
rope_scaling_factor = 1
|
98
96
|
if config.rope_scaling.get("rope_type", None) == "llama3":
|