cortex-llm 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. cortex/__init__.py +73 -0
  2. cortex/__main__.py +83 -0
  3. cortex/config.py +329 -0
  4. cortex/conversation_manager.py +468 -0
  5. cortex/fine_tuning/__init__.py +8 -0
  6. cortex/fine_tuning/dataset.py +332 -0
  7. cortex/fine_tuning/mlx_lora_trainer.py +502 -0
  8. cortex/fine_tuning/trainer.py +957 -0
  9. cortex/fine_tuning/wizard.py +707 -0
  10. cortex/gpu_validator.py +467 -0
  11. cortex/inference_engine.py +727 -0
  12. cortex/metal/__init__.py +275 -0
  13. cortex/metal/gpu_validator.py +177 -0
  14. cortex/metal/memory_pool.py +886 -0
  15. cortex/metal/mlx_accelerator.py +678 -0
  16. cortex/metal/mlx_converter.py +638 -0
  17. cortex/metal/mps_optimizer.py +417 -0
  18. cortex/metal/optimizer.py +665 -0
  19. cortex/metal/performance_profiler.py +364 -0
  20. cortex/model_downloader.py +130 -0
  21. cortex/model_manager.py +2187 -0
  22. cortex/quantization/__init__.py +5 -0
  23. cortex/quantization/dynamic_quantizer.py +736 -0
  24. cortex/template_registry/__init__.py +15 -0
  25. cortex/template_registry/auto_detector.py +144 -0
  26. cortex/template_registry/config_manager.py +234 -0
  27. cortex/template_registry/interactive.py +260 -0
  28. cortex/template_registry/registry.py +347 -0
  29. cortex/template_registry/template_profiles/__init__.py +5 -0
  30. cortex/template_registry/template_profiles/base.py +142 -0
  31. cortex/template_registry/template_profiles/complex/__init__.py +5 -0
  32. cortex/template_registry/template_profiles/complex/reasoning.py +263 -0
  33. cortex/template_registry/template_profiles/standard/__init__.py +9 -0
  34. cortex/template_registry/template_profiles/standard/alpaca.py +73 -0
  35. cortex/template_registry/template_profiles/standard/chatml.py +82 -0
  36. cortex/template_registry/template_profiles/standard/gemma.py +103 -0
  37. cortex/template_registry/template_profiles/standard/llama.py +87 -0
  38. cortex/template_registry/template_profiles/standard/simple.py +65 -0
  39. cortex/ui/__init__.py +120 -0
  40. cortex/ui/cli.py +1685 -0
  41. cortex/ui/markdown_render.py +185 -0
  42. cortex/ui/terminal_app.py +534 -0
  43. cortex_llm-1.0.0.dist-info/METADATA +275 -0
  44. cortex_llm-1.0.0.dist-info/RECORD +48 -0
  45. cortex_llm-1.0.0.dist-info/WHEEL +5 -0
  46. cortex_llm-1.0.0.dist-info/entry_points.txt +2 -0
  47. cortex_llm-1.0.0.dist-info/licenses/LICENSE +21 -0
  48. cortex_llm-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,263 @@
