mlx-code 0.0.2a0__tar.gz → 0.0.2a2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mlx_code-0.0.2a0/mlx_code.egg-info → mlx_code-0.0.2a2}/PKG-INFO +2 -14
- {mlx_code-0.0.2a0 → mlx_code-0.0.2a2}/README.md +1 -13
- {mlx_code-0.0.2a0 → mlx_code-0.0.2a2}/main.py +343 -98
- {mlx_code-0.0.2a0 → mlx_code-0.0.2a2/mlx_code.egg-info}/PKG-INFO +2 -14
- {mlx_code-0.0.2a0 → mlx_code-0.0.2a2}/setup.py +1 -1
- {mlx_code-0.0.2a0 → mlx_code-0.0.2a2}/LICENSE +0 -0
- {mlx_code-0.0.2a0 → mlx_code-0.0.2a2}/mlx_code.egg-info/SOURCES.txt +0 -0
- {mlx_code-0.0.2a0 → mlx_code-0.0.2a2}/mlx_code.egg-info/dependency_links.txt +0 -0
- {mlx_code-0.0.2a0 → mlx_code-0.0.2a2}/mlx_code.egg-info/entry_points.txt +0 -0
- {mlx_code-0.0.2a0 → mlx_code-0.0.2a2}/mlx_code.egg-info/requires.txt +0 -0
- {mlx_code-0.0.2a0 → mlx_code-0.0.2a2}/mlx_code.egg-info/top_level.txt +0 -0
- {mlx_code-0.0.2a0 → mlx_code-0.0.2a2}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: mlx-code
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.2a2
|
|
4
4
|
Summary: Local Claude Code for Mac
|
|
5
5
|
Home-page: https://github.com/JosefAlbers/mlx-code
|
|
6
6
|
Author: J Joe
|
|
@@ -52,19 +52,7 @@ mlx-code [options] [-- claude options]
|
|
|
52
52
|
| `--work` | `$CWD` | Working directory mirrored into the Claude session |
|
|
53
53
|
| `--home` | temp dir | Home directory for the Claude process |
|
|
54
54
|
|
|
55
|
-
Any extra arguments after `--` are forwarded to the `claude` CLI
|
|
56
|
-
|
|
57
|
-
| Command | What it does | Example |
|
|
58
|
-
|--------|--------------|--------|
|
|
59
|
-
| `mlx-code` | Start interactive mode | `mlx-code` |
|
|
60
|
-
| `mlx-code "task"` | Run a one-time task | `mlx-code "fix the build error"` |
|
|
61
|
-
| `mlx-code -p "query"` | Run one-off query, then exit | `mlx-code -p "explain this function"` |
|
|
62
|
-
| `mlx-code -c` | Continue most recent conversation in current directory | `mlx-code -c` |
|
|
63
|
-
| `mlx-code -r` | Resume a previous conversation | `mlx-code -r` |
|
|
64
|
-
| `mlx-code commit` | Create a Git commit | `mlx-code commit` |
|
|
65
|
-
| `/clear` | Clear conversation history | `/clear` |
|
|
66
|
-
| `/help` | Show available commands | `/help` |
|
|
67
|
-
| `exit` or `Ctrl+C` | Exit Claude Code | `exit` |
|
|
55
|
+
Any extra arguments after `--` are forwarded to the `claude` CLI.
|
|
68
56
|
|
|
69
57
|
### Licence
|
|
70
58
|
|
|
@@ -28,19 +28,7 @@ mlx-code [options] [-- claude options]
|
|
|
28
28
|
| `--work` | `$CWD` | Working directory mirrored into the Claude session |
|
|
29
29
|
| `--home` | temp dir | Home directory for the Claude process |
|
|
30
30
|
|
|
31
|
-
Any extra arguments after `--` are forwarded to the `claude` CLI
|
|
32
|
-
|
|
33
|
-
| Command | What it does | Example |
|
|
34
|
-
|--------|--------------|--------|
|
|
35
|
-
| `mlx-code` | Start interactive mode | `mlx-code` |
|
|
36
|
-
| `mlx-code "task"` | Run a one-time task | `mlx-code "fix the build error"` |
|
|
37
|
-
| `mlx-code -p "query"` | Run one-off query, then exit | `mlx-code -p "explain this function"` |
|
|
38
|
-
| `mlx-code -c` | Continue most recent conversation in current directory | `mlx-code -c` |
|
|
39
|
-
| `mlx-code -r` | Resume a previous conversation | `mlx-code -r` |
|
|
40
|
-
| `mlx-code commit` | Create a Git commit | `mlx-code commit` |
|
|
41
|
-
| `/clear` | Clear conversation history | `/clear` |
|
|
42
|
-
| `/help` | Show available commands | `/help` |
|
|
43
|
-
| `exit` or `Ctrl+C` | Exit Claude Code | `exit` |
|
|
31
|
+
Any extra arguments after `--` are forwarded to the `claude` CLI.
|
|
44
32
|
|
|
45
33
|
### Licence
|
|
46
34
|
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
# {{{
|
|
1
2
|
# Copyright 2026 J Joe
|
|
2
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
3
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -25,8 +26,22 @@ from http.server import BaseHTTPRequestHandler, HTTPServer
|
|
|
25
26
|
from pathlib import Path
|
|
26
27
|
import mlx.core as mx
|
|
27
28
|
import mlx_lm
|
|
28
|
-
|
|
29
|
+
import numpy as np
|
|
30
|
+
import hashlib
|
|
31
|
+
import contextlib
|
|
32
|
+
import functools
|
|
33
|
+
import mlx.nn as nn
|
|
34
|
+
from typing import (
|
|
35
|
+
Any,
|
|
36
|
+
Callable,
|
|
37
|
+
Generator,
|
|
38
|
+
List,
|
|
39
|
+
Optional,
|
|
40
|
+
Tuple,
|
|
41
|
+
Union,
|
|
42
|
+
)
|
|
29
43
|
|
|
44
|
+
generation_stream = mx.new_stream(mx.default_device())
|
|
30
45
|
stream_logger = logging.getLogger("stream")
|
|
31
46
|
stream_logger.setLevel(logging.DEBUG)
|
|
32
47
|
s_handler = logging.FileHandler("mlx_stream.log", mode='w')
|
|
@@ -39,7 +54,54 @@ t_handler = logging.FileHandler("mlx_trace.log", mode='w')
|
|
|
39
54
|
t_handler.setFormatter(logging.Formatter("【%(message)s\n】\n"))
|
|
40
55
|
trace_logger.addHandler(t_handler)
|
|
41
56
|
gen_lock = threading.Lock()
|
|
42
|
-
|
|
57
|
+
dict_cache = {}
|
|
58
|
+
|
|
59
|
+
def hash_tokens(tokens):
|
|
60
|
+
arr = np.array(tokens, dtype=np.uint32)
|
|
61
|
+
return hashlib.blake2b(arr.tobytes(), digest_size=8).hexdigest()
|
|
62
|
+
|
|
63
|
+
def get_common_len(a, b):
|
|
64
|
+
common_len = 0
|
|
65
|
+
for p, h in zip(a, b):
|
|
66
|
+
if p == h:
|
|
67
|
+
common_len += 1
|
|
68
|
+
else:
|
|
69
|
+
break
|
|
70
|
+
return common_len
|
|
71
|
+
|
|
72
|
+
@contextlib.contextmanager
|
|
73
|
+
def wired_limit(model: nn.Module, streams: Optional[List[mx.Stream]] = None):
|
|
74
|
+
if not mx.metal.is_available():
|
|
75
|
+
try:
|
|
76
|
+
yield
|
|
77
|
+
finally:
|
|
78
|
+
pass
|
|
79
|
+
else:
|
|
80
|
+
model_bytes = tree_reduce(
|
|
81
|
+
lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc, model, 0
|
|
82
|
+
)
|
|
83
|
+
max_rec_size = mx.device_info()["max_recommended_working_set_size"]
|
|
84
|
+
if model_bytes > 0.9 * max_rec_size:
|
|
85
|
+
model_mb = model_bytes // 2**20
|
|
86
|
+
max_rec_mb = max_rec_size // 2**20
|
|
87
|
+
print(f"{model_mb=} {max_rec_mb=}")
|
|
88
|
+
old_limit = mx.set_wired_limit(max_rec_size)
|
|
89
|
+
try:
|
|
90
|
+
yield
|
|
91
|
+
finally:
|
|
92
|
+
if streams is not None:
|
|
93
|
+
for s in streams:
|
|
94
|
+
mx.synchronize(s)
|
|
95
|
+
else:
|
|
96
|
+
mx.synchronize()
|
|
97
|
+
mx.set_wired_limit(old_limit)
|
|
98
|
+
|
|
99
|
+
def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits):
|
|
100
|
+
if kv_bits is None:
|
|
101
|
+
return
|
|
102
|
+
for e, c in enumerate(prompt_cache):
|
|
103
|
+
if hasattr(c, "to_quantized") and c.offset >= quantized_kv_start:
|
|
104
|
+
prompt_cache[e] = c.to_quantized(group_size=kv_group_size, bits=kv_bits)
|
|
43
105
|
|
|
44
106
|
def parse_tool(tools, names):
|
|
45
107
|
qwen_tools = []
|
|
@@ -47,17 +109,18 @@ def parse_tool(tools, names):
|
|
|
47
109
|
if names is not None and tool["name"] not in names:
|
|
48
110
|
continue
|
|
49
111
|
qwen_tool = {
|
|
50
|
-
"type": "function",
|
|
51
|
-
"function": {
|
|
112
|
+
# "type": "function",
|
|
113
|
+
# "function": {
|
|
52
114
|
"name": tool["name"],
|
|
53
115
|
"description": tool["description"],
|
|
54
116
|
"parameters": tool.get("input_schema", {
|
|
55
117
|
"type": "object",
|
|
56
118
|
"properties": {}
|
|
57
119
|
})
|
|
58
|
-
}
|
|
120
|
+
# }
|
|
59
121
|
}
|
|
60
|
-
params = qwen_tool["function"]["parameters"]
|
|
122
|
+
# params = qwen_tool["function"]["parameters"]
|
|
123
|
+
params = qwen_tool["parameters"]
|
|
61
124
|
params.pop("$schema", None)
|
|
62
125
|
qwen_tools.append(qwen_tool)
|
|
63
126
|
return qwen_tools
|
|
@@ -77,24 +140,21 @@ def encode(body, tokenizer, system, names, skips):
|
|
|
77
140
|
if block.get("type") != "text":
|
|
78
141
|
continue
|
|
79
142
|
text = block.get("text", "").strip()
|
|
80
|
-
if re.match(r'
|
|
143
|
+
if re.match(r'^x-anthropic-billing-header:\s?.*;$', text) and '\n' not in text:
|
|
81
144
|
continue
|
|
82
145
|
if text:
|
|
83
146
|
sys_parts.append(text)
|
|
84
147
|
if sys_parts:
|
|
85
148
|
msgs.append({"role": "system", "content": "\n\n".join(sys_parts)})
|
|
86
149
|
calls = {}
|
|
87
|
-
def skip(text, show_skipped=
|
|
150
|
+
def skip(text, show_skipped=False):
|
|
88
151
|
if skips is None:
|
|
89
152
|
return text
|
|
90
153
|
lines = []
|
|
91
154
|
for pattern in skips:
|
|
92
155
|
found = re.findall(pattern, text)
|
|
93
156
|
if found:
|
|
94
|
-
lines.append(
|
|
95
|
-
f"{pattern}\n" +
|
|
96
|
-
"\n".join(found)
|
|
97
|
-
)
|
|
157
|
+
lines.append(f"{pattern}\n" + "\n".join(found))
|
|
98
158
|
if lines and show_skipped:
|
|
99
159
|
trace_logger.debug("\n".join(["S"]+lines))
|
|
100
160
|
for pattern in skips:
|
|
@@ -109,7 +169,7 @@ def encode(body, tokenizer, system, names, skips):
|
|
|
109
169
|
for block in content:
|
|
110
170
|
t = block.get("type")
|
|
111
171
|
if t == "text":
|
|
112
|
-
parts['content'] = parts.get('content', '')
|
|
172
|
+
parts['content'] = parts.get('content', '') + skip(block['text'])
|
|
113
173
|
elif t == "thinking":
|
|
114
174
|
parts['reasoning_content'] = block['thinking']
|
|
115
175
|
elif t == "tool_use":
|
|
@@ -123,25 +183,39 @@ def encode(body, tokenizer, system, names, skips):
|
|
|
123
183
|
if parts:
|
|
124
184
|
msgs.append({"role": role}|parts)
|
|
125
185
|
if not msgs[-1].get('content', '').strip():
|
|
126
|
-
return None
|
|
127
|
-
|
|
186
|
+
return None, -1
|
|
187
|
+
apply_chat_template = lambda x: tokenizer.apply_chat_template(x, tools = parse_tool(body.get("tools", []), names), tokenize=False, add_generation_prompt=True)
|
|
188
|
+
full = apply_chat_template(msgs)
|
|
189
|
+
last_user_idx = max((i for i, m in enumerate(msgs) if m.get("role") == "user"), default=None)
|
|
190
|
+
if last_user_idx is None:
|
|
191
|
+
return full, -1
|
|
192
|
+
p_msgs = msgs[:last_user_idx] + [dict(role='user', content='h' if msgs[last_user_idx]['content'][0] != 'h' else 'i')]
|
|
193
|
+
pref = apply_chat_template(p_msgs)
|
|
194
|
+
return full, pref
|
|
128
195
|
|
|
129
|
-
def decode(raw_text, tokenizer, parse_think=
|
|
196
|
+
def decode(raw_text, tokenizer, parse_think, single_think=False):
|
|
197
|
+
def escape(text):
|
|
198
|
+
def repl(match):
|
|
199
|
+
inner = match.group(1)
|
|
200
|
+
inner = inner.replace('<', '‹').replace('>', '›')
|
|
201
|
+
return f'`{inner}`'
|
|
202
|
+
return re.sub(r'`([^\n`]*)`', repl, text)
|
|
203
|
+
raw_text = escape(raw_text)
|
|
130
204
|
raw_text = '<think>' + raw_text if (c := raw_text.find('</think>')) != -1 and ((o := raw_text.find('<think>')) == -1 or c < o) else raw_text
|
|
131
205
|
blocks = []
|
|
132
206
|
if parse_think:
|
|
133
|
-
parts = re.split(r'(<think>.*?</think>)', raw_text, flags=re.DOTALL)
|
|
207
|
+
parts = re.split(r'(<think>.*?</think>)', raw_text, flags=re.DOTALL, maxsplit=1 if single_think else 0)
|
|
134
208
|
else:
|
|
135
209
|
parts = [raw_text]
|
|
136
210
|
for part in parts:
|
|
137
211
|
if not part:
|
|
138
|
-
continue
|
|
139
|
-
if parse_think and part.startswith('<think>') and part.endswith('</think>'):
|
|
212
|
+
continue
|
|
213
|
+
if parse_think and not single_think and part.startswith('<think>') and part.endswith('</think>'):
|
|
140
214
|
thinking_content = part[7:-8].strip()
|
|
141
215
|
if thinking_content:
|
|
142
216
|
blocks.append({"type": "thinking", "thinking": thinking_content})
|
|
143
217
|
else:
|
|
144
|
-
blocks.append({"type": "text", "text": part})
|
|
218
|
+
blocks.append({"type": "text", "text": re.sub(r'</?think>', '‹think›', part)}) #: show tool call
|
|
145
219
|
tool_pattern = re.compile(r'<tool_call>(.*?)</tool_call>', re.DOTALL)
|
|
146
220
|
for match in tool_pattern.finditer(part):
|
|
147
221
|
content = match.group(1).strip()
|
|
@@ -185,7 +259,7 @@ def blocks_to_sse(blocks: list[dict], msg_id: str, in_tokens: int, out_tokens: i
|
|
|
185
259
|
elif bt == "tool_use":
|
|
186
260
|
out += event("content_block_start", {"type": "content_block_start", "index": i,
|
|
187
261
|
"content_block": {"type": "tool_use", "id": block["id"],
|
|
188
|
-
"name": block["name"], "input": {}}})
|
|
262
|
+
"name": block["name"], "input": {}} })
|
|
189
263
|
out += event("content_block_delta", {"type": "content_block_delta", "index": i,
|
|
190
264
|
"delta": {"type": "input_json_delta", "partial_json": json.dumps(block["input"])}})
|
|
191
265
|
out += event("content_block_stop", {"type": "content_block_stop", "index": i})
|
|
@@ -196,6 +270,8 @@ def blocks_to_sse(blocks: list[dict], msg_id: str, in_tokens: int, out_tokens: i
|
|
|
196
270
|
return bytes(out)
|
|
197
271
|
|
|
198
272
|
def dmca(p_str):
|
|
273
|
+
if True: #: False for recording
|
|
274
|
+
return p_str
|
|
199
275
|
symbols = ["▲", "△", "▶", "▷", "▼", "▽", "◀", "◁", "◆", "◇"]
|
|
200
276
|
def mask_text(text):
|
|
201
277
|
return re.sub(r"\S", lambda _: random.choice(symbols), text)
|
|
@@ -211,64 +287,6 @@ def dmca(p_str):
|
|
|
211
287
|
p_str = re.sub(pattern, lambda m: mask_text(m.group(0)), p_str)
|
|
212
288
|
return p_str
|
|
213
289
|
|
|
214
|
-
def generate(model, tokenizer, prompt, hook=None, max_tokens=256, helper_max_tokens=64, **kwargs):
|
|
215
|
-
global prompt_cache
|
|
216
|
-
if prompt is None:
|
|
217
|
-
return '', 0, 0
|
|
218
|
-
if not isinstance(tokenizer, mlx_lm.tokenizer_utils.TokenizerWrapper):
|
|
219
|
-
tokenizer = mlx_lm.tokenizer_utils.TokenizerWrapper(tokenizer)
|
|
220
|
-
detokenizer = tokenizer.detokenizer
|
|
221
|
-
if isinstance(prompt, str):
|
|
222
|
-
add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(tokenizer.bos_token)
|
|
223
|
-
prompt_s = prompt
|
|
224
|
-
prompt = tokenizer.encode(prompt, add_special_tokens=add_special_tokens)
|
|
225
|
-
else:
|
|
226
|
-
prompt_s = tokenizer.decode(prompt)
|
|
227
|
-
stream_logger.debug(dmca(prompt_s))
|
|
228
|
-
common_len = 0
|
|
229
|
-
if prompt_cache.get('cache', None):
|
|
230
|
-
for p, h in zip(prompt, prompt_cache['hx']):
|
|
231
|
-
if p == h:
|
|
232
|
-
common_len += 1
|
|
233
|
-
else:
|
|
234
|
-
break
|
|
235
|
-
else:
|
|
236
|
-
prompt_cache['hx'] = []
|
|
237
|
-
prompt_cache['cache'] = mlx_lm.models.cache.make_prompt_cache(model)
|
|
238
|
-
trim_len = len(prompt_cache['hx']) - common_len
|
|
239
|
-
mlx_lm.models.cache.trim_prompt_cache(prompt_cache['cache'], trim_len)
|
|
240
|
-
token_gen = generate_step(
|
|
241
|
-
mx.array(prompt[common_len:]),
|
|
242
|
-
model,
|
|
243
|
-
prompt_cache=prompt_cache['cache'],
|
|
244
|
-
max_tokens=max_tokens,
|
|
245
|
-
**kwargs,
|
|
246
|
-
)
|
|
247
|
-
text = ""
|
|
248
|
-
tic_non = time.perf_counter()
|
|
249
|
-
gens = []
|
|
250
|
-
for token, _ in token_gen:
|
|
251
|
-
gens.append(token)
|
|
252
|
-
if token in tokenizer.eos_token_ids:
|
|
253
|
-
break
|
|
254
|
-
detokenizer.add_token(token)
|
|
255
|
-
seg = detokenizer.last_segment
|
|
256
|
-
stream_logger.debug(seg)
|
|
257
|
-
text += seg
|
|
258
|
-
if len(gens) == 1:
|
|
259
|
-
tic_inp = time.perf_counter()
|
|
260
|
-
if prompt_cache.get('file_name'):
|
|
261
|
-
_fn = prompt_cache.pop('file_name')
|
|
262
|
-
mlx_lm.models.cache.save_prompt_cache(_fn, prompt_cache['cache'], metadata=dict(model_name=prompt_cache['model_name'], hx=json.dumps(prompt)))
|
|
263
|
-
if len(gens) >= max_tokens:
|
|
264
|
-
break
|
|
265
|
-
tic_out = time.perf_counter()
|
|
266
|
-
detokenizer.finalize()
|
|
267
|
-
text += detokenizer.last_segment
|
|
268
|
-
prompt_cache['hx'] = prompt+gens
|
|
269
|
-
trace_logger.debug(f'G {common_len} {trim_len}\n=== TPS ===\n- Processed {len(prompt)} input tokens in {tic_inp-tic_non:.0f} seconds ({len(prompt)/(tic_inp-tic_non):.0f} tokens per second)\n- Generated {len(gens)} new tokens in {tic_out-tic_inp:.0f} seconds ({len(gens)/(tic_out-tic_inp):.0f} tokens per second)\n\n=== INP ===\n{prompt_s}\n=== OUT ===\n{text}')
|
|
270
|
-
return text, len(prompt), len(gens)
|
|
271
|
-
|
|
272
290
|
def make_handler(model, tokenizer, system, names, skips, parse_think=True):
|
|
273
291
|
class Handler(BaseHTTPRequestHandler):
|
|
274
292
|
def log_message(self, fmt, *args):
|
|
@@ -299,9 +317,9 @@ def make_handler(model, tokenizer, system, names, skips, parse_think=True):
|
|
|
299
317
|
return
|
|
300
318
|
n = int(self.headers.get("Content-Length", 0))
|
|
301
319
|
body = json.loads(self.rfile.read(n))
|
|
302
|
-
prompt = encode(body, tokenizer, system, names, skips)
|
|
320
|
+
prompt, pref = encode(body, tokenizer, system, names, skips)
|
|
303
321
|
with gen_lock:
|
|
304
|
-
raw, in_tokens, out_tokens = generate(model, tokenizer, prompt=prompt, max_tokens=body.get("max_tokens", 8192))
|
|
322
|
+
raw, in_tokens, out_tokens = generate(model, tokenizer, pref=pref, prompt=prompt, max_tokens=body.get("max_tokens", 8192))
|
|
305
323
|
blocks, stop_reason = decode(raw, tokenizer, parse_think=parse_think)
|
|
306
324
|
msg_id = f"msg_{uuid.uuid4().hex}"
|
|
307
325
|
sse = blocks_to_sse(blocks, msg_id, in_tokens, out_tokens, stop_reason)
|
|
@@ -317,18 +335,33 @@ def make_handler(model, tokenizer, system, names, skips, parse_think=True):
|
|
|
317
335
|
pass
|
|
318
336
|
return Handler
|
|
319
337
|
|
|
338
|
+
def load_dict_cache(cache_path):
|
|
339
|
+
global dict_cache
|
|
340
|
+
cache, metadata = mlx_lm.models.cache.load_prompt_cache(cache_path, return_metadata=True)
|
|
341
|
+
mx.eval(cache)
|
|
342
|
+
model_name = metadata.pop("model_name", "")
|
|
343
|
+
tokens_str = metadata.pop("hx", "[]")
|
|
344
|
+
tokens = json.loads(tokens_str)
|
|
345
|
+
dict_cache = dict(cache=cache, hx=tokens, model_name=model_name)
|
|
346
|
+
|
|
347
|
+
def save_dict_cache(cache_path, metadata, prompt_cache):
|
|
348
|
+
mlx_lm.models.cache.save_prompt_cache(cache_path, prompt_cache, metadata=metadata)
|
|
349
|
+
|
|
320
350
|
def main():
|
|
321
351
|
parser = argparse.ArgumentParser()
|
|
322
352
|
parser.add_argument("--model", default="mlx-community/Qwen3.5-4B-OptiQ-4bit")
|
|
323
353
|
# parser.add_argument("--model", default="mlx-community/Qwen3.5-2B-OptiQ-4bit")
|
|
324
354
|
# parser.add_argument("--model", default="mlx-community/Qwen3.5-0.8B-MLX-bf16")
|
|
325
|
-
parser.add_argument("--system", type=str, default='
|
|
355
|
+
parser.add_argument("--system", type=str, default='')
|
|
356
|
+
# parser.add_argument("--system", type=str, default='# Env\n{env}')
|
|
326
357
|
# parser.add_argument("--system", type=str, default=None)
|
|
327
|
-
parser.add_argument("--cache", type=str, default='cache/cache.safetensors')
|
|
358
|
+
# parser.add_argument("--cache", type=str, default='cache/cache.safetensors')
|
|
359
|
+
parser.add_argument("--cache", type=str, default='cache')
|
|
360
|
+
# parser.add_argument("--names", nargs="+", default=[])
|
|
328
361
|
parser.add_argument("--names", nargs="+", default=['Read','Edit','Write','Grep','Glob','Bash','Agent','Skill'])
|
|
329
362
|
# parser.add_argument("--names", nargs="+", default=None)
|
|
330
363
|
parser.add_argument("--skips", nargs="+", default=[
|
|
331
|
-
r'(?m)^\[SUGGESTION MODE[\s\S]*'
|
|
364
|
+
r'(?m)^\[SUGGESTION MODE[\s\S]*',
|
|
332
365
|
r'(?m)^<system-reminder>[\s\S]*?^</system-reminder>\s*',
|
|
333
366
|
])
|
|
334
367
|
parser.add_argument("--port", type=int, default=8000)
|
|
@@ -336,20 +369,9 @@ def main():
|
|
|
336
369
|
parser.add_argument("--home", default=tempfile.mkdtemp())
|
|
337
370
|
parser.add_argument("--work", default=os.getcwd())
|
|
338
371
|
args, claude_args = parser.parse_known_args()
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
mx.eval(cache)
|
|
343
|
-
model_name = metadata.pop("model_name", "")
|
|
344
|
-
tokens_str = metadata.pop("hx", "[]")
|
|
345
|
-
tokens = json.loads(tokens_str)
|
|
346
|
-
prompt_cache = dict(cache=cache, hx=tokens, model_name=model_name)
|
|
347
|
-
if prompt_cache.get('model_name') != args.model:
|
|
348
|
-
prompt_cache = dict(model_name=args.model)
|
|
349
|
-
else:
|
|
350
|
-
Path(args.cache).parent.mkdir(parents=True, exist_ok=True)
|
|
351
|
-
prompt_cache = dict(model_name=args.model)
|
|
352
|
-
prompt_cache['file_name']=args.cache
|
|
372
|
+
Path(args.cache).mkdir(parents=True, exist_ok=True)
|
|
373
|
+
global dict_cache
|
|
374
|
+
dict_cache = dict(model_name=args.model, cache_dir = args.cache)
|
|
353
375
|
model, tokenizer = mlx_lm.load(args.model)
|
|
354
376
|
server = HTTPServer((args.host, args.port), make_handler(model, tokenizer, args.system, args.names, args.skips))
|
|
355
377
|
threading.Thread(target=server.serve_forever, daemon=True).start()
|
|
@@ -369,5 +391,228 @@ def main():
|
|
|
369
391
|
mirror_workspace(args.work, workspace)
|
|
370
392
|
sys.exit(subprocess.run(["claude"] + claude_args, env=env, cwd=workspace).returncode)
|
|
371
393
|
|
|
394
|
+
def generate_step(
|
|
395
|
+
prompt: mx.array,
|
|
396
|
+
model: nn.Module,
|
|
397
|
+
*,
|
|
398
|
+
max_tokens: int = 256,
|
|
399
|
+
sampler: Optional[Callable[[mx.array], mx.array]] = None,
|
|
400
|
+
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
|
401
|
+
max_kv_size: Optional[int] = None,
|
|
402
|
+
prompt_cache: Optional[Any] = None,
|
|
403
|
+
prefill_step_size: int = 2048,
|
|
404
|
+
kv_bits: Optional[int] = None,
|
|
405
|
+
kv_group_size: int = 64,
|
|
406
|
+
quantized_kv_start: int = 0,
|
|
407
|
+
prompt_progress_callback: Optional[Callable[[int, int], None]] = None,
|
|
408
|
+
input_embeddings: Optional[mx.array] = None,
|
|
409
|
+
save_at: int = -1,
|
|
410
|
+
save_fn = None,
|
|
411
|
+
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
|
412
|
+
if input_embeddings is not None:
|
|
413
|
+
if not does_model_support_input_embeddings(model):
|
|
414
|
+
raise ValueError("Model does not support input embeddings.")
|
|
415
|
+
elif len(prompt) > 0 and len(prompt) != len(input_embeddings):
|
|
416
|
+
raise ValueError(f"{len(input_embeddings)=} {len(prompt)=}")
|
|
417
|
+
elif len(prompt) == 0:
|
|
418
|
+
raise ValueError("Either input_embeddings or prompt (or both) must be provided.")
|
|
419
|
+
|
|
420
|
+
tokens = None
|
|
421
|
+
|
|
422
|
+
if prompt_cache is None:
|
|
423
|
+
prompt_cache = cache.make_prompt_cache(
|
|
424
|
+
model,
|
|
425
|
+
max_kv_size=max_kv_size,
|
|
426
|
+
)
|
|
427
|
+
|
|
428
|
+
prompt_progress_callback = prompt_progress_callback or (lambda *_: None)
|
|
429
|
+
|
|
430
|
+
quantize_cache_fn = functools.partial(
|
|
431
|
+
maybe_quantize_kv_cache,
|
|
432
|
+
quantized_kv_start=quantized_kv_start,
|
|
433
|
+
kv_group_size=kv_group_size,
|
|
434
|
+
kv_bits=kv_bits,
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
sampler = sampler or (lambda x: mx.argmax(x, axis=-1))
|
|
438
|
+
|
|
439
|
+
def _model_call(input_tokens: mx.array, input_embeddings: Optional[mx.array]):
|
|
440
|
+
if input_embeddings is not None:
|
|
441
|
+
return model(
|
|
442
|
+
input_tokens, cache=prompt_cache, input_embeddings=input_embeddings
|
|
443
|
+
)
|
|
444
|
+
else:
|
|
445
|
+
return model(input_tokens, cache=prompt_cache)
|
|
446
|
+
|
|
447
|
+
def _step(input_tokens: mx.array, input_embeddings: Optional[mx.array] = None):
|
|
448
|
+
nonlocal tokens
|
|
449
|
+
|
|
450
|
+
with mx.stream(generation_stream):
|
|
451
|
+
logits = _model_call(
|
|
452
|
+
input_tokens=input_tokens[None],
|
|
453
|
+
input_embeddings=(
|
|
454
|
+
input_embeddings[None] if input_embeddings is not None else None
|
|
455
|
+
),
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
logits = logits[:, -1, :]
|
|
459
|
+
|
|
460
|
+
if logits_processors and len(input_tokens) > 0:
|
|
461
|
+
tokens = (
|
|
462
|
+
mx.concat([tokens, input_tokens])
|
|
463
|
+
if tokens is not None
|
|
464
|
+
else input_tokens
|
|
465
|
+
)
|
|
466
|
+
for processor in logits_processors:
|
|
467
|
+
logits = processor(tokens, logits)
|
|
468
|
+
|
|
469
|
+
quantize_cache_fn(prompt_cache)
|
|
470
|
+
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
|
471
|
+
sampled = sampler(logprobs)
|
|
472
|
+
return sampled, logprobs.squeeze(0)
|
|
473
|
+
|
|
474
|
+
with mx.stream(generation_stream):
|
|
475
|
+
total_prompt_tokens = (
|
|
476
|
+
len(input_embeddings) if input_embeddings is not None else len(prompt)
|
|
477
|
+
)
|
|
478
|
+
prompt_processed_tokens = 0
|
|
479
|
+
prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens)
|
|
480
|
+
while total_prompt_tokens - prompt_processed_tokens > 1:
|
|
481
|
+
remaining = (total_prompt_tokens - prompt_processed_tokens) - 1
|
|
482
|
+
n_to_process = min(prefill_step_size, remaining)
|
|
483
|
+
if prompt_processed_tokens < save_at:
|
|
484
|
+
n_to_process = min(n_to_process, save_at - prompt_processed_tokens)
|
|
485
|
+
|
|
486
|
+
_model_call(
|
|
487
|
+
input_tokens=prompt[:n_to_process][None],
|
|
488
|
+
input_embeddings=(
|
|
489
|
+
input_embeddings[:n_to_process][None]
|
|
490
|
+
if input_embeddings is not None
|
|
491
|
+
else None
|
|
492
|
+
),
|
|
493
|
+
)
|
|
494
|
+
quantize_cache_fn(prompt_cache)
|
|
495
|
+
mx.eval([c.state for c in prompt_cache])
|
|
496
|
+
prompt_processed_tokens += n_to_process
|
|
497
|
+
prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens)
|
|
498
|
+
prompt = prompt[n_to_process:]
|
|
499
|
+
input_embeddings = (
|
|
500
|
+
input_embeddings[n_to_process:]
|
|
501
|
+
if input_embeddings is not None
|
|
502
|
+
else input_embeddings
|
|
503
|
+
)
|
|
504
|
+
mx.clear_cache()
|
|
505
|
+
if save_fn is not None and prompt_processed_tokens == save_at:
|
|
506
|
+
save_fn(prompt_cache)
|
|
507
|
+
|
|
508
|
+
y, logprobs = _step(input_tokens=prompt, input_embeddings=input_embeddings)
|
|
509
|
+
|
|
510
|
+
mx.async_eval(y, logprobs)
|
|
511
|
+
n = 0
|
|
512
|
+
while True:
|
|
513
|
+
if n != max_tokens:
|
|
514
|
+
next_y, next_logprobs = _step(y)
|
|
515
|
+
mx.async_eval(next_y, next_logprobs)
|
|
516
|
+
if n == 0:
|
|
517
|
+
mx.eval(y)
|
|
518
|
+
prompt_progress_callback(total_prompt_tokens, total_prompt_tokens)
|
|
519
|
+
if n == max_tokens:
|
|
520
|
+
break
|
|
521
|
+
yield y.item(), logprobs
|
|
522
|
+
if n % 256 == 0:
|
|
523
|
+
mx.clear_cache()
|
|
524
|
+
y, logprobs = next_y, next_logprobs
|
|
525
|
+
n += 1
|
|
526
|
+
|
|
527
|
+
def generate(model, tokenizer, prompt, pref, hook=None, max_tokens=256, helper_max_tokens=64, **kwargs):
|
|
528
|
+
global dict_cache
|
|
529
|
+
if prompt is None:
|
|
530
|
+
return '', 0, 0
|
|
531
|
+
if not isinstance(tokenizer, mlx_lm.tokenizer_utils.TokenizerWrapper):
|
|
532
|
+
tokenizer = mlx_lm.tokenizer_utils.TokenizerWrapper(tokenizer)
|
|
533
|
+
detokenizer = tokenizer.detokenizer
|
|
534
|
+
if isinstance(prompt, str):
|
|
535
|
+
add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(tokenizer.bos_token)
|
|
536
|
+
prompt_s = prompt
|
|
537
|
+
prompt = tokenizer.encode(prompt, add_special_tokens=add_special_tokens)
|
|
538
|
+
_pref = tokenizer.encode(pref, add_special_tokens=add_special_tokens)
|
|
539
|
+
save_at = get_common_len(prompt, _pref)
|
|
540
|
+
else:
|
|
541
|
+
prompt_s = tokenizer.decode(prompt)
|
|
542
|
+
save_at = -1 # □ for now
|
|
543
|
+
stream_logger.debug(dmca(prompt_s))
|
|
544
|
+
text = ''
|
|
545
|
+
gens = []
|
|
546
|
+
common_len = 0
|
|
547
|
+
hx_len = None
|
|
548
|
+
trim_len = None
|
|
549
|
+
save_fn = None
|
|
550
|
+
if not dict_cache.get('cache'):
|
|
551
|
+
ckpt_path = Path(dict_cache['cache_dir'])/f'{"".join(c for c in dict_cache["model_name"] if c.isalnum())}_{save_at}_{hash_tokens(prompt[:save_at])}.safetensors'
|
|
552
|
+
trace_logger.debug(ckpt_path.resolve())
|
|
553
|
+
trace_logger.debug(ckpt_path.absolute())
|
|
554
|
+
if os.path.exists(ckpt_path):
|
|
555
|
+
load_dict_cache(ckpt_path)
|
|
556
|
+
else:
|
|
557
|
+
dict_cache |= dict(cache=mlx_lm.models.cache.make_prompt_cache(model), hx=[])
|
|
558
|
+
save_fn = functools.partial(save_dict_cache, ckpt_path, dict(model_name=dict_cache['model_name'], hx=json.dumps(prompt[:save_at+1])))
|
|
559
|
+
|
|
560
|
+
if (hx := dict_cache.get('hx')):
|
|
561
|
+
_hx = hx[:-1]
|
|
562
|
+
common_len = get_common_len(prompt, _hx)
|
|
563
|
+
hx_len = len(_hx)
|
|
564
|
+
trim_len = hx_len - common_len
|
|
565
|
+
if trim_len > 0:
|
|
566
|
+
if all(c.is_trimmable() for c in dict_cache['cache']):
|
|
567
|
+
mlx_lm.models.cache.trim_prompt_cache(dict_cache['cache'], trim_len)
|
|
568
|
+
else:
|
|
569
|
+
ckpt_path = Path(dict_cache['cache_dir'])/f'{"".join(c for c in dict_cache["model_name"] if c.isalnum())}_{save_at}_{hash_tokens(prompt[:save_at])}.safetensors'
|
|
570
|
+
if os.path.exists(ckpt_path):
|
|
571
|
+
load_dict_cache(ckpt_path)
|
|
572
|
+
common_len = save_at
|
|
573
|
+
if save_at > common_len and not all(c.is_trimmable() for c in dict_cache['cache']):
|
|
574
|
+
ckpt_path = Path(dict_cache['cache_dir'])/f'{"".join(c for c in dict_cache["model_name"] if c.isalnum())}_{save_at}_{hash_tokens(prompt[:save_at])}.safetensors'
|
|
575
|
+
save_fn = functools.partial(save_dict_cache, ckpt_path, dict(model_name=dict_cache['model_name'], hx=json.dumps(prompt[:save_at+1])))
|
|
576
|
+
else:
|
|
577
|
+
save_at = -1
|
|
578
|
+
|
|
579
|
+
if common_len==len(prompt):
|
|
580
|
+
_last_gen = dict_cache['hx'][common_len]
|
|
581
|
+
prompt_arr = mx.array([_last_gen])
|
|
582
|
+
gens.append(_last_gen)
|
|
583
|
+
detokenizer.add(_last_gen)
|
|
584
|
+
else:
|
|
585
|
+
prompt_arr = mx.array(prompt[common_len:])
|
|
586
|
+
|
|
587
|
+
trace_logger.debug(f'{save_at=} {common_len=}')
|
|
588
|
+
token_gen = generate_step(
|
|
589
|
+
prompt_arr,
|
|
590
|
+
model,
|
|
591
|
+
prompt_cache=dict_cache['cache'],
|
|
592
|
+
max_tokens=max_tokens,
|
|
593
|
+
save_at=save_at-common_len,
|
|
594
|
+
save_fn=save_fn,
|
|
595
|
+
**kwargs,
|
|
596
|
+
)
|
|
597
|
+
tic_non = time.perf_counter()
|
|
598
|
+
for token, _ in token_gen:
|
|
599
|
+
gens.append(token)
|
|
600
|
+
if token in tokenizer.eos_token_ids:
|
|
601
|
+
break
|
|
602
|
+
detokenizer.add_token(token)
|
|
603
|
+
seg = detokenizer.last_segment
|
|
604
|
+
stream_logger.debug(seg)
|
|
605
|
+
text += seg
|
|
606
|
+
if len(gens) == 1:
|
|
607
|
+
tic_inp = time.perf_counter()
|
|
608
|
+
if len(gens) >= max_tokens:
|
|
609
|
+
break
|
|
610
|
+
tic_out = time.perf_counter()
|
|
611
|
+
detokenizer.finalize()
|
|
612
|
+
text += detokenizer.last_segment
|
|
613
|
+
dict_cache['hx'] = prompt+gens
|
|
614
|
+
trace_logger.debug(f'G {hx_len} {len(prompt)} {common_len} {trim_len} {len(gens)}\n=== TPS ===\n- Processed {len(prompt)} input tokens in {tic_inp-tic_non:.0f} seconds ({len(prompt)/(tic_inp-tic_non):.0f} tokens per second)\n- Generated {len(gens)} new tokens in {tic_out-tic_inp:.0f} seconds ({len(gens)/(tic_out-tic_inp):.0f} tokens per second)\n\n=== INP ===\n{dmca(prompt_s)}\n=== OUT ===\n{text}')
|
|
615
|
+
return text, len(prompt), len(gens)
|
|
616
|
+
|
|
372
617
|
if __name__ == "__main__":
|
|
373
618
|
main()
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: mlx-code
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.2a2
|
|
4
4
|
Summary: Local Claude Code for Mac
|
|
5
5
|
Home-page: https://github.com/JosefAlbers/mlx-code
|
|
6
6
|
Author: J Joe
|
|
@@ -52,19 +52,7 @@ mlx-code [options] [-- claude options]
|
|
|
52
52
|
| `--work` | `$CWD` | Working directory mirrored into the Claude session |
|
|
53
53
|
| `--home` | temp dir | Home directory for the Claude process |
|
|
54
54
|
|
|
55
|
-
Any extra arguments after `--` are forwarded to the `claude` CLI
|
|
56
|
-
|
|
57
|
-
| Command | What it does | Example |
|
|
58
|
-
|--------|--------------|--------|
|
|
59
|
-
| `mlx-code` | Start interactive mode | `mlx-code` |
|
|
60
|
-
| `mlx-code "task"` | Run a one-time task | `mlx-code "fix the build error"` |
|
|
61
|
-
| `mlx-code -p "query"` | Run one-off query, then exit | `mlx-code -p "explain this function"` |
|
|
62
|
-
| `mlx-code -c` | Continue most recent conversation in current directory | `mlx-code -c` |
|
|
63
|
-
| `mlx-code -r` | Resume a previous conversation | `mlx-code -r` |
|
|
64
|
-
| `mlx-code commit` | Create a Git commit | `mlx-code commit` |
|
|
65
|
-
| `/clear` | Clear conversation history | `/clear` |
|
|
66
|
-
| `/help` | Show available commands | `/help` |
|
|
67
|
-
| `exit` or `Ctrl+C` | Exit Claude Code | `exit` |
|
|
55
|
+
Any extra arguments after `--` are forwarded to the `claude` CLI.
|
|
68
56
|
|
|
69
57
|
### Licence
|
|
70
58
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|