cortex-llm 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cortex/__init__.py +73 -0
- cortex/__main__.py +83 -0
- cortex/config.py +329 -0
- cortex/conversation_manager.py +468 -0
- cortex/fine_tuning/__init__.py +8 -0
- cortex/fine_tuning/dataset.py +332 -0
- cortex/fine_tuning/mlx_lora_trainer.py +502 -0
- cortex/fine_tuning/trainer.py +957 -0
- cortex/fine_tuning/wizard.py +707 -0
- cortex/gpu_validator.py +467 -0
- cortex/inference_engine.py +727 -0
- cortex/metal/__init__.py +275 -0
- cortex/metal/gpu_validator.py +177 -0
- cortex/metal/memory_pool.py +886 -0
- cortex/metal/mlx_accelerator.py +678 -0
- cortex/metal/mlx_converter.py +638 -0
- cortex/metal/mps_optimizer.py +417 -0
- cortex/metal/optimizer.py +665 -0
- cortex/metal/performance_profiler.py +364 -0
- cortex/model_downloader.py +130 -0
- cortex/model_manager.py +2187 -0
- cortex/quantization/__init__.py +5 -0
- cortex/quantization/dynamic_quantizer.py +736 -0
- cortex/template_registry/__init__.py +15 -0
- cortex/template_registry/auto_detector.py +144 -0
- cortex/template_registry/config_manager.py +234 -0
- cortex/template_registry/interactive.py +260 -0
- cortex/template_registry/registry.py +347 -0
- cortex/template_registry/template_profiles/__init__.py +5 -0
- cortex/template_registry/template_profiles/base.py +142 -0
- cortex/template_registry/template_profiles/complex/__init__.py +5 -0
- cortex/template_registry/template_profiles/complex/reasoning.py +263 -0
- cortex/template_registry/template_profiles/standard/__init__.py +9 -0
- cortex/template_registry/template_profiles/standard/alpaca.py +73 -0
- cortex/template_registry/template_profiles/standard/chatml.py +82 -0
- cortex/template_registry/template_profiles/standard/gemma.py +103 -0
- cortex/template_registry/template_profiles/standard/llama.py +87 -0
- cortex/template_registry/template_profiles/standard/simple.py +65 -0
- cortex/ui/__init__.py +120 -0
- cortex/ui/cli.py +1685 -0
- cortex/ui/markdown_render.py +185 -0
- cortex/ui/terminal_app.py +534 -0
- cortex_llm-1.0.0.dist-info/METADATA +275 -0
- cortex_llm-1.0.0.dist-info/RECORD +48 -0
- cortex_llm-1.0.0.dist-info/WHEEL +5 -0
- cortex_llm-1.0.0.dist-info/entry_points.txt +2 -0
- cortex_llm-1.0.0.dist-info/licenses/LICENSE +21 -0
- cortex_llm-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,263 @@
|
|
|
1
|
+
"""Reasoning-aware template profile for models with internal reasoning."""
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
from typing import List, Dict, Any, Tuple, Optional
|
|
5
|
+
from cortex.template_registry.template_profiles.base import BaseTemplateProfile, TemplateConfig, TemplateType
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ReasoningProfile(BaseTemplateProfile):
|
|
9
|
+
"""Profile for models with internal reasoning/chain-of-thought outputs."""
|
|
10
|
+
|
|
11
|
+
def __init__(self):
|
|
12
|
+
"""Initialize the reasoning profile with streaming state."""
|
|
13
|
+
super().__init__()
|
|
14
|
+
self._streaming_state = {
|
|
15
|
+
'in_final': False,
|
|
16
|
+
'buffer': '',
|
|
17
|
+
'final_marker_seen': False
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
def get_default_config(self) -> TemplateConfig:
|
|
21
|
+
"""Return the default reasoning configuration."""
|
|
22
|
+
return TemplateConfig(
|
|
23
|
+
name="Reasoning",
|
|
24
|
+
description="Models with internal reasoning/analysis channels",
|
|
25
|
+
template_type=TemplateType.REASONING,
|
|
26
|
+
supports_system_prompt=True,
|
|
27
|
+
supports_multi_turn=True,
|
|
28
|
+
strip_special_tokens=True,
|
|
29
|
+
show_reasoning=False, # By default, hide internal reasoning
|
|
30
|
+
custom_filters=[
|
|
31
|
+
"<|channel|>", "<|message|>", "<|end|>",
|
|
32
|
+
"<|start|>", "<|return|>", "<|endofprompt|>"
|
|
33
|
+
],
|
|
34
|
+
stop_sequences=["<|return|>", "<|endoftext|>", "<|endofprompt|>"]
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
def format_messages(self, messages: List[Dict[str, str]], add_generation_prompt: bool = True) -> str:
|
|
38
|
+
"""Format messages for reasoning-aware models."""
|
|
39
|
+
formatted = ""
|
|
40
|
+
|
|
41
|
+
# Handle system message
|
|
42
|
+
system_msg = None
|
|
43
|
+
for msg in messages:
|
|
44
|
+
if msg.get('role') == 'system':
|
|
45
|
+
system_msg = msg.get('content', '')
|
|
46
|
+
break
|
|
47
|
+
|
|
48
|
+
if system_msg:
|
|
49
|
+
formatted += f"<|start|>system<|message|>{system_msg}<|end|>"
|
|
50
|
+
|
|
51
|
+
# Format conversation
|
|
52
|
+
for msg in messages:
|
|
53
|
+
role = msg.get('role', 'user')
|
|
54
|
+
content = msg.get('content', '')
|
|
55
|
+
|
|
56
|
+
if role == 'user':
|
|
57
|
+
formatted += f"<|start|>user<|message|>{content}<|end|>"
|
|
58
|
+
elif role == 'assistant' and content:
|
|
59
|
+
# For assistant messages, use the final channel
|
|
60
|
+
formatted += f"<|start|>assistant<|channel|>final<|message|>{content}<|end|>"
|
|
61
|
+
|
|
62
|
+
if add_generation_prompt:
|
|
63
|
+
formatted += "<|start|>assistant"
|
|
64
|
+
|
|
65
|
+
return formatted
|
|
66
|
+
|
|
67
|
+
def process_response(self, raw_output: str) -> str:
|
|
68
|
+
"""Process reasoning model output to extract clean response."""
|
|
69
|
+
output = raw_output
|
|
70
|
+
|
|
71
|
+
# If we're showing reasoning, keep everything but clean up formatting
|
|
72
|
+
if self.config.show_reasoning:
|
|
73
|
+
return self._clean_reasoning_output(output)
|
|
74
|
+
|
|
75
|
+
# Otherwise, extract only the final response
|
|
76
|
+
return self._extract_final_response(output)
|
|
77
|
+
|
|
78
|
+
def _extract_final_response(self, output: str) -> str:
|
|
79
|
+
"""Extract only the final response, hiding internal reasoning."""
|
|
80
|
+
# Pattern to find final channel content
|
|
81
|
+
final_pattern = r'<\|channel\|>final<\|message\|>(.*?)(?:<\|end\|>|<\|return\|>|$)'
|
|
82
|
+
final_matches = re.findall(final_pattern, output, re.DOTALL)
|
|
83
|
+
|
|
84
|
+
if final_matches:
|
|
85
|
+
# Return the last final response
|
|
86
|
+
return final_matches[-1].strip()
|
|
87
|
+
|
|
88
|
+
# Fallback: look for content after channel markers
|
|
89
|
+
channel_pattern = r'<\|channel\|>\w+<\|message\|>(.*?)(?:<\|end\|>|<\|channel\|>|$)'
|
|
90
|
+
matches = re.findall(channel_pattern, output, re.DOTALL)
|
|
91
|
+
|
|
92
|
+
if matches:
|
|
93
|
+
# Return the last message
|
|
94
|
+
return matches[-1].strip()
|
|
95
|
+
|
|
96
|
+
# Check for common reasoning patterns in gpt-oss models
|
|
97
|
+
# These models sometimes output reasoning without proper channel markers
|
|
98
|
+
if "User says:" in output or "We can comply" in output or "There's no disallowed content" in output:
|
|
99
|
+
# This looks like leaked internal reasoning
|
|
100
|
+
# Try to extract a proper response if there is one
|
|
101
|
+
|
|
102
|
+
# Look for a response after the reasoning
|
|
103
|
+
lines = output.split('\n')
|
|
104
|
+
filtered_lines = []
|
|
105
|
+
for line in lines:
|
|
106
|
+
# Skip lines that look like internal reasoning
|
|
107
|
+
if any(pattern in line for pattern in [
|
|
108
|
+
"User says:", "We need to", "We can comply",
|
|
109
|
+
"There's no disallowed", "There's no policy",
|
|
110
|
+
"So we comply", "It's fine", "The user wants"
|
|
111
|
+
]):
|
|
112
|
+
continue
|
|
113
|
+
# Keep lines that look like actual responses
|
|
114
|
+
if line.strip():
|
|
115
|
+
filtered_lines.append(line)
|
|
116
|
+
|
|
117
|
+
if filtered_lines:
|
|
118
|
+
return '\n'.join(filtered_lines).strip()
|
|
119
|
+
|
|
120
|
+
# If everything looks like reasoning, return a generic error
|
|
121
|
+
return "I apologize, but I'm having trouble generating a proper response. Please try rephrasing your request."
|
|
122
|
+
|
|
123
|
+
# Last resort: remove all special tokens
|
|
124
|
+
cleaned = output
|
|
125
|
+
for token in self.config.custom_filters:
|
|
126
|
+
cleaned = cleaned.replace(token, " ")
|
|
127
|
+
|
|
128
|
+
# Clean up multiple spaces
|
|
129
|
+
cleaned = re.sub(r'\s+', ' ', cleaned)
|
|
130
|
+
|
|
131
|
+
return cleaned.strip()
|
|
132
|
+
|
|
133
|
+
def _clean_reasoning_output(self, output: str) -> str:
|
|
134
|
+
"""Clean up reasoning output for display."""
|
|
135
|
+
# Replace channel markers with readable labels
|
|
136
|
+
output = re.sub(r'<\|channel\|>analysis<\|message\|>', '\n[Analysis] ', output)
|
|
137
|
+
output = re.sub(r'<\|channel\|>commentary<\|message\|>', '\n[Commentary] ', output)
|
|
138
|
+
output = re.sub(r'<\|channel\|>final<\|message\|>', '\n[Response] ', output)
|
|
139
|
+
|
|
140
|
+
# Remove other special tokens
|
|
141
|
+
for token in ["<|end|>", "<|start|>", "<|return|>", "<|message|>"]:
|
|
142
|
+
output = output.replace(token, "")
|
|
143
|
+
|
|
144
|
+
# Clean up role markers
|
|
145
|
+
output = re.sub(r'<\|start\|>assistant', '', output)
|
|
146
|
+
|
|
147
|
+
return output.strip()
|
|
148
|
+
|
|
149
|
+
def can_handle(self, model_name: str, tokenizer: Any = None) -> Tuple[bool, float]:
|
|
150
|
+
"""Check if this profile can handle the model."""
|
|
151
|
+
model_lower = model_name.lower()
|
|
152
|
+
|
|
153
|
+
# High confidence for known reasoning models
|
|
154
|
+
if any(name in model_lower for name in ['gpt-oss', 'reasoning', 'cot', 'chain-of-thought']):
|
|
155
|
+
return True, 0.9
|
|
156
|
+
|
|
157
|
+
# Check tokenizer for reasoning tokens
|
|
158
|
+
if tokenizer:
|
|
159
|
+
try:
|
|
160
|
+
vocab = getattr(tokenizer, 'get_vocab', lambda: {})()
|
|
161
|
+
reasoning_tokens = ['<|channel|>', '<|message|>', '<|start|>', '<|end|>']
|
|
162
|
+
if any(token in vocab for token in reasoning_tokens):
|
|
163
|
+
return True, 0.85
|
|
164
|
+
|
|
165
|
+
# Check special tokens map
|
|
166
|
+
special_tokens = getattr(tokenizer, 'special_tokens_map', {})
|
|
167
|
+
special_tokens_str = str(special_tokens)
|
|
168
|
+
if any(token in special_tokens_str for token in reasoning_tokens):
|
|
169
|
+
return True, 0.85
|
|
170
|
+
except:
|
|
171
|
+
pass
|
|
172
|
+
|
|
173
|
+
return False, 0.0
|
|
174
|
+
|
|
175
|
+
def set_show_reasoning(self, show: bool) -> None:
|
|
176
|
+
"""Toggle whether to show internal reasoning."""
|
|
177
|
+
self.config.show_reasoning = show
|
|
178
|
+
|
|
179
|
+
def reset_streaming_state(self):
|
|
180
|
+
"""Reset streaming state for new response."""
|
|
181
|
+
self._streaming_state = {
|
|
182
|
+
'in_final': False,
|
|
183
|
+
'buffer': '',
|
|
184
|
+
'final_marker_seen': False
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
def process_streaming_response(self, token: str, accumulated: str) -> Tuple[str, bool]:
|
|
188
|
+
"""Process tokens in streaming mode for reasoning models.
|
|
189
|
+
|
|
190
|
+
Returns:
|
|
191
|
+
Tuple of (output_token, should_display)
|
|
192
|
+
- output_token: The token to display (may be empty)
|
|
193
|
+
- should_display: Whether this token should be shown to user
|
|
194
|
+
"""
|
|
195
|
+
# If showing reasoning, pass through everything with formatting
|
|
196
|
+
if self.config.show_reasoning:
|
|
197
|
+
# Simple pass-through with basic formatting
|
|
198
|
+
return token, True
|
|
199
|
+
|
|
200
|
+
# Add token to buffer
|
|
201
|
+
self._streaming_state['buffer'] += token
|
|
202
|
+
buffer = self._streaming_state['buffer']
|
|
203
|
+
|
|
204
|
+
# State machine for filtering
|
|
205
|
+
if not self._streaming_state['in_final']:
|
|
206
|
+
# Look for final channel marker
|
|
207
|
+
if '<|channel|>final<|message|>' in buffer:
|
|
208
|
+
# Found it! Transition to final state
|
|
209
|
+
self._streaming_state['in_final'] = True
|
|
210
|
+
# Clear buffer of everything up to and including the marker
|
|
211
|
+
idx = buffer.index('<|channel|>final<|message|>')
|
|
212
|
+
self._streaming_state['buffer'] = buffer[idx + len('<|channel|>final<|message|>'):]
|
|
213
|
+
# Don't output anything yet
|
|
214
|
+
return '', False
|
|
215
|
+
else:
|
|
216
|
+
# Still accumulating, check if we might be building a marker
|
|
217
|
+
# Keep last 30 chars in buffer to handle partial markers
|
|
218
|
+
if len(buffer) > 30:
|
|
219
|
+
self._streaming_state['buffer'] = buffer[-30:]
|
|
220
|
+
return '', False
|
|
221
|
+
else:
|
|
222
|
+
# We're in final channel, output everything except end markers
|
|
223
|
+
output = ''
|
|
224
|
+
remaining = ''
|
|
225
|
+
|
|
226
|
+
# Check for end markers
|
|
227
|
+
if '<|end|>' in buffer:
|
|
228
|
+
# Output everything before the end marker
|
|
229
|
+
idx = buffer.index('<|end|>')
|
|
230
|
+
output = buffer[:idx]
|
|
231
|
+
self._streaming_state['buffer'] = ''
|
|
232
|
+
# Reset for potential next response
|
|
233
|
+
self.reset_streaming_state()
|
|
234
|
+
elif '<|return|>' in buffer:
|
|
235
|
+
# Output everything before the return marker
|
|
236
|
+
idx = buffer.index('<|return|>')
|
|
237
|
+
output = buffer[:idx]
|
|
238
|
+
self._streaming_state['buffer'] = ''
|
|
239
|
+
# Reset for potential next response
|
|
240
|
+
self.reset_streaming_state()
|
|
241
|
+
elif '<|start|>' in buffer:
|
|
242
|
+
# Another message starting, output what we have
|
|
243
|
+
idx = buffer.index('<|start|>')
|
|
244
|
+
output = buffer[:idx]
|
|
245
|
+
self._streaming_state['buffer'] = buffer[idx:]
|
|
246
|
+
self._streaming_state['in_final'] = False
|
|
247
|
+
else:
|
|
248
|
+
# Check if we might be building an end marker
|
|
249
|
+
potential_markers = ['<', '<|', '<|e', '<|en', '<|end', '<|end|',
|
|
250
|
+
'<|r', '<|re', '<|ret', '<|retu', '<|retur', '<|return',
|
|
251
|
+
'<|s', '<|st', '<|sta', '<|star', '<|start']
|
|
252
|
+
for marker in potential_markers:
|
|
253
|
+
if buffer.endswith(marker):
|
|
254
|
+
# Might be building a marker, output everything except potential marker
|
|
255
|
+
output = buffer[:-len(marker)]
|
|
256
|
+
self._streaming_state['buffer'] = marker
|
|
257
|
+
return output, bool(output)
|
|
258
|
+
|
|
259
|
+
# No potential markers, output everything
|
|
260
|
+
output = buffer
|
|
261
|
+
self._streaming_state['buffer'] = ''
|
|
262
|
+
|
|
263
|
+
return output, bool(output)
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
"""Standard template profiles."""
|
|
2
|
+
|
|
3
|
+
from cortex.template_registry.template_profiles.standard.chatml import ChatMLProfile
|
|
4
|
+
from cortex.template_registry.template_profiles.standard.llama import LlamaProfile
|
|
5
|
+
from cortex.template_registry.template_profiles.standard.alpaca import AlpacaProfile
|
|
6
|
+
from cortex.template_registry.template_profiles.standard.simple import SimpleProfile
|
|
7
|
+
from cortex.template_registry.template_profiles.standard.gemma import GemmaProfile
|
|
8
|
+
|
|
9
|
+
__all__ = ['ChatMLProfile', 'LlamaProfile', 'AlpacaProfile', 'SimpleProfile', 'GemmaProfile']
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
"""Alpaca template profile implementation."""
|
|
2
|
+
|
|
3
|
+
from typing import List, Dict, Any, Tuple
|
|
4
|
+
from cortex.template_registry.template_profiles.base import BaseTemplateProfile, TemplateConfig, TemplateType
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class AlpacaProfile(BaseTemplateProfile):
|
|
8
|
+
"""Alpaca instruction format."""
|
|
9
|
+
|
|
10
|
+
def get_default_config(self) -> TemplateConfig:
|
|
11
|
+
"""Return the default Alpaca configuration."""
|
|
12
|
+
return TemplateConfig(
|
|
13
|
+
name="Alpaca",
|
|
14
|
+
description="Alpaca instruction-following format",
|
|
15
|
+
template_type=TemplateType.ALPACA,
|
|
16
|
+
supports_system_prompt=False,
|
|
17
|
+
supports_multi_turn=False,
|
|
18
|
+
strip_special_tokens=False,
|
|
19
|
+
stop_sequences=["### Human:", "### Assistant:", "\n\n###"]
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
def format_messages(self, messages: List[Dict[str, str]], add_generation_prompt: bool = True) -> str:
|
|
23
|
+
"""Format messages in Alpaca style."""
|
|
24
|
+
formatted = ""
|
|
25
|
+
|
|
26
|
+
# Alpaca is primarily single-turn, but we'll handle multi-turn
|
|
27
|
+
instruction = ""
|
|
28
|
+
input_text = ""
|
|
29
|
+
|
|
30
|
+
for msg in messages:
|
|
31
|
+
if msg.get('role') == 'system':
|
|
32
|
+
instruction = msg.get('content', '')
|
|
33
|
+
elif msg.get('role') == 'user':
|
|
34
|
+
if instruction:
|
|
35
|
+
input_text = msg.get('content', '')
|
|
36
|
+
else:
|
|
37
|
+
instruction = msg.get('content', '')
|
|
38
|
+
|
|
39
|
+
# Format as Alpaca
|
|
40
|
+
if instruction and input_text:
|
|
41
|
+
formatted = f"### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n"
|
|
42
|
+
elif instruction:
|
|
43
|
+
formatted = f"### Instruction:\n{instruction}\n\n"
|
|
44
|
+
|
|
45
|
+
if add_generation_prompt:
|
|
46
|
+
formatted += "### Response:\n"
|
|
47
|
+
|
|
48
|
+
return formatted
|
|
49
|
+
|
|
50
|
+
def process_response(self, raw_output: str) -> str:
|
|
51
|
+
"""Process Alpaca output."""
|
|
52
|
+
output = raw_output
|
|
53
|
+
|
|
54
|
+
# Remove any instruction markers that might appear
|
|
55
|
+
for marker in ["### Instruction:", "### Input:", "### Response:", "### Human:", "### Assistant:"]:
|
|
56
|
+
if marker in output:
|
|
57
|
+
output = output.split(marker)[0]
|
|
58
|
+
|
|
59
|
+
return output.strip()
|
|
60
|
+
|
|
61
|
+
def can_handle(self, model_name: str, tokenizer: Any = None) -> Tuple[bool, float]:
|
|
62
|
+
"""Check if this profile can handle the model."""
|
|
63
|
+
model_lower = model_name.lower()
|
|
64
|
+
|
|
65
|
+
# High confidence for Alpaca models
|
|
66
|
+
if 'alpaca' in model_lower:
|
|
67
|
+
return True, 0.95
|
|
68
|
+
|
|
69
|
+
# Medium confidence for instruction-tuned models
|
|
70
|
+
if any(name in model_lower for name in ['instruct', 'instruction']):
|
|
71
|
+
return True, 0.5
|
|
72
|
+
|
|
73
|
+
return False, 0.0
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
"""ChatML template profile implementation."""
|
|
2
|
+
|
|
3
|
+
from typing import List, Dict, Any, Tuple
|
|
4
|
+
from cortex.template_registry.template_profiles.base import BaseTemplateProfile, TemplateConfig, TemplateType
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ChatMLProfile(BaseTemplateProfile):
|
|
8
|
+
"""ChatML format used by models like Qwen, OpenHermes, etc."""
|
|
9
|
+
|
|
10
|
+
def get_default_config(self) -> TemplateConfig:
|
|
11
|
+
"""Return the default ChatML configuration."""
|
|
12
|
+
return TemplateConfig(
|
|
13
|
+
name="ChatML",
|
|
14
|
+
description="ChatML format with <|im_start|> and <|im_end|> tokens",
|
|
15
|
+
template_type=TemplateType.CHATML,
|
|
16
|
+
supports_system_prompt=True,
|
|
17
|
+
supports_multi_turn=True,
|
|
18
|
+
strip_special_tokens=False,
|
|
19
|
+
stop_sequences=["<|im_end|>", "<|endoftext|>"]
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
def format_messages(self, messages: List[Dict[str, str]], add_generation_prompt: bool = True) -> str:
|
|
23
|
+
"""Format messages in ChatML style."""
|
|
24
|
+
formatted = ""
|
|
25
|
+
|
|
26
|
+
# Add default system message if none provided
|
|
27
|
+
if not messages or messages[0].get('role') != 'system':
|
|
28
|
+
formatted += "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
|
|
29
|
+
|
|
30
|
+
for message in messages:
|
|
31
|
+
role = message.get('role', 'user')
|
|
32
|
+
content = message.get('content', '')
|
|
33
|
+
|
|
34
|
+
formatted += f"<|im_start|>{role}\n{content}<|im_end|>\n"
|
|
35
|
+
|
|
36
|
+
if add_generation_prompt:
|
|
37
|
+
formatted += "<|im_start|>assistant\n"
|
|
38
|
+
|
|
39
|
+
return formatted
|
|
40
|
+
|
|
41
|
+
def process_response(self, raw_output: str) -> str:
|
|
42
|
+
"""Process ChatML output."""
|
|
43
|
+
# Remove special tokens if they appear in output
|
|
44
|
+
output = raw_output
|
|
45
|
+
|
|
46
|
+
# Remove end tokens
|
|
47
|
+
for token in ["<|im_end|>", "<|endoftext|>", "<|im_start|>"]:
|
|
48
|
+
output = output.replace(token, "")
|
|
49
|
+
|
|
50
|
+
# Clean up any role markers that might appear
|
|
51
|
+
if output.startswith("assistant\n"):
|
|
52
|
+
output = output[10:] # Remove "assistant\n"
|
|
53
|
+
|
|
54
|
+
return output.strip()
|
|
55
|
+
|
|
56
|
+
def can_handle(self, model_name: str, tokenizer: Any = None) -> Tuple[bool, float]:
|
|
57
|
+
"""Check if this profile can handle the model."""
|
|
58
|
+
model_lower = model_name.lower()
|
|
59
|
+
|
|
60
|
+
# High confidence for known ChatML models
|
|
61
|
+
if any(name in model_lower for name in ['qwen', 'openhermes', 'neural-chat']):
|
|
62
|
+
return True, 0.9
|
|
63
|
+
|
|
64
|
+
# Check tokenizer for ChatML tokens
|
|
65
|
+
if tokenizer:
|
|
66
|
+
try:
|
|
67
|
+
vocab = getattr(tokenizer, 'get_vocab', lambda: {})()
|
|
68
|
+
if '<|im_start|>' in vocab or '<|im_end|>' in vocab:
|
|
69
|
+
return True, 0.8
|
|
70
|
+
|
|
71
|
+
# Check special tokens
|
|
72
|
+
special_tokens = getattr(tokenizer, 'special_tokens_map', {})
|
|
73
|
+
if any('<|im_start|>' in str(v) for v in special_tokens.values()):
|
|
74
|
+
return True, 0.8
|
|
75
|
+
except:
|
|
76
|
+
pass
|
|
77
|
+
|
|
78
|
+
# Check for ChatML in model name
|
|
79
|
+
if 'chatml' in model_lower or 'chat-ml' in model_lower:
|
|
80
|
+
return True, 0.95
|
|
81
|
+
|
|
82
|
+
return False, 0.0
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
"""Gemma template profile for Google Gemma models."""
|
|
2
|
+
|
|
3
|
+
from typing import List, Dict, Any, Tuple
|
|
4
|
+
from cortex.template_registry.template_profiles.base import BaseTemplateProfile, TemplateConfig, TemplateType
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class GemmaProfile(BaseTemplateProfile):
|
|
8
|
+
"""Template profile for Google Gemma models."""
|
|
9
|
+
|
|
10
|
+
def get_default_config(self) -> TemplateConfig:
|
|
11
|
+
"""Return the default Gemma configuration."""
|
|
12
|
+
return TemplateConfig(
|
|
13
|
+
name="Gemma",
|
|
14
|
+
description="Google Gemma chat template format",
|
|
15
|
+
template_type=TemplateType.GEMMA,
|
|
16
|
+
supports_system_prompt=True,
|
|
17
|
+
supports_multi_turn=True,
|
|
18
|
+
strip_special_tokens=True,
|
|
19
|
+
stop_sequences=["<end_of_turn>", "<eos>"]
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
def format_messages(self, messages: List[Dict[str, str]], add_generation_prompt: bool = True) -> str:
|
|
23
|
+
"""Format messages using Gemma chat template format."""
|
|
24
|
+
formatted = ""
|
|
25
|
+
|
|
26
|
+
for msg in messages:
|
|
27
|
+
role = msg.get('role', 'user')
|
|
28
|
+
content = msg.get('content', '')
|
|
29
|
+
|
|
30
|
+
if role == 'system':
|
|
31
|
+
# Gemma system messages are treated as user messages with special formatting
|
|
32
|
+
formatted += f"<start_of_turn>user\n{content}<end_of_turn>\n"
|
|
33
|
+
elif role == 'user':
|
|
34
|
+
formatted += f"<start_of_turn>user\n{content}<end_of_turn>\n"
|
|
35
|
+
elif role == 'assistant':
|
|
36
|
+
formatted += f"<start_of_turn>model\n{content}<end_of_turn>\n"
|
|
37
|
+
|
|
38
|
+
if add_generation_prompt:
|
|
39
|
+
formatted += "<start_of_turn>model\n"
|
|
40
|
+
|
|
41
|
+
return formatted
|
|
42
|
+
|
|
43
|
+
def process_response(self, raw_output: str) -> str:
|
|
44
|
+
"""Process Gemma model output to clean it up."""
|
|
45
|
+
output = raw_output
|
|
46
|
+
|
|
47
|
+
# Stop at the first occurrence of any stop token
|
|
48
|
+
stop_tokens = ["<end_of_turn>", "<eos>", "</s>"]
|
|
49
|
+
for token in stop_tokens:
|
|
50
|
+
if token in output:
|
|
51
|
+
output = output.split(token)[0]
|
|
52
|
+
|
|
53
|
+
# Remove Gemma-specific tokens that might have leaked through
|
|
54
|
+
gemma_tokens = [
|
|
55
|
+
"<start_of_turn>", "model\n", "user\n", "assistant\n"
|
|
56
|
+
]
|
|
57
|
+
|
|
58
|
+
for token in gemma_tokens:
|
|
59
|
+
output = output.replace(token, "")
|
|
60
|
+
|
|
61
|
+
# Remove any role markers that might have been added by incorrect templates
|
|
62
|
+
role_markers = ["Assistant:", "User:", "Human:", "System:", "Model:"]
|
|
63
|
+
for marker in role_markers:
|
|
64
|
+
if output.startswith(marker):
|
|
65
|
+
output = output[len(marker):].strip()
|
|
66
|
+
output = output.replace(f"\n{marker}", "\n")
|
|
67
|
+
|
|
68
|
+
# Clean up extra whitespace
|
|
69
|
+
lines = output.split('\n')
|
|
70
|
+
cleaned_lines = []
|
|
71
|
+
for line in lines:
|
|
72
|
+
cleaned_line = line.strip()
|
|
73
|
+
if cleaned_line:
|
|
74
|
+
cleaned_lines.append(cleaned_line)
|
|
75
|
+
|
|
76
|
+
return '\n'.join(cleaned_lines).strip()
|
|
77
|
+
|
|
78
|
+
def can_handle(self, model_name: str, tokenizer: Any = None) -> Tuple[bool, float]:
|
|
79
|
+
"""Check if this profile can handle Gemma models."""
|
|
80
|
+
model_name_lower = model_name.lower()
|
|
81
|
+
|
|
82
|
+
# High confidence for Gemma models
|
|
83
|
+
if 'gemma' in model_name_lower:
|
|
84
|
+
return True, 0.9
|
|
85
|
+
|
|
86
|
+
# Check tokenizer for Gemma-specific tokens
|
|
87
|
+
if tokenizer and hasattr(tokenizer, 'vocab'):
|
|
88
|
+
vocab = getattr(tokenizer, 'vocab', {})
|
|
89
|
+
if isinstance(vocab, dict):
|
|
90
|
+
vocab_str = str(vocab.keys()).lower()
|
|
91
|
+
if '<start_of_turn>' in vocab_str or '<end_of_turn>' in vocab_str:
|
|
92
|
+
return True, 0.8
|
|
93
|
+
|
|
94
|
+
# Check for chat template patterns
|
|
95
|
+
if tokenizer and hasattr(tokenizer, 'chat_template'):
|
|
96
|
+
try:
|
|
97
|
+
template_str = str(tokenizer.chat_template).lower()
|
|
98
|
+
if '<start_of_turn>' in template_str or 'gemma' in template_str:
|
|
99
|
+
return True, 0.8
|
|
100
|
+
except:
|
|
101
|
+
pass
|
|
102
|
+
|
|
103
|
+
return False, 0.0
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
"""Llama template profile implementation."""
|
|
2
|
+
|
|
3
|
+
from typing import List, Dict, Any, Tuple
|
|
4
|
+
from cortex.template_registry.template_profiles.base import BaseTemplateProfile, TemplateConfig, TemplateType
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class LlamaProfile(BaseTemplateProfile):
|
|
8
|
+
"""Llama/Llama2 format with [INST] tokens."""
|
|
9
|
+
|
|
10
|
+
def get_default_config(self) -> TemplateConfig:
|
|
11
|
+
"""Return the default Llama configuration."""
|
|
12
|
+
return TemplateConfig(
|
|
13
|
+
name="Llama",
|
|
14
|
+
description="Llama format with [INST] tokens",
|
|
15
|
+
template_type=TemplateType.LLAMA,
|
|
16
|
+
supports_system_prompt=True,
|
|
17
|
+
supports_multi_turn=True,
|
|
18
|
+
strip_special_tokens=False,
|
|
19
|
+
stop_sequences=["</s>", "[/INST]"]
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
def format_messages(self, messages: List[Dict[str, str]], add_generation_prompt: bool = True) -> str:
|
|
23
|
+
"""Format messages in Llama style."""
|
|
24
|
+
formatted = ""
|
|
25
|
+
|
|
26
|
+
# Handle system message
|
|
27
|
+
system_msg = "You are a helpful assistant."
|
|
28
|
+
for msg in messages:
|
|
29
|
+
if msg.get('role') == 'system':
|
|
30
|
+
system_msg = msg.get('content', system_msg)
|
|
31
|
+
break
|
|
32
|
+
|
|
33
|
+
# Format conversation
|
|
34
|
+
conversation = []
|
|
35
|
+
for msg in messages:
|
|
36
|
+
if msg.get('role') == 'user':
|
|
37
|
+
conversation.append(('user', msg.get('content', '')))
|
|
38
|
+
elif msg.get('role') == 'assistant':
|
|
39
|
+
conversation.append(('assistant', msg.get('content', '')))
|
|
40
|
+
|
|
41
|
+
# Build the prompt
|
|
42
|
+
if conversation:
|
|
43
|
+
formatted = f"<s>[INST] <<SYS>>\n{system_msg}\n<</SYS>>\n\n"
|
|
44
|
+
|
|
45
|
+
for i, (role, content) in enumerate(conversation):
|
|
46
|
+
if role == 'user':
|
|
47
|
+
if i > 0:
|
|
48
|
+
formatted += f"<s>[INST] {content} [/INST] "
|
|
49
|
+
else:
|
|
50
|
+
formatted += f"{content} [/INST] "
|
|
51
|
+
elif role == 'assistant':
|
|
52
|
+
formatted += f"{content} </s>"
|
|
53
|
+
|
|
54
|
+
return formatted
|
|
55
|
+
|
|
56
|
+
def process_response(self, raw_output: str) -> str:
|
|
57
|
+
"""Process Llama output."""
|
|
58
|
+
output = raw_output
|
|
59
|
+
|
|
60
|
+
# Remove special tokens
|
|
61
|
+
for token in ["</s>", "<s>", "[INST]", "[/INST]", "<<SYS>>", "<</SYS>>"]:
|
|
62
|
+
output = output.replace(token, "")
|
|
63
|
+
|
|
64
|
+
return output.strip()
|
|
65
|
+
|
|
66
|
+
def can_handle(self, model_name: str, tokenizer: Any = None) -> Tuple[bool, float]:
|
|
67
|
+
"""Check if this profile can handle the model."""
|
|
68
|
+
model_lower = model_name.lower()
|
|
69
|
+
|
|
70
|
+
# High confidence for Llama models
|
|
71
|
+
if 'llama' in model_lower or 'codellama' in model_lower:
|
|
72
|
+
return True, 0.9
|
|
73
|
+
|
|
74
|
+
# Check for specific Llama-based models
|
|
75
|
+
if any(name in model_lower for name in ['vicuna', 'alpaca', 'guanaco']):
|
|
76
|
+
return True, 0.7
|
|
77
|
+
|
|
78
|
+
# Check tokenizer
|
|
79
|
+
if tokenizer:
|
|
80
|
+
try:
|
|
81
|
+
vocab = getattr(tokenizer, 'get_vocab', lambda: {})()
|
|
82
|
+
if '[INST]' in vocab or '[/INST]' in vocab:
|
|
83
|
+
return True, 0.8
|
|
84
|
+
except:
|
|
85
|
+
pass
|
|
86
|
+
|
|
87
|
+
return False, 0.0
|