sglang 0.4.6.post4__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +6 -6
- sglang/bench_one_batch.py +5 -4
- sglang/bench_one_batch_server.py +23 -15
- sglang/bench_serving.py +133 -57
- sglang/compile_deep_gemm.py +4 -4
- sglang/srt/configs/model_config.py +39 -28
- sglang/srt/conversation.py +1 -1
- sglang/srt/disaggregation/decode.py +122 -133
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
- sglang/srt/disaggregation/fake/conn.py +3 -13
- sglang/srt/disaggregation/kv_events.py +357 -0
- sglang/srt/disaggregation/mini_lb.py +57 -24
- sglang/srt/disaggregation/mooncake/conn.py +11 -2
- sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
- sglang/srt/disaggregation/nixl/conn.py +9 -19
- sglang/srt/disaggregation/prefill.py +126 -44
- sglang/srt/disaggregation/utils.py +116 -5
- sglang/srt/distributed/utils.py +3 -3
- sglang/srt/entrypoints/EngineBase.py +5 -0
- sglang/srt/entrypoints/engine.py +28 -8
- sglang/srt/entrypoints/http_server.py +6 -4
- sglang/srt/entrypoints/http_server_engine.py +5 -2
- sglang/srt/function_call/base_format_detector.py +250 -0
- sglang/srt/function_call/core_types.py +34 -0
- sglang/srt/function_call/deepseekv3_detector.py +157 -0
- sglang/srt/function_call/ebnf_composer.py +234 -0
- sglang/srt/function_call/function_call_parser.py +175 -0
- sglang/srt/function_call/llama32_detector.py +74 -0
- sglang/srt/function_call/mistral_detector.py +84 -0
- sglang/srt/function_call/pythonic_detector.py +163 -0
- sglang/srt/function_call/qwen25_detector.py +67 -0
- sglang/srt/function_call/utils.py +35 -0
- sglang/srt/hf_transformers_utils.py +46 -7
- sglang/srt/layers/attention/aiter_backend.py +513 -0
- sglang/srt/layers/attention/flashattention_backend.py +63 -17
- sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
- sglang/srt/layers/attention/flashmla_backend.py +340 -78
- sglang/srt/layers/attention/triton_backend.py +3 -0
- sglang/srt/layers/attention/utils.py +2 -2
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/communicator.py +451 -0
- sglang/srt/layers/dp_attention.py +0 -10
- sglang/srt/layers/moe/cutlass_moe.py +207 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +33 -11
- sglang/srt/layers/moe/ep_moe/layer.py +104 -50
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
- sglang/srt/layers/moe/topk.py +66 -9
- sglang/srt/layers/multimodal.py +70 -0
- sglang/srt/layers/quantization/__init__.py +7 -2
- sglang/srt/layers/quantization/deep_gemm.py +5 -3
- sglang/srt/layers/quantization/fp8.py +90 -0
- sglang/srt/layers/quantization/fp8_utils.py +6 -0
- sglang/srt/layers/quantization/gptq.py +298 -6
- sglang/srt/layers/quantization/int8_kernel.py +18 -5
- sglang/srt/layers/quantization/qoq.py +244 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/deepseek_eplb.py +278 -0
- sglang/srt/managers/eplb_manager.py +55 -0
- sglang/srt/managers/expert_distribution.py +704 -56
- sglang/srt/managers/expert_location.py +394 -0
- sglang/srt/managers/expert_location_dispatch.py +91 -0
- sglang/srt/managers/io_struct.py +16 -3
- sglang/srt/managers/mm_utils.py +293 -139
- sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
- sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
- sglang/srt/managers/multimodal_processors/internvl.py +14 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
- sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
- sglang/srt/managers/multimodal_processors/llava.py +3 -3
- sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
- sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +9 -9
- sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
- sglang/srt/managers/schedule_batch.py +49 -21
- sglang/srt/managers/schedule_policy.py +4 -5
- sglang/srt/managers/scheduler.py +92 -50
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +99 -24
- sglang/srt/mem_cache/base_prefix_cache.py +3 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -1
- sglang/srt/mem_cache/hiradix_cache.py +4 -4
- sglang/srt/mem_cache/memory_pool.py +74 -52
- sglang/srt/mem_cache/multimodal_cache.py +45 -0
- sglang/srt/mem_cache/radix_cache.py +58 -5
- sglang/srt/metrics/collector.py +2 -2
- sglang/srt/mm_utils.py +10 -0
- sglang/srt/model_executor/cuda_graph_runner.py +20 -9
- sglang/srt/model_executor/expert_location_updater.py +422 -0
- sglang/srt/model_executor/forward_batch_info.py +4 -0
- sglang/srt/model_executor/model_runner.py +144 -54
- sglang/srt/model_loader/loader.py +10 -6
- sglang/srt/models/clip.py +5 -1
- sglang/srt/models/deepseek_v2.py +297 -343
- sglang/srt/models/exaone.py +8 -3
- sglang/srt/models/gemma3_mm.py +70 -33
- sglang/srt/models/llama4.py +10 -2
- sglang/srt/models/llava.py +26 -18
- sglang/srt/models/mimo_mtp.py +220 -0
- sglang/srt/models/minicpmo.py +5 -12
- sglang/srt/models/mistral.py +71 -1
- sglang/srt/models/mllama.py +3 -3
- sglang/srt/models/qwen2.py +95 -26
- sglang/srt/models/qwen2_5_vl.py +8 -0
- sglang/srt/models/qwen2_moe.py +330 -60
- sglang/srt/models/qwen2_vl.py +6 -0
- sglang/srt/models/qwen3.py +52 -10
- sglang/srt/models/qwen3_moe.py +411 -48
- sglang/srt/models/siglip.py +294 -0
- sglang/srt/openai_api/adapter.py +28 -16
- sglang/srt/openai_api/protocol.py +6 -0
- sglang/srt/operations.py +154 -0
- sglang/srt/operations_strategy.py +31 -0
- sglang/srt/server_args.py +134 -24
- sglang/srt/speculative/eagle_utils.py +131 -0
- sglang/srt/speculative/eagle_worker.py +47 -2
- sglang/srt/utils.py +68 -12
- sglang/test/test_cutlass_moe.py +278 -0
- sglang/test/test_utils.py +2 -36
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +20 -11
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +128 -102
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
- sglang/srt/function_call_parser.py +0 -858
- sglang/srt/platforms/interface.py +0 -371
- /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,163 @@
|
|
1
|
+
import ast
|
2
|
+
import json
|
3
|
+
import logging
|
4
|
+
import re
|
5
|
+
from typing import List, Optional
|
6
|
+
|
7
|
+
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
8
|
+
from sglang.srt.function_call.core_types import (
|
9
|
+
StreamingParseResult,
|
10
|
+
StructureInfo,
|
11
|
+
ToolCallItem,
|
12
|
+
_GetInfoFunc,
|
13
|
+
)
|
14
|
+
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
15
|
+
from sglang.srt.openai_api.protocol import Tool
|
16
|
+
|
17
|
+
logger = logging.getLogger(__name__)
|
18
|
+
|
19
|
+
|
20
|
+
class PythonicDetector(BaseFormatDetector):
|
21
|
+
"""
|
22
|
+
Detector for Llama-3.2 and Llama-4 models with pythonic tool call format.
|
23
|
+
Assumes function call format:
|
24
|
+
[tool1(arg1=val1, arg2=val2), tool2(arg1=val3)]
|
25
|
+
Arguments are Python literals (not JSON).
|
26
|
+
"""
|
27
|
+
|
28
|
+
def __init__(self):
|
29
|
+
super().__init__()
|
30
|
+
self.tool_call_regex = re.compile(
|
31
|
+
r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]",
|
32
|
+
re.DOTALL,
|
33
|
+
)
|
34
|
+
|
35
|
+
def has_tool_call(self, text: str) -> bool:
|
36
|
+
return bool(self.tool_call_regex.match(text.strip()))
|
37
|
+
|
38
|
+
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
39
|
+
# Try parsing the text as a Python list of function calls
|
40
|
+
text = text.strip()
|
41
|
+
if not (text.startswith("[") and text.endswith("]")):
|
42
|
+
# Not a pythonic tool call format
|
43
|
+
return StreamingParseResult(normal_text=text, calls=[])
|
44
|
+
try:
|
45
|
+
module = ast.parse(text)
|
46
|
+
parsed = getattr(module.body[0], "value", None)
|
47
|
+
if not (
|
48
|
+
isinstance(parsed, ast.List)
|
49
|
+
and all(isinstance(e, ast.Call) for e in parsed.elts)
|
50
|
+
):
|
51
|
+
return StreamingParseResult(normal_text=text, calls=[])
|
52
|
+
calls = []
|
53
|
+
tool_indices = {
|
54
|
+
tool.function.name: i
|
55
|
+
for i, tool in enumerate(tools)
|
56
|
+
if tool.function.name
|
57
|
+
}
|
58
|
+
for call in parsed.elts:
|
59
|
+
if not isinstance(call.func, ast.Name):
|
60
|
+
continue
|
61
|
+
function_name = call.func.id
|
62
|
+
arguments = {}
|
63
|
+
for keyword in call.keywords:
|
64
|
+
arguments[keyword.arg] = self._get_parameter_value(keyword.value)
|
65
|
+
calls.append(
|
66
|
+
ToolCallItem(
|
67
|
+
tool_index=tool_indices.get(function_name, -1),
|
68
|
+
name=function_name,
|
69
|
+
parameters=json.dumps(arguments, ensure_ascii=False),
|
70
|
+
)
|
71
|
+
)
|
72
|
+
return StreamingParseResult(normal_text="", calls=calls)
|
73
|
+
except Exception:
|
74
|
+
logger.exception("Error in pythonic tool call parsing.")
|
75
|
+
return StreamingParseResult(normal_text=text, calls=[])
|
76
|
+
|
77
|
+
def _find_matching_bracket(self, buffer: str, start: int) -> int:
|
78
|
+
"""
|
79
|
+
Find the matching closing bracket for the opening bracket at start position.
|
80
|
+
Properly handles nested brackets.
|
81
|
+
|
82
|
+
Args:
|
83
|
+
buffer: The text buffer to search in
|
84
|
+
start: Position of the opening bracket '['
|
85
|
+
|
86
|
+
Returns:
|
87
|
+
Position of the matching closing bracket ']', or -1 if not found
|
88
|
+
"""
|
89
|
+
bracket_count = 0
|
90
|
+
for i in range(start, len(buffer)):
|
91
|
+
if buffer[i] == "[":
|
92
|
+
bracket_count += 1
|
93
|
+
elif buffer[i] == "]":
|
94
|
+
bracket_count -= 1
|
95
|
+
if bracket_count == 0:
|
96
|
+
return i
|
97
|
+
return -1 # No matching bracket found
|
98
|
+
|
99
|
+
def parse_streaming_increment(
|
100
|
+
self, new_text: str, tools: List[Tool]
|
101
|
+
) -> StreamingParseResult:
|
102
|
+
"""
|
103
|
+
Streaming incremental parsing for pythonic tool calls.
|
104
|
+
Buffers input until a complete pythonic tool call (from [ to ]) is found,
|
105
|
+
then parses and emits any detected calls.
|
106
|
+
"""
|
107
|
+
self._buffer += new_text
|
108
|
+
start = self._buffer.find("[")
|
109
|
+
|
110
|
+
if start == -1:
|
111
|
+
normal_text = self._buffer
|
112
|
+
self._buffer = ""
|
113
|
+
return StreamingParseResult(normal_text=normal_text)
|
114
|
+
|
115
|
+
normal_text = self._buffer[:start] if start > 0 else ""
|
116
|
+
|
117
|
+
end = self._find_matching_bracket(self._buffer, start)
|
118
|
+
if end != -1:
|
119
|
+
call_text = self._buffer[start : end + 1]
|
120
|
+
result = self.detect_and_parse(call_text, tools)
|
121
|
+
self._buffer = self._buffer[end + 1 :]
|
122
|
+
|
123
|
+
# If we had normal text before the tool call, add it to the result
|
124
|
+
if normal_text:
|
125
|
+
result.normal_text = normal_text + (result.normal_text or "")
|
126
|
+
|
127
|
+
return result
|
128
|
+
|
129
|
+
# We have an opening bracket but no closing bracket yet
|
130
|
+
if normal_text:
|
131
|
+
self._buffer = self._buffer[start:]
|
132
|
+
return StreamingParseResult(normal_text=normal_text)
|
133
|
+
|
134
|
+
# Otherwise, we're still accumulating a potential tool call
|
135
|
+
return StreamingParseResult(normal_text="")
|
136
|
+
|
137
|
+
def _get_parameter_value(self, val):
|
138
|
+
if isinstance(val, ast.Constant):
|
139
|
+
return val.value
|
140
|
+
elif isinstance(val, ast.Dict):
|
141
|
+
return {
|
142
|
+
k.value: self._get_parameter_value(v)
|
143
|
+
for k, v in zip(val.keys, val.values)
|
144
|
+
}
|
145
|
+
elif isinstance(val, ast.List):
|
146
|
+
return [self._get_parameter_value(v) for v in val.elts]
|
147
|
+
else:
|
148
|
+
raise ValueError("Tool call arguments must be literals")
|
149
|
+
|
150
|
+
def structure_info(self) -> _GetInfoFunc:
|
151
|
+
def info(name: str):
|
152
|
+
return StructureInfo(begin=f"[{name}(", end=")]", trigger=f"[{name}(")
|
153
|
+
|
154
|
+
return info
|
155
|
+
|
156
|
+
def build_ebnf(self, tools: List[Tool]) -> Optional[str]:
|
157
|
+
return EBNFComposer.build_ebnf(
|
158
|
+
tools,
|
159
|
+
bot_token="[",
|
160
|
+
eot_token="]",
|
161
|
+
tool_call_separator=",",
|
162
|
+
function_format="pythonic",
|
163
|
+
)
|
@@ -0,0 +1,67 @@
|
|
1
|
+
import json
|
2
|
+
import re
|
3
|
+
from typing import List
|
4
|
+
|
5
|
+
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
6
|
+
from sglang.srt.function_call.core_types import (
|
7
|
+
StreamingParseResult,
|
8
|
+
StructureInfo,
|
9
|
+
_GetInfoFunc,
|
10
|
+
)
|
11
|
+
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
12
|
+
from sglang.srt.openai_api.protocol import Tool
|
13
|
+
|
14
|
+
|
15
|
+
class Qwen25Detector(BaseFormatDetector):
|
16
|
+
"""
|
17
|
+
Detector for Qwen 2.5 models.
|
18
|
+
Assumes function call format:
|
19
|
+
<tool_call>{"name":"xxx", "arguments":{...}}</tool_call>
|
20
|
+
"""
|
21
|
+
|
22
|
+
def __init__(self):
|
23
|
+
"""
|
24
|
+
Initializes the detector with necessary state variables.
|
25
|
+
"""
|
26
|
+
super().__init__()
|
27
|
+
self.bot_token = "<tool_call>"
|
28
|
+
self.eot_token = "</tool_call>"
|
29
|
+
|
30
|
+
def has_tool_call(self, text: str) -> bool:
|
31
|
+
"""Check if the text contains a Qwen 2.5 format tool call."""
|
32
|
+
return self.bot_token in text
|
33
|
+
|
34
|
+
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
35
|
+
"""
|
36
|
+
One-time parsing: Detects and parses tool calls in the provided text.
|
37
|
+
|
38
|
+
:param text: The complete text to parse.
|
39
|
+
:param tools: List of available tools.
|
40
|
+
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
|
41
|
+
"""
|
42
|
+
idx = text.find(self.bot_token)
|
43
|
+
normal_text = text[:idx].strip() if idx != -1 else text
|
44
|
+
if self.bot_token not in text:
|
45
|
+
return StreamingParseResult(normal_text=normal_text, calls=[])
|
46
|
+
pattern = rf"{self.bot_token}(.*?){self.eot_token}"
|
47
|
+
match_result_list = re.findall(pattern, text, re.DOTALL)
|
48
|
+
calls = []
|
49
|
+
for match_result in match_result_list:
|
50
|
+
match_result = json.loads(match_result)
|
51
|
+
calls.extend(self.parse_base_json(match_result, tools))
|
52
|
+
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
53
|
+
|
54
|
+
def structure_info(self) -> _GetInfoFunc:
|
55
|
+
return lambda name: StructureInfo(
|
56
|
+
begin='<tool_call>{"name":"' + name + '", "arguments":',
|
57
|
+
end="}</tool_call>",
|
58
|
+
trigger="<tool_call>",
|
59
|
+
)
|
60
|
+
|
61
|
+
def build_ebnf(self, tools: List[Tool]):
|
62
|
+
return EBNFComposer.build_ebnf(
|
63
|
+
tools,
|
64
|
+
bot_token=self.bot_token,
|
65
|
+
eot_token=self.eot_token,
|
66
|
+
function_format="json",
|
67
|
+
)
|
@@ -0,0 +1,35 @@
|
|
1
|
+
import json
|
2
|
+
from json import JSONDecodeError, JSONDecoder
|
3
|
+
from typing import Any, Tuple
|
4
|
+
|
5
|
+
import partial_json_parser
|
6
|
+
from partial_json_parser.core.options import Allow
|
7
|
+
|
8
|
+
|
9
|
+
def _find_common_prefix(s1: str, s2: str) -> str:
|
10
|
+
prefix = ""
|
11
|
+
min_length = min(len(s1), len(s2))
|
12
|
+
for i in range(0, min_length):
|
13
|
+
if s1[i] == s2[i]:
|
14
|
+
prefix += s1[i]
|
15
|
+
else:
|
16
|
+
break
|
17
|
+
return prefix
|
18
|
+
|
19
|
+
|
20
|
+
def _partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]:
|
21
|
+
try:
|
22
|
+
return (partial_json_parser.loads(input_str, flags), len(input_str))
|
23
|
+
except JSONDecodeError as e:
|
24
|
+
if "Extra data" in e.msg:
|
25
|
+
dec = JSONDecoder()
|
26
|
+
return dec.raw_decode(input_str)
|
27
|
+
raise
|
28
|
+
|
29
|
+
|
30
|
+
def _is_complete_json(input_str: str) -> bool:
|
31
|
+
try:
|
32
|
+
json.loads(input_str)
|
33
|
+
return True
|
34
|
+
except JSONDecodeError:
|
35
|
+
return False
|
@@ -19,7 +19,7 @@ import warnings
|
|
19
19
|
from pathlib import Path
|
20
20
|
from typing import Dict, Optional, Type, Union
|
21
21
|
|
22
|
-
import
|
22
|
+
import torch
|
23
23
|
from huggingface_hub import snapshot_download
|
24
24
|
from transformers import (
|
25
25
|
AutoConfig,
|
@@ -66,6 +66,43 @@ def download_from_hf(model_path: str):
|
|
66
66
|
return snapshot_download(model_path, allow_patterns=["*.json", "*.bin", "*.model"])
|
67
67
|
|
68
68
|
|
69
|
+
def get_hf_text_config(config: PretrainedConfig):
|
70
|
+
"""Get the "sub" config relevant to llm for multi modal models.
|
71
|
+
No op for pure text models.
|
72
|
+
"""
|
73
|
+
if config.architectures is not None:
|
74
|
+
class_name = config.architectures[0]
|
75
|
+
if class_name.startswith("Llava") and class_name.endswith("ForCausalLM"):
|
76
|
+
# We support non-hf version of llava models, so we do not want to
|
77
|
+
# read the wrong values from the unused default text_config.
|
78
|
+
# NOTE(HandH1998): We set `torch_dtype` of config to `torch.float16` for the weights, as
|
79
|
+
# `torch.float16` is default used for image features in `python/sglang/srt/models/llava.py`.
|
80
|
+
setattr(config, "torch_dtype", torch.float16)
|
81
|
+
return config
|
82
|
+
|
83
|
+
if hasattr(config, "text_config"):
|
84
|
+
# The code operates under the assumption that text_config should have
|
85
|
+
# `num_attention_heads` (among others). Assert here to fail early
|
86
|
+
# if transformers config doesn't align with this assumption.
|
87
|
+
assert hasattr(config.text_config, "num_attention_heads")
|
88
|
+
return config.text_config
|
89
|
+
if hasattr(config, "language_config"):
|
90
|
+
return config.language_config
|
91
|
+
if hasattr(config, "thinker_config"):
|
92
|
+
# qwen2.5 omni
|
93
|
+
thinker_config = config.thinker_config
|
94
|
+
if hasattr(thinker_config, "text_config"):
|
95
|
+
setattr(
|
96
|
+
thinker_config.text_config,
|
97
|
+
"torch_dtype",
|
98
|
+
getattr(thinker_config, "torch_dtype", None),
|
99
|
+
)
|
100
|
+
return thinker_config.text_config
|
101
|
+
return thinker_config
|
102
|
+
else:
|
103
|
+
return config
|
104
|
+
|
105
|
+
|
69
106
|
def get_config(
|
70
107
|
model: str,
|
71
108
|
trust_remote_code: bool,
|
@@ -81,13 +118,12 @@ def get_config(
|
|
81
118
|
config = AutoConfig.from_pretrained(
|
82
119
|
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
|
83
120
|
)
|
121
|
+
text_config = get_hf_text_config(config=config)
|
84
122
|
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
setattr(config, key, val)
|
90
|
-
setattr(config, "architectures", ["MultiModalityCausalLM"])
|
123
|
+
if isinstance(model, str) and text_config is not None:
|
124
|
+
for key, val in text_config.__dict__.items():
|
125
|
+
if not hasattr(config, key) and getattr(text_config, key, None) is not None:
|
126
|
+
setattr(config, key, val)
|
91
127
|
|
92
128
|
if config.model_type in _CONFIG_REGISTRY:
|
93
129
|
config_class = _CONFIG_REGISTRY[config.model_type]
|
@@ -100,6 +136,9 @@ def get_config(
|
|
100
136
|
if not hasattr(config, key):
|
101
137
|
setattr(config, key, val)
|
102
138
|
|
139
|
+
if config.model_type == "multi_modality":
|
140
|
+
config.update({"architectures": ["MultiModalityCausalLM"]})
|
141
|
+
|
103
142
|
if model_override_args:
|
104
143
|
config.update(model_override_args)
|
105
144
|
|