sglang 0.2.15__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. sglang/bench_latency.py +10 -6
  2. sglang/bench_serving.py +33 -38
  3. sglang/global_config.py +0 -4
  4. sglang/lang/backend/runtime_endpoint.py +13 -6
  5. sglang/lang/interpreter.py +1 -1
  6. sglang/launch_server.py +3 -6
  7. sglang/launch_server_llavavid.py +7 -8
  8. sglang/srt/{model_config.py → configs/model_config.py} +5 -0
  9. sglang/srt/constrained/__init__.py +2 -0
  10. sglang/srt/constrained/fsm_cache.py +29 -38
  11. sglang/srt/constrained/jump_forward.py +0 -1
  12. sglang/srt/conversation.py +4 -1
  13. sglang/srt/hf_transformers_utils.py +2 -4
  14. sglang/srt/layers/attention_backend.py +480 -0
  15. sglang/srt/layers/flashinfer_utils.py +235 -0
  16. sglang/srt/layers/logits_processor.py +64 -77
  17. sglang/srt/layers/radix_attention.py +11 -161
  18. sglang/srt/layers/sampler.py +40 -35
  19. sglang/srt/layers/torchao_utils.py +75 -0
  20. sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
  21. sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
  22. sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
  23. sglang/srt/lora/lora.py +403 -0
  24. sglang/srt/lora/lora_config.py +43 -0
  25. sglang/srt/lora/lora_manager.py +256 -0
  26. sglang/srt/managers/controller_multi.py +1 -5
  27. sglang/srt/managers/controller_single.py +0 -5
  28. sglang/srt/managers/io_struct.py +16 -1
  29. sglang/srt/managers/policy_scheduler.py +122 -5
  30. sglang/srt/managers/schedule_batch.py +110 -74
  31. sglang/srt/managers/tokenizer_manager.py +24 -15
  32. sglang/srt/managers/tp_worker.py +181 -115
  33. sglang/srt/model_executor/cuda_graph_runner.py +60 -133
  34. sglang/srt/model_executor/forward_batch_info.py +35 -312
  35. sglang/srt/model_executor/model_runner.py +118 -141
  36. sglang/srt/models/baichuan.py +416 -0
  37. sglang/srt/models/chatglm.py +6 -8
  38. sglang/srt/models/commandr.py +1 -5
  39. sglang/srt/models/dbrx.py +1 -5
  40. sglang/srt/models/deepseek.py +1 -5
  41. sglang/srt/models/deepseek_v2.py +1 -5
  42. sglang/srt/models/exaone.py +8 -43
  43. sglang/srt/models/gemma.py +1 -5
  44. sglang/srt/models/gemma2.py +1 -5
  45. sglang/srt/models/gpt_bigcode.py +1 -5
  46. sglang/srt/models/grok.py +1 -5
  47. sglang/srt/models/internlm2.py +1 -5
  48. sglang/srt/models/{llama2.py → llama.py} +48 -26
  49. sglang/srt/models/llama_classification.py +14 -40
  50. sglang/srt/models/llama_embedding.py +7 -6
  51. sglang/srt/models/llava.py +38 -16
  52. sglang/srt/models/llavavid.py +7 -8
  53. sglang/srt/models/minicpm.py +1 -5
  54. sglang/srt/models/minicpm3.py +665 -0
  55. sglang/srt/models/mistral.py +2 -3
  56. sglang/srt/models/mixtral.py +6 -5
  57. sglang/srt/models/mixtral_quant.py +1 -5
  58. sglang/srt/models/qwen.py +1 -5
  59. sglang/srt/models/qwen2.py +1 -5
  60. sglang/srt/models/qwen2_moe.py +6 -5
  61. sglang/srt/models/stablelm.py +1 -5
  62. sglang/srt/models/xverse.py +375 -0
  63. sglang/srt/models/xverse_moe.py +445 -0
  64. sglang/srt/openai_api/adapter.py +65 -46
  65. sglang/srt/openai_api/protocol.py +11 -3
  66. sglang/srt/sampling/sampling_batch_info.py +67 -58
  67. sglang/srt/server.py +24 -14
  68. sglang/srt/server_args.py +130 -28
  69. sglang/srt/utils.py +12 -0
  70. sglang/test/few_shot_gsm8k.py +132 -0
  71. sglang/test/runners.py +114 -22
  72. sglang/test/test_programs.py +70 -0
  73. sglang/test/test_utils.py +89 -1
  74. sglang/utils.py +38 -4
  75. sglang/version.py +1 -1
  76. {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/METADATA +31 -18
  77. sglang-0.3.1.dist-info/RECORD +129 -0
  78. {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/WHEEL +1 -1
  79. sglang-0.2.15.dist-info/RECORD +0 -118
  80. {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/LICENSE +0 -0
  81. {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/top_level.txt +0 -0
sglang/bench_latency.py CHANGED
@@ -57,10 +57,9 @@ import pandas as pd
57
57
  import torch
58
58
  import torch.distributed as dist
59
59
 
60
+ from sglang.srt.configs.model_config import ModelConfig
60
61
  from sglang.srt.hf_transformers_utils import get_tokenizer
61
62
  from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
62
- from sglang.srt.model_config import ModelConfig
63
- from sglang.srt.model_executor.forward_batch_info import ForwardMode
64
63
  from sglang.srt.model_executor.model_runner import ModelRunner
65
64
  from sglang.srt.sampling.sampling_params import SamplingParams
66
65
  from sglang.srt.server_args import ServerArgs
@@ -165,6 +164,7 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer):
165
164
  req.prefix_indices = []
166
165
  req.sampling_params = sampling_params
167
166
  req.fill_ids = req.origin_input_ids
167
+ req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
168
168
  reqs.append(req)
169
169
 
170
170
  return input_ids, reqs
@@ -179,6 +179,7 @@ def prepare_extend_inputs_for_correctness_test(
179
179
  req.prefix_indices = model_runner.req_to_token_pool.req_to_token[
180
180
  i, : bench_args.cut_len
181
181
  ]
182
+ req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
182
183
  return reqs
183
184
 
184
185
 
@@ -195,6 +196,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
195
196
  req.prefix_indices = []
196
197
  req.sampling_params = sampling_params
197
198
  req.fill_ids = req.origin_input_ids
199
+ req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
198
200
  reqs.append(req)
199
201
 
200
202
  return reqs
@@ -208,15 +210,15 @@ def extend(reqs, model_runner):
208
210
  tree_cache=None,
209
211
  )
210
212
  batch.prepare_for_extend(model_runner.model_config.vocab_size)
211
- sample_output, logits_output = model_runner.forward(batch, ForwardMode.EXTEND)
212
- next_token_ids = sample_output.batch_next_token_ids.tolist()
213
+ logits_output = model_runner.forward(batch)
214
+ next_token_ids = model_runner.sample(logits_output, batch).tolist()
213
215
  return next_token_ids, logits_output.next_token_logits, batch
214
216
 
215
217
 
216
218
  def decode(input_token_ids, batch, model_runner):
217
219
  batch.prepare_for_decode(input_token_ids)
218
- sample_output, logits_output = model_runner.forward(batch, ForwardMode.DECODE)
219
- next_token_ids = sample_output.batch_next_token_ids.tolist()
220
+ logits_output = model_runner.forward(batch)
221
+ next_token_ids = model_runner.sample(logits_output, batch).tolist()
220
222
  return next_token_ids, logits_output.next_token_logits
221
223
 
222
224
 
@@ -480,6 +482,8 @@ def main(server_args, bench_args):
480
482
 
481
483
 
482
484
  if __name__ == "__main__":
485
+ multiprocessing.set_start_method("spawn", force=True)
486
+
483
487
  parser = argparse.ArgumentParser()
484
488
  ServerArgs.add_cli_args(parser)
485
489
  BenchArgs.add_cli_args(parser)
sglang/bench_serving.py CHANGED
@@ -298,34 +298,41 @@ class BenchmarkMetrics:
298
298
  median_e2e_latency_ms: float
299
299
 
300
300
 
301
- default_sharegpt_path = "ShareGPT_V3_unfiltered_cleaned_split.json"
301
+ SHAREGPT_URL = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
302
302
 
303
303
 
304
- def download_sharegpt_dataset(path):
305
- url = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
304
+ def download_and_cache_file(url: str, filename: Optional[str] = None):
305
+ """Read and cache a file from a url."""
306
+ if filename is None:
307
+ filename = os.path.join("/tmp", url.split("/")[-1])
306
308
 
307
- print(f"Downloading dataset from {url}")
308
- try:
309
- response = requests.get(url, stream=True)
310
- response.raise_for_status()
309
+ # Check if the cache file already exists
310
+ if os.path.exists(filename):
311
+ return filename
312
+
313
+ print(f"Downloading from {url} to {filename}")
311
314
 
312
- total_size = int(response.headers.get("content-length", 0))
313
- block_size = 8192
315
+ # Stream the response to show the progress bar
316
+ response = requests.get(url, stream=True)
317
+ response.raise_for_status() # Check for request errors
314
318
 
315
- with open(path, "wb") as f, tqdm(
316
- desc="Downloading",
317
- total=total_size,
318
- unit="iB",
319
- unit_scale=True,
320
- unit_divisor=1024,
321
- ) as progress_bar:
322
- for data in response.iter_content(block_size):
323
- size = f.write(data)
324
- progress_bar.update(size)
319
+ # Total size of the file in bytes
320
+ total_size = int(response.headers.get("content-length", 0))
321
+ chunk_size = 1024 # Download in chunks of 1KB
325
322
 
326
- print(f"Dataset downloaded and saved to {path}")
327
- except requests.RequestException as e:
328
- raise Exception(f"Failed to download dataset: {e}")
323
+ # Use tqdm to display the progress bar
324
+ with open(filename, "wb") as f, tqdm(
325
+ desc=filename,
326
+ total=total_size,
327
+ unit="B",
328
+ unit_scale=True,
329
+ unit_divisor=1024,
330
+ ) as bar:
331
+ for chunk in response.iter_content(chunk_size=chunk_size):
332
+ f.write(chunk)
333
+ bar.update(len(chunk))
334
+
335
+ return filename
329
336
 
330
337
 
331
338
  def sample_sharegpt_requests(
@@ -338,13 +345,8 @@ def sample_sharegpt_requests(
338
345
  raise ValueError("output_len too small")
339
346
 
340
347
  # Download sharegpt if necessary
341
- if not os.path.isfile(dataset_path) and not os.path.isfile(default_sharegpt_path):
342
- download_sharegpt_dataset(default_sharegpt_path)
343
- dataset_path = default_sharegpt_path
344
- else:
345
- dataset_path = (
346
- dataset_path if os.path.isfile(dataset_path) else default_sharegpt_path
347
- )
348
+ if not os.path.isfile(dataset_path):
349
+ dataset_path = download_and_cache_file(SHAREGPT_URL)
348
350
 
349
351
  # Load the dataset.
350
352
  with open(dataset_path) as f:
@@ -412,15 +414,8 @@ def sample_random_requests(
412
414
  # Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens
413
415
 
414
416
  # Download sharegpt if necessary
415
- if not os.path.isfile(dataset_path) and not os.path.isfile(
416
- default_sharegpt_path
417
- ):
418
- download_sharegpt_dataset(default_sharegpt_path)
419
- dataset_path = default_sharegpt_path
420
- else:
421
- dataset_path = (
422
- dataset_path if os.path.isfile(dataset_path) else default_sharegpt_path
423
- )
417
+ if not os.path.isfile(dataset_path):
418
+ dataset_path = download_and_cache_file(SHAREGPT_URL)
424
419
 
425
420
  # Load the dataset.
426
421
  with open(dataset_path) as f:
sglang/global_config.py CHANGED
@@ -11,10 +11,6 @@ class GlobalConfig:
11
11
  # Default backend of the language
12
12
  self.default_backend = None
13
13
 
14
- # Runtime constants: Request dependency time due to network delay
15
- self.request_dependency_delay = 0.02
16
- self.wait_for_new_request_delay = 0.0006
17
-
18
14
  # Runtime constants: New generation token ratio estimation
19
15
  self.init_new_token_ratio = 0.7
20
16
  self.base_min_new_token_ratio = 0.1
@@ -4,7 +4,7 @@ from typing import List, Optional
4
4
 
5
5
  from sglang.global_config import global_config
6
6
  from sglang.lang.backend.base_backend import BaseBackend
7
- from sglang.lang.chat_template import get_chat_template_by_model_path
7
+ from sglang.lang.chat_template import get_chat_template, get_chat_template_by_model_path
8
8
  from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
9
9
  from sglang.lang.interpreter import StreamExecutor
10
10
  from sglang.lang.ir import (
@@ -23,6 +23,7 @@ class RuntimeEndpoint(BaseBackend):
23
23
  base_url: str,
24
24
  api_key: Optional[str] = None,
25
25
  verify: Optional[str] = None,
26
+ chat_template_name: Optional[str] = None,
26
27
  ):
27
28
  super().__init__()
28
29
  self.support_concate_and_append = True
@@ -39,9 +40,12 @@ class RuntimeEndpoint(BaseBackend):
39
40
  self._assert_success(res)
40
41
  self.model_info = res.json()
41
42
 
42
- self.chat_template = get_chat_template_by_model_path(
43
- self.model_info["model_path"]
44
- )
43
+ if chat_template_name:
44
+ self.chat_template = get_chat_template(chat_template_name)
45
+ else:
46
+ self.chat_template = get_chat_template_by_model_path(
47
+ self.model_info["model_path"]
48
+ )
45
49
 
46
50
  def get_model_name(self):
47
51
  return self.model_info["model_path"]
@@ -235,9 +239,12 @@ class RuntimeEndpoint(BaseBackend):
235
239
  # Compute logprob
236
240
  data = {
237
241
  "text": [s.text_ + c for c in choices],
238
- "sampling_params": {"max_new_tokens": 0},
242
+ "sampling_params": {
243
+ "max_new_tokens": 0,
244
+ "temperature": 0,
245
+ },
239
246
  "return_logprob": True,
240
- "logprob_start_len": max(prompt_len - 2, 0),
247
+ "logprob_start_len": max(prompt_len - 2, 0), # for token healing
241
248
  }
242
249
  obj = self._generate_http_request(s, data)
243
250
 
@@ -9,7 +9,7 @@ import uuid
9
9
  import warnings
10
10
  from concurrent.futures import ThreadPoolExecutor
11
11
  from contextlib import contextmanager
12
- from typing import Any, Callable, Dict, List, Optional, Union
12
+ from typing import Any, Callable, Dict, List, Optional
13
13
 
14
14
  import tqdm
15
15
 
sglang/launch_server.py CHANGED
@@ -1,17 +1,14 @@
1
1
  """Launch the inference server."""
2
2
 
3
- import argparse
4
3
  import os
4
+ import sys
5
5
 
6
6
  from sglang.srt.server import launch_server
7
- from sglang.srt.server_args import ServerArgs
7
+ from sglang.srt.server_args import prepare_server_args
8
8
  from sglang.srt.utils import kill_child_process
9
9
 
10
10
  if __name__ == "__main__":
11
- parser = argparse.ArgumentParser()
12
- ServerArgs.add_cli_args(parser)
13
- args = parser.parse_args()
14
- server_args = ServerArgs.from_cli_args(args)
11
+ server_args = prepare_server_args(sys.argv[1:])
15
12
 
16
13
  try:
17
14
  launch_server(server_args)
@@ -1,14 +1,12 @@
1
1
  """Launch the inference server for Llava-video model."""
2
2
 
3
- import argparse
3
+ import json
4
+ import sys
4
5
 
5
- from sglang.srt.server import ServerArgs, launch_server
6
+ from sglang.srt.server import launch_server, prepare_server_args
6
7
 
7
8
  if __name__ == "__main__":
8
- parser = argparse.ArgumentParser()
9
- ServerArgs.add_cli_args(parser)
10
- args = parser.parse_args()
11
- server_args = ServerArgs.from_cli_args(args)
9
+ server_args = prepare_server_args(sys.argv[1:])
12
10
 
13
11
  model_override_args = {}
14
12
  model_override_args["mm_spatial_pool_stride"] = 2
@@ -20,7 +18,8 @@ if __name__ == "__main__":
20
18
  model_override_args["max_sequence_length"] = 4096 * 2
21
19
  model_override_args["tokenizer_model_max_length"] = 4096 * 2
22
20
  model_override_args["model_max_length"] = 4096 * 2
23
- if "34b" in args.model_path.lower():
21
+ if "34b" in server_args.model_path.lower():
24
22
  model_override_args["image_token_index"] = 64002
23
+ server_args.json_model_override_args = json.dumps(model_override_args)
25
24
 
26
- launch_server(server_args, model_override_args, None)
25
+ launch_server(server_args)
@@ -64,6 +64,11 @@ class ModelConfig:
64
64
  self.attention_arch = AttentionArch.MLA
65
65
  self.kv_lora_rank = self.hf_config.kv_lora_rank
66
66
  self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
67
+ elif "MiniCPM3ForCausalLM" in self.hf_config.architectures:
68
+ self.head_dim = 128
69
+ self.attention_arch = AttentionArch.MLA
70
+ self.kv_lora_rank = self.hf_config.kv_lora_rank
71
+ self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
67
72
  else:
68
73
  self.attention_arch = AttentionArch.MHA
69
74
 
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
13
13
  limitations under the License.
14
14
  """
15
15
 
16
+ """For constrained decoding."""
17
+
16
18
  import json
17
19
  from typing import Dict, Optional, Union
18
20
 
@@ -16,6 +16,7 @@ limitations under the License.
16
16
  """Cache for the compressed finite state machine."""
17
17
 
18
18
  from outlines.fsm.json_schema import build_regex_from_schema
19
+ from transformers import AutoTokenizer
19
20
 
20
21
  from sglang.srt.constrained import RegexGuide, TransformerTokenizer
21
22
  from sglang.srt.constrained.base_tool_cache import BaseToolCache
@@ -28,12 +29,9 @@ class FSMCache(BaseToolCache):
28
29
  tokenizer_args_dict,
29
30
  enable=True,
30
31
  skip_tokenizer_init=False,
31
- json_schema_mode=False,
32
32
  ):
33
33
  super().__init__(enable=enable)
34
34
 
35
- self.json_schema_mode = json_schema_mode
36
-
37
35
  if (
38
36
  skip_tokenizer_init
39
37
  or tokenizer_path.endswith(".json")
@@ -42,44 +40,37 @@ class FSMCache(BaseToolCache):
42
40
  # Do not support TiktokenTokenizer or SentencePieceTokenizer
43
41
  return
44
42
 
45
- from importlib.metadata import version
43
+ tokenizer_args_dict.setdefault("padding_side", "left")
44
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_args_dict)
45
+ try:
46
+ self.outlines_tokenizer = TransformerTokenizer(tokenizer)
47
+ except AttributeError:
48
+ # FIXME: tmp fix for chatglm2 & chatglm3 (pad_token_id=0)
49
+ origin_pad_token_id = tokenizer.pad_token_id
46
50
 
47
- if version("outlines") >= "0.0.35":
48
- from transformers import AutoTokenizer
51
+ def fset(self, value):
52
+ self._value = value
49
53
 
50
- tokenizer_args_dict.setdefault("padding_side", "left")
51
- tokenizer = AutoTokenizer.from_pretrained(
52
- tokenizer_path, **tokenizer_args_dict
54
+ type(tokenizer).pad_token_id = property(
55
+ fget=type(tokenizer).pad_token_id.fget, fset=fset
53
56
  )
54
- try:
55
- self.outlines_tokenizer = TransformerTokenizer(tokenizer)
56
- except AttributeError:
57
- # FIXME: tmp fix for chatglm2 & chatglm3 (pad_token_id=0)
58
- origin_pad_token_id = tokenizer.pad_token_id
59
-
60
- def fset(self, value):
61
- self._value = value
62
-
63
- type(tokenizer).pad_token_id = property(
64
- fget=type(tokenizer).pad_token_id.fget, fset=fset
65
- )
66
- self.outlines_tokenizer = TransformerTokenizer(tokenizer)
67
- self.outlines_tokenizer.tokenizer.pad_token_id = origin_pad_token_id
68
- self.outlines_tokenizer.pad_token_id = origin_pad_token_id
69
- self.outlines_tokenizer.pad_token = (
70
- self.outlines_tokenizer.tokenizer.pad_token
71
- )
72
- self.outlines_tokenizer.vocabulary = (
73
- self.outlines_tokenizer.tokenizer.get_vocab()
74
- )
75
- else:
76
- self.outlines_tokenizer = TransformerTokenizer(
77
- tokenizer_path, **tokenizer_args_dict
57
+ self.outlines_tokenizer = TransformerTokenizer(tokenizer)
58
+ self.outlines_tokenizer.tokenizer.pad_token_id = origin_pad_token_id
59
+ self.outlines_tokenizer.pad_token_id = origin_pad_token_id
60
+ self.outlines_tokenizer.pad_token = (
61
+ self.outlines_tokenizer.tokenizer.pad_token
62
+ )
63
+ self.outlines_tokenizer.vocabulary = (
64
+ self.outlines_tokenizer.tokenizer.get_vocab()
78
65
  )
79
66
 
80
- def init_value(self, value):
81
- if self.json_schema_mode:
82
- regex = build_regex_from_schema(value, whitespace_pattern=r"[\n\t ]*")
83
- return RegexGuide(regex, self.outlines_tokenizer), regex
67
+ def init_value(self, key):
68
+ key_type, key_string = key
69
+ if key_type == "json":
70
+ regex = build_regex_from_schema(key_string, whitespace_pattern=r"[\n\t ]*")
71
+ elif key_type == "regex":
72
+ regex = key_string
84
73
  else:
85
- return RegexGuide(value, self.outlines_tokenizer)
74
+ raise ValueError(f"Invalid key_type: {key_type}")
75
+
76
+ return RegexGuide(regex, self.outlines_tokenizer), regex
@@ -23,7 +23,6 @@ from collections import defaultdict
23
23
 
24
24
  import interegular
25
25
  import outlines.caching
26
- from outlines.fsm.json_schema import build_regex_from_schema
27
26
 
28
27
  from sglang.srt.constrained import (
29
28
  FSMInfo,
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
13
13
  limitations under the License.
14
14
  """
15
15
 
16
- """Conversation templates."""
16
+ """Conversation chat templates."""
17
17
 
18
18
  # Adapted from
19
19
  # https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
@@ -71,6 +71,7 @@ class Conversation:
71
71
  # Stop criteria (the default one is EOS token)
72
72
  stop_str: Union[str, List[str]] = None
73
73
  image_data: Optional[List[str]] = None
74
+ modalities: Optional[List[str]] = None
74
75
 
75
76
  def get_prompt(self) -> str:
76
77
  """Get the prompt for generation."""
@@ -379,6 +380,7 @@ def generate_chat_conv(
379
380
  sep2=conv.sep2,
380
381
  stop_str=conv.stop_str,
381
382
  image_data=[],
383
+ modalities=[],
382
384
  )
383
385
 
384
386
  if isinstance(request.messages, str):
@@ -408,6 +410,7 @@ def generate_chat_conv(
408
410
  for content in message.content:
409
411
  if content.type == "image_url":
410
412
  num_image_url += 1
413
+ conv.modalities.append(content.modalities)
411
414
  if num_image_url > 1:
412
415
  image_token = "<image>"
413
416
  else:
@@ -16,11 +16,9 @@ limitations under the License.
16
16
  """Utilities for Huggingface Transformers."""
17
17
 
18
18
  import contextlib
19
- import functools
20
- import json
21
19
  import os
22
20
  import warnings
23
- from typing import AbstractSet, Collection, Dict, List, Literal, Optional, Type, Union
21
+ from typing import Dict, Optional, Type, Union
24
22
 
25
23
  from huggingface_hub import snapshot_download
26
24
  from transformers import (
@@ -92,7 +90,7 @@ def get_context_length(config):
92
90
  """Get the context length of a model from a huggingface model configs."""
93
91
  rope_scaling = getattr(config, "rope_scaling", None)
94
92
  if rope_scaling:
95
- rope_scaling_factor = config.rope_scaling["factor"]
93
+ rope_scaling_factor = config.rope_scaling.get("factor", 1)
96
94
  if "original_max_position_embeddings" in rope_scaling:
97
95
  rope_scaling_factor = 1
98
96
  if config.rope_scaling.get("rope_type", None) == "llama3":