1
+ """Reasoning-aware template profile for models with internal reasoning."""
2
+
3
+ import re
4
+ from typing import List, Dict, Any, Tuple, Optional
5
+ from cortex.template_registry.template_profiles.base import BaseTemplateProfile, TemplateConfig, TemplateType
6
+
7
+
8
+ class ReasoningProfile(BaseTemplateProfile):
9
+ """Profile for models with internal reasoning/chain-of-thought outputs."""
10
+
11
+ def __init__(self):
12
+ """Initialize the reasoning profile with streaming state."""
13
+ super().__init__()
14
+ self._streaming_state = {
15
+ 'in_final': False,
16
+ 'buffer': '',
17
+ 'final_marker_seen': False
18
+ }
19
+
20
+ def get_default_config(self) -> TemplateConfig:
21
+ """Return the default reasoning configuration."""
22
+ return TemplateConfig(
23
+ name="Reasoning",
24
+ description="Models with internal reasoning/analysis channels",
25
+ template_type=TemplateType.REASONING,
26
+ supports_system_prompt=True,
27
+ supports_multi_turn=True,
28
+ strip_special_tokens=True,
29
+ show_reasoning=False, # By default, hide internal reasoning
30
+ custom_filters=[
31
+ "<|channel|>", "<|message|>", "<|end|>",
32
+ "<|start|>", "<|return|>", "<|endofprompt|>"
33
+ ],
34
+ stop_sequences=["<|return|>", "<|endoftext|>", "<|endofprompt|>"]
35
+ )
36
+
37
+ def format_messages(self, messages: List[Dict[str, str]], add_generation_prompt: bool = True) -> str:
38
+ """Format messages for reasoning-aware models."""
39
+ formatted = ""
40
+
41
+ # Handle system message
42
+ system_msg = None
43
+ for msg in messages:
44
+ if msg.get('role') == 'system':
45
+ system_msg = msg.get('content', '')
46
+ break
47
+
48
+ if system_msg:
49
+ formatted += f"<|start|>system<|message|>{system_msg}<|end|>"
50
+
51
+ # Format conversation
52
+ for msg in messages:
53
+ role = msg.get('role', 'user')
54
+ content = msg.get('content', '')
55
+
56
+ if role == 'user':
57
+ formatted += f"<|start|>user<|message|>{content}<|end|>"
58
+ elif role == 'assistant' and content:
59
+ # For assistant messages, use the final channel
60
+ formatted += f"<|start|>assistant<|channel|>final<|message|>{content}<|end|>"
61
+
62
+ if add_generation_prompt:
63
+ formatted += "<|start|>assistant"
64
+
65
+ return formatted
66
+
67
+ def process_response(self, raw_output: str) -> str:
68
+ """Process reasoning model output to extract clean response."""
69
+ output = raw_output
70
+
71
+ # If we're showing reasoning, keep everything but clean up formatting
72
+ if self.config.show_reasoning:
73
+ return self._clean_reasoning_output(output)
74
+
75
+ # Otherwise, extract only the final response
76
+ return self._extract_final_response(output)
77
+
78
+ def _extract_final_response(self, output: str) -> str:
79
+ """Extract only the final response, hiding internal reasoning."""
80
+ # Pattern to find final channel content
81
+ final_pattern = r'<\|channel\|>final<\|message\|>(.*?)(?:<\|end\|>|<\|return\|>|$)'
82
+ final_matches = re.findall(final_pattern, output, re.DOTALL)
83
+
84
+ if final_matches:
85
+ # Return the last final response
86
+ return final_matches[-1].strip()
87
+
88
+ # Fallback: look for content after channel markers
89
+ channel_pattern = r'<\|channel\|>\w+<\|message\|>(.*?)(?:<\|end\|>|<\|channel\|>|$)'
90
+ matches = re.findall(channel_pattern, output, re.DOTALL)
91
+
92
+ if matches:
93
+ # Return the last message
94
+ return matches[-1].strip()
95
+
96
+ # Check for common reasoning patterns in gpt-oss models
97
+ # These models sometimes output reasoning without proper channel markers
98
+ if "User says:" in output or "We can comply" in output or "There's no disallowed content" in output:
99
+ # This looks like leaked internal reasoning
100
+ # Try to extract a proper response if there is one
101
+
102
+ # Look for a response after the reasoning
103
+ lines = output.split('\n')
104
+ filtered_lines = []
105
+ for line in lines:
106
+ # Skip lines that look like internal reasoning
107
+ if any(pattern in line for pattern in [
108
+ "User says:", "We need to", "We can comply",
109
+ "There's no disallowed", "There's no policy",
110
+ "So we comply", "It's fine", "The user wants"
111
+ ]):
112
+ continue
113
+ # Keep lines that look like actual responses
114
+ if line.strip():
115
+ filtered_lines.append(line)
116
+
117
+ if filtered_lines:
118
+ return '\n'.join(filtered_lines).strip()
119
+
120
+ # If everything looks like reasoning, return a generic error
121
+ return "I apologize, but I'm having trouble generating a proper response. Please try rephrasing your request."
122
+
123
+ # Last resort: remove all special tokens
124
+ cleaned = output
125
+ for token in self.config.custom_filters:
126
+ cleaned = cleaned.replace(token, " ")
127
+
128
+ # Clean up multiple spaces
129
+ cleaned = re.sub(r'\s+', ' ', cleaned)
130
+
131
+ return cleaned.strip()
132
+
133
+ def _clean_reasoning_output(self, output: str) -> str:
134
+ """Clean up reasoning output for display."""
135
+ # Replace channel markers with readable labels
136
+ output = re.sub(r'<\|channel\|>analysis<\|message\|>', '\n[Analysis] ', output)
137
+ output = re.sub(r'<\|channel\|>commentary<\|message\|>', '\n[Commentary] ', output)
138
+ output = re.sub(r'<\|channel\|>final<\|message\|>', '\n[Response] ', output)
139
+
140
+ # Remove other special tokens
141
+ for token in ["<|end|>", "<|start|>", "<|return|>", "<|message|>"]:
142
+ output = output.replace(token, "")
143
+
144
+ # Clean up role markers
145
+ output = re.sub(r'<\|start\|>assistant', '', output)
146
+
147
+ return output.strip()
148
+
149
+ def can_handle(self, model_name: str, tokenizer: Any = None) -> Tuple[bool, float]:
150
+ """Check if this profile can handle the model."""
151
+ model_lower = model_name.lower()
152
+
153
+ # High confidence for known reasoning models
154
+ if any(name in model_lower for name in ['gpt-oss', 'reasoning', 'cot', 'chain-of-thought']):
155
+ return True, 0.9
156
+
157
+ # Check tokenizer for reasoning tokens
158
+ if tokenizer:
159
+ try:
160
+ vocab = getattr(tokenizer, 'get_vocab', lambda: {})()
161
+ reasoning_tokens = ['<|channel|>', '<|message|>', '<|start|>', '<|end|>']
162
+ if any(token in vocab for token in reasoning_tokens):
163
+ return True, 0.85
164
+
165
+ # Check special tokens map
166
+ special_tokens = getattr(tokenizer, 'special_tokens_map', {})
167
+ special_tokens_str = str(special_tokens)
168
+ if any(token in special_tokens_str for token in reasoning_tokens):
169
+ return True, 0.85
170
+ except:
171
+ pass
172
+
173
+ return False, 0.0
174
+
175
+ def set_show_reasoning(self, show: bool) -> None:
176
+ """Toggle whether to show internal reasoning."""
177
+ self.config.show_reasoning = show
178
+
179
+ def reset_streaming_state(self):
180
+ """Reset streaming state for new response."""
181
+ self._streaming_state = {
182
+ 'in_final': False,
183
+ 'buffer': '',
184
+ 'final_marker_seen': False
185
+ }
186
+
187
+ def process_streaming_response(self, token: str, accumulated: str) -> Tuple[str, bool]:
188
+ """Process tokens in streaming mode for reasoning models.
189
+
190
+ Returns:
191
+ Tuple of (output_token, should_display)
192
+ - output_token: The token to display (may be empty)
193
+ - should_display: Whether this token should be shown to user
194
+ """
195
+ # If showing reasoning, pass through everything with formatting
196
+ if self.config.show_reasoning:
197
+ # Simple pass-through with basic formatting
198
+ return token, True
199
+
200
+ # Add token to buffer
201
+ self._streaming_state['buffer'] += token
202
+ buffer = self._streaming_state['buffer']
203
+
204
+ # State machine for filtering
205
+ if not self._streaming_state['in_final']:
206
+ # Look for final channel marker
207
+ if '<|channel|>final<|message|>' in buffer:
208
+ # Found it! Transition to final state
209
+ self._streaming_state['in_final'] = True
210
+ # Clear buffer of everything up to and including the marker
211
+ idx = buffer.index('<|channel|>final<|message|>')
212
+ self._streaming_state['buffer'] = buffer[idx + len('<|channel|>final<|message|>'):]
213
+ # Don't output anything yet
214
+ return '', False
215
+ else:
216
+ # Still accumulating, check if we might be building a marker
217
+ # Keep last 30 chars in buffer to handle partial markers
218
+ if len(buffer) > 30:
219
+ self._streaming_state['buffer'] = buffer[-30:]
220
+ return '', False
221
+ else:
222
+ # We're in final channel, output everything except end markers
223
+ output = ''
224
+ remaining = ''
225
+
226
+ # Check for end markers
227
+ if '<|end|>' in buffer:
228
+ # Output everything before the end marker
229
+ idx = buffer.index('<|end|>')
230
+ output = buffer[:idx]
231
+ self._streaming_state['buffer'] = ''
232
+ # Reset for potential next response
233
+ self.reset_streaming_state()
234
+ elif '<|return|>' in buffer:
235
+ # Output everything before the return marker
236
+ idx = buffer.index('<|return|>')
237
+ output = buffer[:idx]
238
+ self._streaming_state['buffer'] = ''
239
+ # Reset for potential next response
240
+ self.reset_streaming_state()
241
+ elif '<|start|>' in buffer:
242
+ # Another message starting, output what we have
243
+ idx = buffer.index('<|start|>')
244
+ output = buffer[:idx]
245
+ self._streaming_state['buffer'] = buffer[idx:]
246
+ self._streaming_state['in_final'] = False
247
+ else:
248
+ # Check if we might be building an end marker
249
+ potential_markers = ['<', '<|', '<|e', '<|en', '<|end', '<|end|',
250
+ '<|r', '<|re', '<|ret', '<|retu', '<|retur', '<|return',
251
+ '<|s', '<|st', '<|sta', '<|star', '<|start']
252
+ for marker in potential_markers:
253
+ if buffer.endswith(marker):
254
+ # Might be building a marker, output everything except potential marker
255
+ output = buffer[:-len(marker)]
256
+ self._streaming_state['buffer'] = marker
257
+ return output, bool(output)
258
+
259
+ # No potential markers, output everything
260
+ output = buffer
261
+ self._streaming_state['buffer'] = ''
262
+
263
+ return output, bool(output)
@@ -0,0 +1,9 @@
1
+ """Standard template profiles."""
2
+
3
+ from cortex.template_registry.template_profiles.standard.chatml import ChatMLProfile
4
+ from cortex.template_registry.template_profiles.standard.llama import LlamaProfile
5
+ from cortex.template_registry.template_profiles.standard.alpaca import AlpacaProfile
6
+ from cortex.template_registry.template_profiles.standard.simple import SimpleProfile
7
+ from cortex.template_registry.template_profiles.standard.gemma import GemmaProfile
8
+
9
+ __all__ = ['ChatMLProfile', 'LlamaProfile', 'AlpacaProfile', 'SimpleProfile', 'GemmaProfile']
@@ -0,0 +1,73 @@
1
+ """Alpaca template profile implementation."""
2
+
3
+ from typing import List, Dict, Any, Tuple
4
+ from cortex.template_registry.template_profiles.base import BaseTemplateProfile, TemplateConfig, TemplateType
5
+
6
+
7
+ class AlpacaProfile(BaseTemplateProfile):
8
+ """Alpaca instruction format."""
9
+
10
+ def get_default_config(self) -> TemplateConfig:
11
+ """Return the default Alpaca configuration."""
12
+ return TemplateConfig(
13
+ name="Alpaca",
14
+ description="Alpaca instruction-following format",
15
+ template_type=TemplateType.ALPACA,
16
+ supports_system_prompt=False,
17
+ supports_multi_turn=False,
18
+ strip_special_tokens=False,
19
+ stop_sequences=["### Human:", "### Assistant:", "\n\n###"]
20
+ )
21
+
22
+ def format_messages(self, messages: List[Dict[str, str]], add_generation_prompt: bool = True) -> str:
23
+ """Format messages in Alpaca style."""
24
+ formatted = ""
25
+
26
+ # Alpaca is primarily single-turn, but we'll handle multi-turn
27
+ instruction = ""
28
+ input_text = ""
29
+
30
+ for msg in messages:
31
+ if msg.get('role') == 'system':
32
+ instruction = msg.get('content', '')
33
+ elif msg.get('role') == 'user':
34
+ if instruction:
35
+ input_text = msg.get('content', '')
36
+ else:
37
+ instruction = msg.get('content', '')
38
+
39
+ # Format as Alpaca
40
+ if instruction and input_text:
41
+ formatted = f"### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n"
42
+ elif instruction:
43
+ formatted = f"### Instruction:\n{instruction}\n\n"
44
+
45
+ if add_generation_prompt:
46
+ formatted += "### Response:\n"
47
+
48
+ return formatted
49
+
50
+ def process_response(self, raw_output: str) -> str:
51
+ """Process Alpaca output."""
52
+ output = raw_output
53
+
54
+ # Remove any instruction markers that might appear
55
+ for marker in ["### Instruction:", "### Input:", "### Response:", "### Human:", "### Assistant:"]:
56
+ if marker in output:
57
+ output = output.split(marker)[0]
58
+
59
+ return output.strip()
60
+
61
+ def can_handle(self, model_name: str, tokenizer: Any = None) -> Tuple[bool, float]:
62
+ """Check if this profile can handle the model."""
63
+ model_lower = model_name.lower()
64
+
65
+ # High confidence for Alpaca models
66
+ if 'alpaca' in model_lower:
67
+ return True, 0.95
68
+
69
+ # Medium confidence for instruction-tuned models
70
+ if any(name in model_lower for name in ['instruct', 'instruction']):
71
+ return True, 0.5
72
+
73
+ return False, 0.0
@@ -0,0 +1,82 @@
1
+ """ChatML template profile implementation."""
2
+
3
+ from typing import List, Dict, Any, Tuple
4
+ from cortex.template_registry.template_profiles.base import BaseTemplateProfile, TemplateConfig, TemplateType
5
+
6
+
7
+ class ChatMLProfile(BaseTemplateProfile):
8
+ """ChatML format used by models like Qwen, OpenHermes, etc."""
9
+
10
+ def get_default_config(self) -> TemplateConfig:
11
+ """Return the default ChatML configuration."""
12
+ return TemplateConfig(
13
+ name="ChatML",
14
+ description="ChatML format with <|im_start|> and <|im_end|> tokens",
15
+ template_type=TemplateType.CHATML,
16
+ supports_system_prompt=True,
17
+ supports_multi_turn=True,
18
+ strip_special_tokens=False,
19
+ stop_sequences=["<|im_end|>", "<|endoftext|>"]
20
+ )
21
+
22
+ def format_messages(self, messages: List[Dict[str, str]], add_generation_prompt: bool = True) -> str:
23
+ """Format messages in ChatML style."""
24
+ formatted = ""
25
+
26
+ # Add default system message if none provided
27
+ if not messages or messages[0].get('role') != 'system':
28
+ formatted += "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
29
+
30
+ for message in messages:
31
+ role = message.get('role', 'user')
32
+ content = message.get('content', '')
33
+
34
+ formatted += f"<|im_start|>{role}\n{content}<|im_end|>\n"
35
+
36
+ if add_generation_prompt:
37
+ formatted += "<|im_start|>assistant\n"
38
+
39
+ return formatted
40
+
41
+ def process_response(self, raw_output: str) -> str:
42
+ """Process ChatML output."""
43
+ # Remove special tokens if they appear in output
44
+ output = raw_output
45
+
46
+ # Remove end tokens
47
+ for token in ["<|im_end|>", "<|endoftext|>", "<|im_start|>"]:
48
+ output = output.replace(token, "")
49
+
50
+ # Clean up any role markers that might appear
51
+ if output.startswith("assistant\n"):
52
+ output = output[10:] # Remove "assistant\n"
53
+
54
+ return output.strip()
55
+
56
+ def can_handle(self, model_name: str, tokenizer: Any = None) -> Tuple[bool, float]:
57
+ """Check if this profile can handle the model."""
58
+ model_lower = model_name.lower()
59
+
60
+ # High confidence for known ChatML models
61
+ if any(name in model_lower for name in ['qwen', 'openhermes', 'neural-chat']):
62
+ return True, 0.9
63
+
64
+ # Check tokenizer for ChatML tokens
65
+ if tokenizer:
66
+ try:
67
+ vocab = getattr(tokenizer, 'get_vocab', lambda: {})()
68
+ if '<|im_start|>' in vocab or '<|im_end|>' in vocab:
69
+ return True, 0.8
70
+
71
+ # Check special tokens
72
+ special_tokens = getattr(tokenizer, 'special_tokens_map', {})
73
+ if any('<|im_start|>' in str(v) for v in special_tokens.values()):
74
+ return True, 0.8
75
+ except:
76
+ pass
77
+
78
+ # Check for ChatML in model name
79
+ if 'chatml' in model_lower or 'chat-ml' in model_lower:
80
+ return True, 0.95
81
+
82
+ return False, 0.0
@@ -0,0 +1,103 @@
1
+ """Gemma template profile for Google Gemma models."""
2
+
3
+ from typing import List, Dict, Any, Tuple
4
+ from cortex.template_registry.template_profiles.base import BaseTemplateProfile, TemplateConfig, TemplateType
5
+
6
+
7
+ class GemmaProfile(BaseTemplateProfile):
8
+ """Template profile for Google Gemma models."""
9
+
10
+ def get_default_config(self) -> TemplateConfig:
11
+ """Return the default Gemma configuration."""
12
+ return TemplateConfig(
13
+ name="Gemma",
14
+ description="Google Gemma chat template format",
15
+ template_type=TemplateType.GEMMA,
16
+ supports_system_prompt=True,
17
+ supports_multi_turn=True,
18
+ strip_special_tokens=True,
19
+ stop_sequences=["<end_of_turn>", "<eos>"]
20
+ )
21
+
22
+ def format_messages(self, messages: List[Dict[str, str]], add_generation_prompt: bool = True) -> str:
23
+ """Format messages using Gemma chat template format."""
24
+ formatted = ""
25
+
26
+ for msg in messages:
27
+ role = msg.get('role', 'user')
28
+ content = msg.get('content', '')
29
+
30
+ if role == 'system':
31
+ # Gemma system messages are treated as user messages with special formatting
32
+ formatted += f"<start_of_turn>user\n{content}<end_of_turn>\n"
33
+ elif role == 'user':
34
+ formatted += f"<start_of_turn>user\n{content}<end_of_turn>\n"
35
+ elif role == 'assistant':
36
+ formatted += f"<start_of_turn>model\n{content}<end_of_turn>\n"
37
+
38
+ if add_generation_prompt:
39
+ formatted += "<start_of_turn>model\n"
40
+
41
+ return formatted
42
+
43
+ def process_response(self, raw_output: str) -> str:
44
+ """Process Gemma model output to clean it up."""
45
+ output = raw_output
46
+
47
+ # Stop at the first occurrence of any stop token
48
+ stop_tokens = ["<end_of_turn>", "<eos>", "</s>"]
49
+ for token in stop_tokens:
50
+ if token in output:
51
+ output = output.split(token)[0]
52
+
53
+ # Remove Gemma-specific tokens that might have leaked through
54
+ gemma_tokens = [
55
+ "<start_of_turn>", "model\n", "user\n", "assistant\n"
56
+ ]
57
+
58
+ for token in gemma_tokens:
59
+ output = output.replace(token, "")
60
+
61
+ # Remove any role markers that might have been added by incorrect templates
62
+ role_markers = ["Assistant:", "User:", "Human:", "System:", "Model:"]
63
+ for marker in role_markers:
64
+ if output.startswith(marker):
65
+ output = output[len(marker):].strip()
66
+ output = output.replace(f"\n{marker}", "\n")
67
+
68
+ # Clean up extra whitespace
69
+ lines = output.split('\n')
70
+ cleaned_lines = []
71
+ for line in lines:
72
+ cleaned_line = line.strip()
73
+ if cleaned_line:
74
+ cleaned_lines.append(cleaned_line)
75
+
76
+ return '\n'.join(cleaned_lines).strip()
77
+
78
+ def can_handle(self, model_name: str, tokenizer: Any = None) -> Tuple[bool, float]:
79
+ """Check if this profile can handle Gemma models."""
80
+ model_name_lower = model_name.lower()
81
+
82
+ # High confidence for Gemma models
83
+ if 'gemma' in model_name_lower:
84
+ return True, 0.9
85
+
86
+ # Check tokenizer for Gemma-specific tokens
87
+ if tokenizer and hasattr(tokenizer, 'vocab'):
88
+ vocab = getattr(tokenizer, 'vocab', {})
89
+ if isinstance(vocab, dict):
90
+ vocab_str = str(vocab.keys()).lower()
91
+ if '<start_of_turn>' in vocab_str or '<end_of_turn>' in vocab_str:
92
+ return True, 0.8
93
+
94
+ # Check for chat template patterns
95
+ if tokenizer and hasattr(tokenizer, 'chat_template'):
96
+ try:
97
+ template_str = str(tokenizer.chat_template).lower()
98
+ if '<start_of_turn>' in template_str or 'gemma' in template_str:
99
+ return True, 0.8
100
+ except:
101
+ pass
102
+
103
+ return False, 0.0
@@ -0,0 +1,87 @@
1
+ """Llama template profile implementation."""
2
+
3
+ from typing import List, Dict, Any, Tuple
4
+ from cortex.template_registry.template_profiles.base import BaseTemplateProfile, TemplateConfig, TemplateType
5
+
6
+
7
+ class LlamaProfile(BaseTemplateProfile):
8
+ """Llama/Llama2 format with [INST] tokens."""
9
+
10
+ def get_default_config(self) -> TemplateConfig:
11
+ """Return the default Llama configuration."""
12
+ return TemplateConfig(
13
+ name="Llama",
14
+ description="Llama format with [INST] tokens",
15
+ template_type=TemplateType.LLAMA,
16
+ supports_system_prompt=True,
17
+ supports_multi_turn=True,
18
+ strip_special_tokens=False,
19
+ stop_sequences=["</s>", "[/INST]"]
20
+ )
21
+
22
+ def format_messages(self, messages: List[Dict[str, str]], add_generation_prompt: bool = True) -> str:
23
+ """Format messages in Llama style."""
24
+ formatted = ""
25
+
26
+ # Handle system message
27
+ system_msg = "You are a helpful assistant."
28
+ for msg in messages:
29
+ if msg.get('role') == 'system':
30
+ system_msg = msg.get('content', system_msg)
31
+ break
32
+
33
+ # Format conversation
34
+ conversation = []
35
+ for msg in messages:
36
+ if msg.get('role') == 'user':
37
+ conversation.append(('user', msg.get('content', '')))
38
+ elif msg.get('role') == 'assistant':
39
+ conversation.append(('assistant', msg.get('content', '')))
40
+
41
+ # Build the prompt
42
+ if conversation:
43
+ formatted = f"<s>[INST] <<SYS>>\n{system_msg}\n<</SYS>>\n\n"
44
+
45
+ for i, (role, content) in enumerate(conversation):
46
+ if role == 'user':
47
+ if i > 0:
48
+ formatted += f"<s>[INST] {content} [/INST] "
49
+ else:
50
+ formatted += f"{content} [/INST] "
51
+ elif role == 'assistant':
52
+ formatted += f"{content} </s>"
53
+
54
+ return formatted
55
+
56
+ def process_response(self, raw_output: str) -> str:
57
+ """Process Llama output."""
58
+ output = raw_output
59
+
60
+ # Remove special tokens
61
+ for token in ["</s>", "<s>", "[INST]", "[/INST]", "<<SYS>>", "<</SYS>>"]:
62
+ output = output.replace(token, "")
63
+
64
+ return output.strip()
65
+
66
+ def can_handle(self, model_name: str, tokenizer: Any = None) -> Tuple[bool, float]:
67
+ """Check if this profile can handle the model."""
68
+ model_lower = model_name.lower()
69
+
70
+ # High confidence for Llama models
71
+ if 'llama' in model_lower or 'codellama' in model_lower:
72
+ return True, 0.9
73
+
74
+ # Check for specific Llama-based models
75
+ if any(name in model_lower for name in ['vicuna', 'alpaca', 'guanaco']):
76
+ return True, 0.7
77
+
78
+ # Check tokenizer
79
+ if tokenizer:
80
+ try:
81
+ vocab = getattr(tokenizer, 'get_vocab', lambda: {})()
82
+ if '[INST]' in vocab or '[/INST]' in vocab:
83
+ return True, 0.8
84
+ except:
85
+ pass
86
+
87
+ return False, 0.0