mlx-code 0.0.2a1__tar.gz → 0.0.2a3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mlx_code-0.0.2a1/mlx_code.egg-info → mlx_code-0.0.2a3}/PKG-INFO +5 -15
- {mlx_code-0.0.2a1 → mlx_code-0.0.2a3}/README.md +5 -15
- {mlx_code-0.0.2a1 → mlx_code-0.0.2a3}/main.py +350 -101
- {mlx_code-0.0.2a1 → mlx_code-0.0.2a3/mlx_code.egg-info}/PKG-INFO +5 -15
- {mlx_code-0.0.2a1 → mlx_code-0.0.2a3}/mlx_code.egg-info/SOURCES.txt +1 -0
- {mlx_code-0.0.2a1 → mlx_code-0.0.2a3}/mlx_code.egg-info/top_level.txt +1 -0
- mlx_code-0.0.2a3/pie.py +832 -0
- {mlx_code-0.0.2a1 → mlx_code-0.0.2a3}/setup.py +2 -2
- {mlx_code-0.0.2a1 → mlx_code-0.0.2a3}/LICENSE +0 -0
- {mlx_code-0.0.2a1 → mlx_code-0.0.2a3}/mlx_code.egg-info/dependency_links.txt +0 -0
- {mlx_code-0.0.2a1 → mlx_code-0.0.2a3}/mlx_code.egg-info/entry_points.txt +0 -0
- {mlx_code-0.0.2a1 → mlx_code-0.0.2a3}/mlx_code.egg-info/requires.txt +0 -0
- {mlx_code-0.0.2a1 → mlx_code-0.0.2a3}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: mlx-code
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.2a3
|
|
4
4
|
Summary: Local Claude Code for Mac
|
|
5
5
|
Home-page: https://github.com/JosefAlbers/mlx-code
|
|
6
6
|
Author: J Joe
|
|
@@ -52,22 +52,12 @@ mlx-code [options] [-- claude options]
|
|
|
52
52
|
| `--work` | `$CWD` | Working directory mirrored into the Claude session |
|
|
53
53
|
| `--home` | temp dir | Home directory for the Claude process |
|
|
54
54
|
|
|
55
|
-
Any extra arguments after `--` are forwarded to the `claude` CLI
|
|
55
|
+
Any extra arguments after `--` are forwarded to the `claude` CLI.
|
|
56
56
|
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
| `mlx-code "task"` | Run a one-time task | `mlx-code "fix the build error"` |
|
|
61
|
-
| `mlx-code -p "query"` | Run one-off query, then exit | `mlx-code -p "explain this function"` |
|
|
62
|
-
| `mlx-code -c` | Continue most recent conversation in current directory | `mlx-code -c` |
|
|
63
|
-
| `mlx-code -r` | Resume a previous conversation | `mlx-code -r` |
|
|
64
|
-
| `mlx-code commit` | Create a Git commit | `mlx-code commit` |
|
|
65
|
-
| `/clear` | Clear conversation history | `/clear` |
|
|
66
|
-
| `/help` | Show available commands | `/help` |
|
|
67
|
-
| `exit` or `Ctrl+C` | Exit Claude Code | `exit` |
|
|
57
|
+
### Credits
|
|
58
|
+
|
|
59
|
+
`pie.py` is based on [pi](https://github.com/badlogic/pi-mono) by Mario Zechner, used under the MIT License.
|
|
68
60
|
|
|
69
61
|
### Licence
|
|
70
62
|
|
|
71
63
|
Apache License 2.0 — see LICENSE for details.
|
|
72
|
-
|
|
73
|
-
|
|
@@ -28,22 +28,12 @@ mlx-code [options] [-- claude options]
|
|
|
28
28
|
| `--work` | `$CWD` | Working directory mirrored into the Claude session |
|
|
29
29
|
| `--home` | temp dir | Home directory for the Claude process |
|
|
30
30
|
|
|
31
|
-
Any extra arguments after `--` are forwarded to the `claude` CLI
|
|
32
|
-
|
|
33
|
-
| Command | What it does | Example |
|
|
34
|
-
|--------|--------------|--------|
|
|
35
|
-
| `mlx-code` | Start interactive mode | `mlx-code` |
|
|
36
|
-
| `mlx-code "task"` | Run a one-time task | `mlx-code "fix the build error"` |
|
|
37
|
-
| `mlx-code -p "query"` | Run one-off query, then exit | `mlx-code -p "explain this function"` |
|
|
38
|
-
| `mlx-code -c` | Continue most recent conversation in current directory | `mlx-code -c` |
|
|
39
|
-
| `mlx-code -r` | Resume a previous conversation | `mlx-code -r` |
|
|
40
|
-
| `mlx-code commit` | Create a Git commit | `mlx-code commit` |
|
|
41
|
-
| `/clear` | Clear conversation history | `/clear` |
|
|
42
|
-
| `/help` | Show available commands | `/help` |
|
|
43
|
-
| `exit` or `Ctrl+C` | Exit Claude Code | `exit` |
|
|
31
|
+
Any extra arguments after `--` are forwarded to the `claude` CLI.
|
|
44
32
|
|
|
45
|
-
###
|
|
33
|
+
### Credits
|
|
46
34
|
|
|
47
|
-
|
|
35
|
+
`pie.py` is based on [pi](https://github.com/badlogic/pi-mono) by Mario Zechner, used under the MIT License.
|
|
48
36
|
|
|
37
|
+
### Licence
|
|
49
38
|
|
|
39
|
+
Apache License 2.0 — see LICENSE for details.
|
|
@@ -25,8 +25,22 @@ from http.server import BaseHTTPRequestHandler, HTTPServer
|
|
|
25
25
|
from pathlib import Path
|
|
26
26
|
import mlx.core as mx
|
|
27
27
|
import mlx_lm
|
|
28
|
-
|
|
28
|
+
import numpy as np
|
|
29
|
+
import hashlib
|
|
30
|
+
import contextlib
|
|
31
|
+
import functools
|
|
32
|
+
import mlx.nn as nn
|
|
33
|
+
from typing import (
|
|
34
|
+
Any,
|
|
35
|
+
Callable,
|
|
36
|
+
Generator,
|
|
37
|
+
List,
|
|
38
|
+
Optional,
|
|
39
|
+
Tuple,
|
|
40
|
+
Union,
|
|
41
|
+
)
|
|
29
42
|
|
|
43
|
+
generation_stream = mx.new_stream(mx.default_device())
|
|
30
44
|
stream_logger = logging.getLogger("stream")
|
|
31
45
|
stream_logger.setLevel(logging.DEBUG)
|
|
32
46
|
s_handler = logging.FileHandler("mlx_stream.log", mode='w')
|
|
@@ -39,7 +53,54 @@ t_handler = logging.FileHandler("mlx_trace.log", mode='w')
|
|
|
39
53
|
t_handler.setFormatter(logging.Formatter("【%(message)s\n】\n"))
|
|
40
54
|
trace_logger.addHandler(t_handler)
|
|
41
55
|
gen_lock = threading.Lock()
|
|
42
|
-
|
|
56
|
+
dict_cache = {}
|
|
57
|
+
|
|
58
|
+
def hash_tokens(tokens):
|
|
59
|
+
arr = np.array(tokens, dtype=np.uint32)
|
|
60
|
+
return hashlib.blake2b(arr.tobytes(), digest_size=8).hexdigest()
|
|
61
|
+
|
|
62
|
+
def get_common_len(a, b):
|
|
63
|
+
common_len = 0
|
|
64
|
+
for p, h in zip(a, b):
|
|
65
|
+
if p == h:
|
|
66
|
+
common_len += 1
|
|
67
|
+
else:
|
|
68
|
+
break
|
|
69
|
+
return common_len
|
|
70
|
+
|
|
71
|
+
@contextlib.contextmanager
|
|
72
|
+
def wired_limit(model: nn.Module, streams: Optional[List[mx.Stream]] = None):
|
|
73
|
+
if not mx.metal.is_available():
|
|
74
|
+
try:
|
|
75
|
+
yield
|
|
76
|
+
finally:
|
|
77
|
+
pass
|
|
78
|
+
else:
|
|
79
|
+
model_bytes = tree_reduce(
|
|
80
|
+
lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc, model, 0
|
|
81
|
+
)
|
|
82
|
+
max_rec_size = mx.device_info()["max_recommended_working_set_size"]
|
|
83
|
+
if model_bytes > 0.9 * max_rec_size:
|
|
84
|
+
model_mb = model_bytes // 2**20
|
|
85
|
+
max_rec_mb = max_rec_size // 2**20
|
|
86
|
+
print(f"{model_mb=} {max_rec_mb=}")
|
|
87
|
+
old_limit = mx.set_wired_limit(max_rec_size)
|
|
88
|
+
try:
|
|
89
|
+
yield
|
|
90
|
+
finally:
|
|
91
|
+
if streams is not None:
|
|
92
|
+
for s in streams:
|
|
93
|
+
mx.synchronize(s)
|
|
94
|
+
else:
|
|
95
|
+
mx.synchronize()
|
|
96
|
+
mx.set_wired_limit(old_limit)
|
|
97
|
+
|
|
98
|
+
def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits):
|
|
99
|
+
if kv_bits is None:
|
|
100
|
+
return
|
|
101
|
+
for e, c in enumerate(prompt_cache):
|
|
102
|
+
if hasattr(c, "to_quantized") and c.offset >= quantized_kv_start:
|
|
103
|
+
prompt_cache[e] = c.to_quantized(group_size=kv_group_size, bits=kv_bits)
|
|
43
104
|
|
|
44
105
|
def parse_tool(tools, names):
|
|
45
106
|
qwen_tools = []
|
|
@@ -121,11 +182,17 @@ def encode(body, tokenizer, system, names, skips):
|
|
|
121
182
|
if parts:
|
|
122
183
|
msgs.append({"role": role}|parts)
|
|
123
184
|
if not msgs[-1].get('content', '').strip():
|
|
124
|
-
return None
|
|
125
|
-
|
|
185
|
+
return None, ''
|
|
186
|
+
apply_chat_template = lambda x: tokenizer.apply_chat_template(x, tools = parse_tool(body.get("tools", []), names), tokenize=False, add_generation_prompt=True)
|
|
187
|
+
full = apply_chat_template(msgs)
|
|
188
|
+
last_user_idx = max((i for i, m in enumerate(msgs) if m.get("role") == "user"), default=None)
|
|
189
|
+
if last_user_idx is None:
|
|
190
|
+
return full, ''
|
|
191
|
+
p_msgs = msgs[:last_user_idx] + [dict(role='user', content='h' if msgs[last_user_idx]['content'][0] != 'h' else 'i')]
|
|
192
|
+
pref = apply_chat_template(p_msgs)
|
|
193
|
+
return full, pref
|
|
126
194
|
|
|
127
195
|
def decode(raw_text, tokenizer, parse_think, single_think=False):
|
|
128
|
-
# think_id = tokenizer.convert_tokens_to_ids("<think>")
|
|
129
196
|
def escape(text):
|
|
130
197
|
def repl(match):
|
|
131
198
|
inner = match.group(1)
|
|
@@ -139,7 +206,6 @@ def decode(raw_text, tokenizer, parse_think, single_think=False):
|
|
|
139
206
|
parts = re.split(r'(<think>.*?</think>)', raw_text, flags=re.DOTALL, maxsplit=1 if single_think else 0)
|
|
140
207
|
else:
|
|
141
208
|
parts = [raw_text]
|
|
142
|
-
|
|
143
209
|
for part in parts:
|
|
144
210
|
if not part:
|
|
145
211
|
continue
|
|
@@ -192,7 +258,7 @@ def blocks_to_sse(blocks: list[dict], msg_id: str, in_tokens: int, out_tokens: i
|
|
|
192
258
|
elif bt == "tool_use":
|
|
193
259
|
out += event("content_block_start", {"type": "content_block_start", "index": i,
|
|
194
260
|
"content_block": {"type": "tool_use", "id": block["id"],
|
|
195
|
-
"name": block["name"], "input": {}}})
|
|
261
|
+
"name": block["name"], "input": {}} })
|
|
196
262
|
out += event("content_block_delta", {"type": "content_block_delta", "index": i,
|
|
197
263
|
"delta": {"type": "input_json_delta", "partial_json": json.dumps(block["input"])}})
|
|
198
264
|
out += event("content_block_stop", {"type": "content_block_stop", "index": i})
|
|
@@ -203,6 +269,8 @@ def blocks_to_sse(blocks: list[dict], msg_id: str, in_tokens: int, out_tokens: i
|
|
|
203
269
|
return bytes(out)
|
|
204
270
|
|
|
205
271
|
def dmca(p_str):
|
|
272
|
+
if True: #: False for recording
|
|
273
|
+
return p_str
|
|
206
274
|
symbols = ["▲", "△", "▶", "▷", "▼", "▽", "◀", "◁", "◆", "◇"]
|
|
207
275
|
def mask_text(text):
|
|
208
276
|
return re.sub(r"\S", lambda _: random.choice(symbols), text)
|
|
@@ -218,66 +286,6 @@ def dmca(p_str):
|
|
|
218
286
|
p_str = re.sub(pattern, lambda m: mask_text(m.group(0)), p_str)
|
|
219
287
|
return p_str
|
|
220
288
|
|
|
221
|
-
def generate(model, tokenizer, prompt, hook=None, max_tokens=256, helper_max_tokens=64, **kwargs):
|
|
222
|
-
global prompt_cache
|
|
223
|
-
if prompt is None:
|
|
224
|
-
return '', 0, 0
|
|
225
|
-
if not isinstance(tokenizer, mlx_lm.tokenizer_utils.TokenizerWrapper):
|
|
226
|
-
tokenizer = mlx_lm.tokenizer_utils.TokenizerWrapper(tokenizer)
|
|
227
|
-
detokenizer = tokenizer.detokenizer
|
|
228
|
-
if isinstance(prompt, str):
|
|
229
|
-
add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(tokenizer.bos_token)
|
|
230
|
-
prompt_s = prompt
|
|
231
|
-
prompt = tokenizer.encode(prompt, add_special_tokens=add_special_tokens)
|
|
232
|
-
else:
|
|
233
|
-
prompt_s = tokenizer.decode(prompt)
|
|
234
|
-
stream_logger.debug(dmca(prompt_s))
|
|
235
|
-
common_len = 0
|
|
236
|
-
if prompt_cache.get('cache', None):
|
|
237
|
-
for p, h in zip(prompt, prompt_cache['hx']):
|
|
238
|
-
if p == h:
|
|
239
|
-
common_len += 1
|
|
240
|
-
else:
|
|
241
|
-
break
|
|
242
|
-
common_len = min(common_len, len(prompt) - 1)
|
|
243
|
-
else:
|
|
244
|
-
prompt_cache['hx'] = []
|
|
245
|
-
prompt_cache['cache'] = mlx_lm.models.cache.make_prompt_cache(model)
|
|
246
|
-
hx_len = len(prompt_cache['hx'])
|
|
247
|
-
trim_len = hx_len - common_len
|
|
248
|
-
mlx_lm.models.cache.trim_prompt_cache(prompt_cache['cache'], trim_len)
|
|
249
|
-
token_gen = generate_step(
|
|
250
|
-
mx.array(prompt[common_len:]),
|
|
251
|
-
model,
|
|
252
|
-
prompt_cache=prompt_cache['cache'],
|
|
253
|
-
max_tokens=max_tokens,
|
|
254
|
-
**kwargs,
|
|
255
|
-
)
|
|
256
|
-
text = ""
|
|
257
|
-
tic_non = time.perf_counter()
|
|
258
|
-
gens = []
|
|
259
|
-
for token, _ in token_gen:
|
|
260
|
-
gens.append(token)
|
|
261
|
-
if token in tokenizer.eos_token_ids:
|
|
262
|
-
break
|
|
263
|
-
detokenizer.add_token(token)
|
|
264
|
-
seg = detokenizer.last_segment
|
|
265
|
-
stream_logger.debug(seg)
|
|
266
|
-
text += seg
|
|
267
|
-
if len(gens) == 1:
|
|
268
|
-
tic_inp = time.perf_counter()
|
|
269
|
-
if prompt_cache.get('file_name'):
|
|
270
|
-
_fn = prompt_cache.pop('file_name')
|
|
271
|
-
mlx_lm.models.cache.save_prompt_cache(_fn, prompt_cache['cache'], metadata=dict(model_name=prompt_cache['model_name'], hx=json.dumps(prompt+gens)))
|
|
272
|
-
if len(gens) >= max_tokens:
|
|
273
|
-
break
|
|
274
|
-
tic_out = time.perf_counter()
|
|
275
|
-
detokenizer.finalize()
|
|
276
|
-
text += detokenizer.last_segment
|
|
277
|
-
prompt_cache['hx'] = prompt+gens
|
|
278
|
-
trace_logger.debug(f'G {hx_len} {len(prompt)} {common_len} {trim_len} {len(gens)}\n=== TPS ===\n- Processed {len(prompt)} input tokens in {tic_inp-tic_non:.0f} seconds ({len(prompt)/(tic_inp-tic_non):.0f} tokens per second)\n- Generated {len(gens)} new tokens in {tic_out-tic_inp:.0f} seconds ({len(gens)/(tic_out-tic_inp):.0f} tokens per second)\n\n=== INP ===\n{dmca(prompt_s)}\n=== OUT ===\n{text}')
|
|
279
|
-
return text, len(prompt), len(gens)
|
|
280
|
-
|
|
281
289
|
def make_handler(model, tokenizer, system, names, skips, parse_think=True):
|
|
282
290
|
class Handler(BaseHTTPRequestHandler):
|
|
283
291
|
def log_message(self, fmt, *args):
|
|
@@ -308,12 +316,13 @@ def make_handler(model, tokenizer, system, names, skips, parse_think=True):
|
|
|
308
316
|
return
|
|
309
317
|
n = int(self.headers.get("Content-Length", 0))
|
|
310
318
|
body = json.loads(self.rfile.read(n))
|
|
311
|
-
prompt = encode(body, tokenizer, system, names, skips)
|
|
319
|
+
prompt, pref = encode(body, tokenizer, system, names, skips)
|
|
312
320
|
with gen_lock:
|
|
313
|
-
raw, in_tokens, out_tokens = generate(model, tokenizer, prompt=prompt, max_tokens=body.get("max_tokens", 8192))
|
|
321
|
+
raw, in_tokens, out_tokens = generate(model, tokenizer, pref=pref, prompt=prompt, max_tokens=body.get("max_tokens", 8192))
|
|
314
322
|
blocks, stop_reason = decode(raw, tokenizer, parse_think=parse_think)
|
|
315
323
|
msg_id = f"msg_{uuid.uuid4().hex}"
|
|
316
324
|
sse = blocks_to_sse(blocks, msg_id, in_tokens, out_tokens, stop_reason)
|
|
325
|
+
trace_logger.debug(sse)
|
|
317
326
|
self.send_response(200)
|
|
318
327
|
self.send_header("Content-Type", "text/event-stream")
|
|
319
328
|
self.send_header("Cache-Control", "no-cache")
|
|
@@ -326,15 +335,29 @@ def make_handler(model, tokenizer, system, names, skips, parse_think=True):
|
|
|
326
335
|
pass
|
|
327
336
|
return Handler
|
|
328
337
|
|
|
338
|
+
def load_dict_cache(cache_path):
|
|
339
|
+
cache, metadata = mlx_lm.models.cache.load_prompt_cache(cache_path, return_metadata=True)
|
|
340
|
+
mx.eval(cache)
|
|
341
|
+
model_name = metadata.pop("model_name", "")
|
|
342
|
+
tokens_str = metadata.pop("hx", "[]")
|
|
343
|
+
tokens = json.loads(tokens_str)
|
|
344
|
+
global dict_cache
|
|
345
|
+
dict_cache |= dict(cache=cache, hx=tokens, model_name=model_name)
|
|
346
|
+
|
|
347
|
+
def save_dict_cache(cache_path, metadata, prompt_cache):
|
|
348
|
+
mlx_lm.models.cache.save_prompt_cache(cache_path, prompt_cache, metadata=metadata)
|
|
349
|
+
|
|
329
350
|
def main():
|
|
330
351
|
parser = argparse.ArgumentParser()
|
|
352
|
+
parser.add_argument("--harness", default=None)
|
|
353
|
+
# parser.add_argument("--harness", default="claude")
|
|
331
354
|
parser.add_argument("--model", default="mlx-community/Qwen3.5-4B-OptiQ-4bit")
|
|
332
355
|
# parser.add_argument("--model", default="mlx-community/Qwen3.5-2B-OptiQ-4bit")
|
|
333
356
|
# parser.add_argument("--model", default="mlx-community/Qwen3.5-0.8B-MLX-bf16")
|
|
334
357
|
parser.add_argument("--system", type=str, default='')
|
|
335
358
|
# parser.add_argument("--system", type=str, default='# Env\n{env}')
|
|
336
359
|
# parser.add_argument("--system", type=str, default=None)
|
|
337
|
-
parser.add_argument("--cache", type=str, default='cache
|
|
360
|
+
parser.add_argument("--cache", type=str, default='cache')
|
|
338
361
|
# parser.add_argument("--names", nargs="+", default=[])
|
|
339
362
|
parser.add_argument("--names", nargs="+", default=['Read','Edit','Write','Grep','Glob','Bash','Agent','Skill'])
|
|
340
363
|
# parser.add_argument("--names", nargs="+", default=None)
|
|
@@ -346,39 +369,265 @@ def main():
|
|
|
346
369
|
parser.add_argument("--host", default="127.0.0.1")
|
|
347
370
|
parser.add_argument("--home", default=tempfile.mkdtemp())
|
|
348
371
|
parser.add_argument("--work", default=os.getcwd())
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
model_name = metadata.pop("model_name", "")
|
|
355
|
-
tokens_str = metadata.pop("hx", "[]")
|
|
356
|
-
tokens = json.loads(tokens_str)
|
|
357
|
-
prompt_cache = dict(cache=cache, hx=tokens, model_name=model_name)
|
|
358
|
-
if prompt_cache.get('model_name') != args.model:
|
|
359
|
-
prompt_cache = dict(model_name=args.model)
|
|
360
|
-
else:
|
|
361
|
-
Path(args.cache).parent.mkdir(parents=True, exist_ok=True)
|
|
362
|
-
prompt_cache = dict(model_name=args.model)
|
|
363
|
-
prompt_cache['file_name']=args.cache
|
|
372
|
+
parser.add_argument("--nocc", action="store_true", help="Disable Claude Code subprocess and run server only")
|
|
373
|
+
args, harness_args = parser.parse_known_args()
|
|
374
|
+
Path(args.cache).mkdir(parents=True, exist_ok=True)
|
|
375
|
+
global dict_cache
|
|
376
|
+
dict_cache = dict(model_name=args.model, cache_dir = args.cache)
|
|
364
377
|
model, tokenizer = mlx_lm.load(args.model)
|
|
365
378
|
server = HTTPServer((args.host, args.port), make_handler(model, tokenizer, args.system, args.names, args.skips))
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
379
|
+
if args.nocc:
|
|
380
|
+
try:
|
|
381
|
+
server.serve_forever()
|
|
382
|
+
except KeyboardInterrupt:
|
|
383
|
+
print("\nShutting down server...")
|
|
384
|
+
server.server_close()
|
|
385
|
+
else:
|
|
386
|
+
threading.Thread(target=server.serve_forever, daemon=True).start()
|
|
387
|
+
env = os.environ.copy()
|
|
388
|
+
env["ANTHROPIC_BASE_URL"] = f"http://{args.host}:{args.port}"
|
|
389
|
+
env["ANTHROPIC_AUTH_TOKEN"] = "local"
|
|
390
|
+
env["ANTHROPIC_MODEL"] = args.model
|
|
391
|
+
env["ANTHROPIC_SMALL_FAST_MODEL"] = args.model
|
|
392
|
+
env["HOME"] = args.home
|
|
393
|
+
def mirror_workspace(src: str, dst: str):
|
|
394
|
+
for root, dirs, files in os.walk(src):
|
|
395
|
+
rel = os.path.relpath(root, src)
|
|
396
|
+
os.makedirs(os.path.join(dst, rel), exist_ok=True)
|
|
397
|
+
for f in files:
|
|
398
|
+
os.link(os.path.join(root, f), os.path.join(dst, rel, f))
|
|
399
|
+
workspace = os.path.join(args.home, "workspace")
|
|
400
|
+
mirror_workspace(args.work, workspace)
|
|
401
|
+
if args.harness is None:
|
|
402
|
+
from pie import run_repl
|
|
403
|
+
run_repl(base_url=f"http://{args.host}:{args.port}/v1")
|
|
404
|
+
else:
|
|
405
|
+
sys.exit(subprocess.run([args.harness] + harness_args, env=env, cwd=workspace).returncode)
|
|
406
|
+
|
|
407
|
+
def generate_step(
|
|
408
|
+
prompt: mx.array,
|
|
409
|
+
model: nn.Module,
|
|
410
|
+
*,
|
|
411
|
+
max_tokens: int = 256,
|
|
412
|
+
sampler: Optional[Callable[[mx.array], mx.array]] = None,
|
|
413
|
+
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
|
414
|
+
max_kv_size: Optional[int] = None,
|
|
415
|
+
prompt_cache: Optional[Any] = None,
|
|
416
|
+
prefill_step_size: int = 2048,
|
|
417
|
+
kv_bits: Optional[int] = None,
|
|
418
|
+
kv_group_size: int = 64,
|
|
419
|
+
quantized_kv_start: int = 0,
|
|
420
|
+
prompt_progress_callback: Optional[Callable[[int, int], None]] = None,
|
|
421
|
+
input_embeddings: Optional[mx.array] = None,
|
|
422
|
+
save_at: int = -1,
|
|
423
|
+
save_fn = None,
|
|
424
|
+
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
|
425
|
+
if input_embeddings is not None:
|
|
426
|
+
if not does_model_support_input_embeddings(model):
|
|
427
|
+
raise ValueError("Model does not support input embeddings.")
|
|
428
|
+
elif len(prompt) > 0 and len(prompt) != len(input_embeddings):
|
|
429
|
+
raise ValueError(f"{len(input_embeddings)=} {len(prompt)=}")
|
|
430
|
+
elif len(prompt) == 0:
|
|
431
|
+
raise ValueError("Either input_embeddings or prompt (or both) must be provided.")
|
|
432
|
+
|
|
433
|
+
tokens = None
|
|
434
|
+
|
|
435
|
+
if prompt_cache is None:
|
|
436
|
+
prompt_cache = cache.make_prompt_cache(
|
|
437
|
+
model,
|
|
438
|
+
max_kv_size=max_kv_size,
|
|
439
|
+
)
|
|
440
|
+
|
|
441
|
+
prompt_progress_callback = prompt_progress_callback or (lambda *_: None)
|
|
442
|
+
|
|
443
|
+
quantize_cache_fn = functools.partial(
|
|
444
|
+
maybe_quantize_kv_cache,
|
|
445
|
+
quantized_kv_start=quantized_kv_start,
|
|
446
|
+
kv_group_size=kv_group_size,
|
|
447
|
+
kv_bits=kv_bits,
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
sampler = sampler or (lambda x: mx.argmax(x, axis=-1))
|
|
451
|
+
|
|
452
|
+
def _model_call(input_tokens: mx.array, input_embeddings: Optional[mx.array]):
|
|
453
|
+
if input_embeddings is not None:
|
|
454
|
+
return model(
|
|
455
|
+
input_tokens, cache=prompt_cache, input_embeddings=input_embeddings
|
|
456
|
+
)
|
|
457
|
+
else:
|
|
458
|
+
return model(input_tokens, cache=prompt_cache)
|
|
459
|
+
|
|
460
|
+
def _step(input_tokens: mx.array, input_embeddings: Optional[mx.array] = None):
|
|
461
|
+
nonlocal tokens
|
|
462
|
+
|
|
463
|
+
with mx.stream(generation_stream):
|
|
464
|
+
logits = _model_call(
|
|
465
|
+
input_tokens=input_tokens[None],
|
|
466
|
+
input_embeddings=(
|
|
467
|
+
input_embeddings[None] if input_embeddings is not None else None
|
|
468
|
+
),
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
logits = logits[:, -1, :]
|
|
472
|
+
|
|
473
|
+
if logits_processors and len(input_tokens) > 0:
|
|
474
|
+
tokens = (
|
|
475
|
+
mx.concat([tokens, input_tokens])
|
|
476
|
+
if tokens is not None
|
|
477
|
+
else input_tokens
|
|
478
|
+
)
|
|
479
|
+
for processor in logits_processors:
|
|
480
|
+
logits = processor(tokens, logits)
|
|
481
|
+
|
|
482
|
+
quantize_cache_fn(prompt_cache)
|
|
483
|
+
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
|
484
|
+
sampled = sampler(logprobs)
|
|
485
|
+
return sampled, logprobs.squeeze(0)
|
|
486
|
+
|
|
487
|
+
with mx.stream(generation_stream):
|
|
488
|
+
total_prompt_tokens = (
|
|
489
|
+
len(input_embeddings) if input_embeddings is not None else len(prompt)
|
|
490
|
+
)
|
|
491
|
+
prompt_processed_tokens = 0
|
|
492
|
+
prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens)
|
|
493
|
+
while total_prompt_tokens - prompt_processed_tokens > 1:
|
|
494
|
+
remaining = (total_prompt_tokens - prompt_processed_tokens) - 1
|
|
495
|
+
n_to_process = min(prefill_step_size, remaining)
|
|
496
|
+
if prompt_processed_tokens < save_at:
|
|
497
|
+
n_to_process = min(n_to_process, save_at - prompt_processed_tokens)
|
|
498
|
+
|
|
499
|
+
_model_call(
|
|
500
|
+
input_tokens=prompt[:n_to_process][None],
|
|
501
|
+
input_embeddings=(
|
|
502
|
+
input_embeddings[:n_to_process][None]
|
|
503
|
+
if input_embeddings is not None
|
|
504
|
+
else None
|
|
505
|
+
),
|
|
506
|
+
)
|
|
507
|
+
quantize_cache_fn(prompt_cache)
|
|
508
|
+
mx.eval([c.state for c in prompt_cache])
|
|
509
|
+
prompt_processed_tokens += n_to_process
|
|
510
|
+
prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens)
|
|
511
|
+
prompt = prompt[n_to_process:]
|
|
512
|
+
input_embeddings = (
|
|
513
|
+
input_embeddings[n_to_process:]
|
|
514
|
+
if input_embeddings is not None
|
|
515
|
+
else input_embeddings
|
|
516
|
+
)
|
|
517
|
+
mx.clear_cache()
|
|
518
|
+
if save_fn is not None and prompt_processed_tokens == save_at:
|
|
519
|
+
save_fn(prompt_cache)
|
|
520
|
+
|
|
521
|
+
y, logprobs = _step(input_tokens=prompt, input_embeddings=input_embeddings)
|
|
522
|
+
|
|
523
|
+
mx.async_eval(y, logprobs)
|
|
524
|
+
n = 0
|
|
525
|
+
while True:
|
|
526
|
+
if n != max_tokens:
|
|
527
|
+
next_y, next_logprobs = _step(y)
|
|
528
|
+
mx.async_eval(next_y, next_logprobs)
|
|
529
|
+
if n == 0:
|
|
530
|
+
mx.eval(y)
|
|
531
|
+
prompt_progress_callback(total_prompt_tokens, total_prompt_tokens)
|
|
532
|
+
if n == max_tokens:
|
|
533
|
+
break
|
|
534
|
+
yield y.item(), logprobs
|
|
535
|
+
if n % 256 == 0:
|
|
536
|
+
mx.clear_cache()
|
|
537
|
+
y, logprobs = next_y, next_logprobs
|
|
538
|
+
n += 1
|
|
539
|
+
|
|
540
|
+
def generate(model, tokenizer, prompt, pref, hook=None, max_tokens=256, helper_max_tokens=64, **kwargs):
|
|
541
|
+
global dict_cache
|
|
542
|
+
if prompt is None:
|
|
543
|
+
return '', 0, 0
|
|
544
|
+
if not isinstance(tokenizer, mlx_lm.tokenizer_utils.TokenizerWrapper):
|
|
545
|
+
tokenizer = mlx_lm.tokenizer_utils.TokenizerWrapper(tokenizer)
|
|
546
|
+
detokenizer = tokenizer.detokenizer
|
|
547
|
+
if isinstance(prompt, str):
|
|
548
|
+
add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(tokenizer.bos_token)
|
|
549
|
+
prompt_s = prompt
|
|
550
|
+
prompt = tokenizer.encode(prompt, add_special_tokens=add_special_tokens)
|
|
551
|
+
_pref = tokenizer.encode(pref, add_special_tokens=add_special_tokens)
|
|
552
|
+
save_at = get_common_len(prompt, _pref)
|
|
553
|
+
else:
|
|
554
|
+
prompt_s = tokenizer.decode(prompt)
|
|
555
|
+
save_at = -1 # □ for now
|
|
556
|
+
stream_logger.debug(dmca(prompt_s))
|
|
557
|
+
text = ''
|
|
558
|
+
gens = []
|
|
559
|
+
common_len = 0
|
|
560
|
+
hx_len = None
|
|
561
|
+
trim_len = None
|
|
562
|
+
save_fn = None
|
|
563
|
+
if not dict_cache.get('cache'):
|
|
564
|
+
ckpt_path = Path(dict_cache['cache_dir'])/f'{"".join(c for c in dict_cache["model_name"] if c.isalnum())}_{save_at}_{hash_tokens(prompt[:save_at])}.safetensors'
|
|
565
|
+
if os.path.exists(ckpt_path):
|
|
566
|
+
load_dict_cache(ckpt_path)
|
|
567
|
+
else:
|
|
568
|
+
dict_cache |= dict(cache=mlx_lm.models.cache.make_prompt_cache(model), hx=[])
|
|
569
|
+
save_fn = functools.partial(save_dict_cache, ckpt_path, dict(model_name=dict_cache['model_name'], hx=json.dumps(prompt[:save_at+1])))
|
|
570
|
+
|
|
571
|
+
if (hx := dict_cache.get('hx')):
|
|
572
|
+
_hx = hx[:-1]
|
|
573
|
+
common_len = get_common_len(prompt, _hx)
|
|
574
|
+
hx_len = len(_hx)
|
|
575
|
+
trim_len = hx_len - common_len
|
|
576
|
+
if trim_len > 0:
|
|
577
|
+
if all(c.is_trimmable() for c in dict_cache['cache']):
|
|
578
|
+
mlx_lm.models.cache.trim_prompt_cache(dict_cache['cache'], trim_len)
|
|
579
|
+
else:
|
|
580
|
+
ckpt_path = Path(dict_cache['cache_dir'])/f'{"".join(c for c in dict_cache["model_name"] if c.isalnum())}_{save_at}_{hash_tokens(prompt[:save_at])}.safetensors'
|
|
581
|
+
if os.path.exists(ckpt_path):
|
|
582
|
+
load_dict_cache(ckpt_path)
|
|
583
|
+
common_len = save_at
|
|
584
|
+
else:
|
|
585
|
+
dict_cache |= dict(cache=mlx_lm.models.cache.make_prompt_cache(model), hx=[])
|
|
586
|
+
save_fn = functools.partial(save_dict_cache, ckpt_path, dict(model_name=dict_cache['model_name'], hx=json.dumps(prompt[:save_at+1])))
|
|
587
|
+
common_len = 0
|
|
588
|
+
|
|
589
|
+
if save_at > common_len and not all(c.is_trimmable() for c in dict_cache['cache']):
|
|
590
|
+
ckpt_path = Path(dict_cache['cache_dir'])/f'{"".join(c for c in dict_cache["model_name"] if c.isalnum())}_{save_at}_{hash_tokens(prompt[:save_at])}.safetensors'
|
|
591
|
+
save_fn = functools.partial(save_dict_cache, ckpt_path, dict(model_name=dict_cache['model_name'], hx=json.dumps(prompt[:save_at+1])))
|
|
592
|
+
else:
|
|
593
|
+
save_at = -1
|
|
594
|
+
|
|
595
|
+
if common_len==len(prompt):
|
|
596
|
+
_last_gen = dict_cache['hx'][common_len]
|
|
597
|
+
prompt_arr = mx.array([_last_gen])
|
|
598
|
+
gens.append(_last_gen)
|
|
599
|
+
detokenizer.add(_last_gen)
|
|
600
|
+
else:
|
|
601
|
+
prompt_arr = mx.array(prompt[common_len:])
|
|
602
|
+
|
|
603
|
+
token_gen = generate_step(
|
|
604
|
+
prompt_arr,
|
|
605
|
+
model,
|
|
606
|
+
prompt_cache=dict_cache['cache'],
|
|
607
|
+
max_tokens=max_tokens,
|
|
608
|
+
save_at=save_at-common_len,
|
|
609
|
+
save_fn=save_fn,
|
|
610
|
+
**kwargs,
|
|
611
|
+
)
|
|
612
|
+
tic_non = time.perf_counter()
|
|
613
|
+
for token, _ in token_gen:
|
|
614
|
+
gens.append(token)
|
|
615
|
+
if token in tokenizer.eos_token_ids:
|
|
616
|
+
break
|
|
617
|
+
detokenizer.add_token(token)
|
|
618
|
+
seg = detokenizer.last_segment
|
|
619
|
+
stream_logger.debug(seg)
|
|
620
|
+
text += seg
|
|
621
|
+
if len(gens) == 1:
|
|
622
|
+
tic_inp = time.perf_counter()
|
|
623
|
+
if len(gens) >= max_tokens:
|
|
624
|
+
break
|
|
625
|
+
tic_out = time.perf_counter()
|
|
626
|
+
detokenizer.finalize()
|
|
627
|
+
text += detokenizer.last_segment
|
|
628
|
+
dict_cache['hx'] = prompt+gens
|
|
629
|
+
trace_logger.debug(f'G {hx_len} {len(prompt)} {common_len} {trim_len} {len(gens)}\n=== TPS ===\n- Processed {len(prompt)} input tokens in {tic_inp-tic_non:.0f} seconds ({len(prompt)/(tic_inp-tic_non):.0f} tokens per second)\n- Generated {len(gens)} new tokens in {tic_out-tic_inp:.0f} seconds ({len(gens)/(tic_out-tic_inp):.0f} tokens per second)\n\n=== INP ===\n{dmca(prompt_s)}\n=== OUT ===\n{text}')
|
|
630
|
+
return text, len(prompt), len(gens)
|
|
382
631
|
|
|
383
632
|
if __name__ == "__main__":
|
|
384
633
|
main()
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: mlx-code
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.2a3
|
|
4
4
|
Summary: Local Claude Code for Mac
|
|
5
5
|
Home-page: https://github.com/JosefAlbers/mlx-code
|
|
6
6
|
Author: J Joe
|
|
@@ -52,22 +52,12 @@ mlx-code [options] [-- claude options]
|
|
|
52
52
|
| `--work` | `$CWD` | Working directory mirrored into the Claude session |
|
|
53
53
|
| `--home` | temp dir | Home directory for the Claude process |
|
|
54
54
|
|
|
55
|
-
Any extra arguments after `--` are forwarded to the `claude` CLI
|
|
55
|
+
Any extra arguments after `--` are forwarded to the `claude` CLI.
|
|
56
56
|
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
| `mlx-code "task"` | Run a one-time task | `mlx-code "fix the build error"` |
|
|
61
|
-
| `mlx-code -p "query"` | Run one-off query, then exit | `mlx-code -p "explain this function"` |
|
|
62
|
-
| `mlx-code -c` | Continue most recent conversation in current directory | `mlx-code -c` |
|
|
63
|
-
| `mlx-code -r` | Resume a previous conversation | `mlx-code -r` |
|
|
64
|
-
| `mlx-code commit` | Create a Git commit | `mlx-code commit` |
|
|
65
|
-
| `/clear` | Clear conversation history | `/clear` |
|
|
66
|
-
| `/help` | Show available commands | `/help` |
|
|
67
|
-
| `exit` or `Ctrl+C` | Exit Claude Code | `exit` |
|
|
57
|
+
### Credits
|
|
58
|
+
|
|
59
|
+
`pie.py` is based on [pi](https://github.com/badlogic/pi-mono) by Mario Zechner, used under the MIT License.
|
|
68
60
|
|
|
69
61
|
### Licence
|
|
70
62
|
|
|
71
63
|
Apache License 2.0 — see LICENSE for details.
|
|
72
|
-
|
|
73
|
-
|