mlx-code 0.0.2a1__tar.gz → 0.0.2a2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mlx_code-0.0.2a1/mlx_code.egg-info → mlx_code-0.0.2a2}/PKG-INFO +2 -14
- {mlx_code-0.0.2a1 → mlx_code-0.0.2a2}/README.md +1 -13
- {mlx_code-0.0.2a1 → mlx_code-0.0.2a2}/main.py +318 -84
- {mlx_code-0.0.2a1 → mlx_code-0.0.2a2/mlx_code.egg-info}/PKG-INFO +2 -14
- {mlx_code-0.0.2a1 → mlx_code-0.0.2a2}/setup.py +1 -1
- {mlx_code-0.0.2a1 → mlx_code-0.0.2a2}/LICENSE +0 -0
- {mlx_code-0.0.2a1 → mlx_code-0.0.2a2}/mlx_code.egg-info/SOURCES.txt +0 -0
- {mlx_code-0.0.2a1 → mlx_code-0.0.2a2}/mlx_code.egg-info/dependency_links.txt +0 -0
- {mlx_code-0.0.2a1 → mlx_code-0.0.2a2}/mlx_code.egg-info/entry_points.txt +0 -0
- {mlx_code-0.0.2a1 → mlx_code-0.0.2a2}/mlx_code.egg-info/requires.txt +0 -0
- {mlx_code-0.0.2a1 → mlx_code-0.0.2a2}/mlx_code.egg-info/top_level.txt +0 -0
- {mlx_code-0.0.2a1 → mlx_code-0.0.2a2}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: mlx-code
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.2a2
|
|
4
4
|
Summary: Local Claude Code for Mac
|
|
5
5
|
Home-page: https://github.com/JosefAlbers/mlx-code
|
|
6
6
|
Author: J Joe
|
|
@@ -52,19 +52,7 @@ mlx-code [options] [-- claude options]
|
|
|
52
52
|
| `--work` | `$CWD` | Working directory mirrored into the Claude session |
|
|
53
53
|
| `--home` | temp dir | Home directory for the Claude process |
|
|
54
54
|
|
|
55
|
-
Any extra arguments after `--` are forwarded to the `claude` CLI
|
|
56
|
-
|
|
57
|
-
| Command | What it does | Example |
|
|
58
|
-
|--------|--------------|--------|
|
|
59
|
-
| `mlx-code` | Start interactive mode | `mlx-code` |
|
|
60
|
-
| `mlx-code "task"` | Run a one-time task | `mlx-code "fix the build error"` |
|
|
61
|
-
| `mlx-code -p "query"` | Run one-off query, then exit | `mlx-code -p "explain this function"` |
|
|
62
|
-
| `mlx-code -c` | Continue most recent conversation in current directory | `mlx-code -c` |
|
|
63
|
-
| `mlx-code -r` | Resume a previous conversation | `mlx-code -r` |
|
|
64
|
-
| `mlx-code commit` | Create a Git commit | `mlx-code commit` |
|
|
65
|
-
| `/clear` | Clear conversation history | `/clear` |
|
|
66
|
-
| `/help` | Show available commands | `/help` |
|
|
67
|
-
| `exit` or `Ctrl+C` | Exit Claude Code | `exit` |
|
|
55
|
+
Any extra arguments after `--` are forwarded to the `claude` CLI.
|
|
68
56
|
|
|
69
57
|
### Licence
|
|
70
58
|
|
|
@@ -28,19 +28,7 @@ mlx-code [options] [-- claude options]
|
|
|
28
28
|
| `--work` | `$CWD` | Working directory mirrored into the Claude session |
|
|
29
29
|
| `--home` | temp dir | Home directory for the Claude process |
|
|
30
30
|
|
|
31
|
-
Any extra arguments after `--` are forwarded to the `claude` CLI
|
|
32
|
-
|
|
33
|
-
| Command | What it does | Example |
|
|
34
|
-
|--------|--------------|--------|
|
|
35
|
-
| `mlx-code` | Start interactive mode | `mlx-code` |
|
|
36
|
-
| `mlx-code "task"` | Run a one-time task | `mlx-code "fix the build error"` |
|
|
37
|
-
| `mlx-code -p "query"` | Run one-off query, then exit | `mlx-code -p "explain this function"` |
|
|
38
|
-
| `mlx-code -c` | Continue most recent conversation in current directory | `mlx-code -c` |
|
|
39
|
-
| `mlx-code -r` | Resume a previous conversation | `mlx-code -r` |
|
|
40
|
-
| `mlx-code commit` | Create a Git commit | `mlx-code commit` |
|
|
41
|
-
| `/clear` | Clear conversation history | `/clear` |
|
|
42
|
-
| `/help` | Show available commands | `/help` |
|
|
43
|
-
| `exit` or `Ctrl+C` | Exit Claude Code | `exit` |
|
|
31
|
+
Any extra arguments after `--` are forwarded to the `claude` CLI.
|
|
44
32
|
|
|
45
33
|
### Licence
|
|
46
34
|
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
# {{{
|
|
1
2
|
# Copyright 2026 J Joe
|
|
2
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
3
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -25,8 +26,22 @@ from http.server import BaseHTTPRequestHandler, HTTPServer
|
|
|
25
26
|
from pathlib import Path
|
|
26
27
|
import mlx.core as mx
|
|
27
28
|
import mlx_lm
|
|
28
|
-
|
|
29
|
+
import numpy as np
|
|
30
|
+
import hashlib
|
|
31
|
+
import contextlib
|
|
32
|
+
import functools
|
|
33
|
+
import mlx.nn as nn
|
|
34
|
+
from typing import (
|
|
35
|
+
Any,
|
|
36
|
+
Callable,
|
|
37
|
+
Generator,
|
|
38
|
+
List,
|
|
39
|
+
Optional,
|
|
40
|
+
Tuple,
|
|
41
|
+
Union,
|
|
42
|
+
)
|
|
29
43
|
|
|
44
|
+
generation_stream = mx.new_stream(mx.default_device())
|
|
30
45
|
stream_logger = logging.getLogger("stream")
|
|
31
46
|
stream_logger.setLevel(logging.DEBUG)
|
|
32
47
|
s_handler = logging.FileHandler("mlx_stream.log", mode='w')
|
|
@@ -39,7 +54,54 @@ t_handler = logging.FileHandler("mlx_trace.log", mode='w')
|
|
|
39
54
|
t_handler.setFormatter(logging.Formatter("【%(message)s\n】\n"))
|
|
40
55
|
trace_logger.addHandler(t_handler)
|
|
41
56
|
gen_lock = threading.Lock()
|
|
42
|
-
|
|
57
|
+
dict_cache = {}
|
|
58
|
+
|
|
59
|
+
def hash_tokens(tokens):
|
|
60
|
+
arr = np.array(tokens, dtype=np.uint32)
|
|
61
|
+
return hashlib.blake2b(arr.tobytes(), digest_size=8).hexdigest()
|
|
62
|
+
|
|
63
|
+
def get_common_len(a, b):
|
|
64
|
+
common_len = 0
|
|
65
|
+
for p, h in zip(a, b):
|
|
66
|
+
if p == h:
|
|
67
|
+
common_len += 1
|
|
68
|
+
else:
|
|
69
|
+
break
|
|
70
|
+
return common_len
|
|
71
|
+
|
|
72
|
+
@contextlib.contextmanager
|
|
73
|
+
def wired_limit(model: nn.Module, streams: Optional[List[mx.Stream]] = None):
|
|
74
|
+
if not mx.metal.is_available():
|
|
75
|
+
try:
|
|
76
|
+
yield
|
|
77
|
+
finally:
|
|
78
|
+
pass
|
|
79
|
+
else:
|
|
80
|
+
model_bytes = tree_reduce(
|
|
81
|
+
lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc, model, 0
|
|
82
|
+
)
|
|
83
|
+
max_rec_size = mx.device_info()["max_recommended_working_set_size"]
|
|
84
|
+
if model_bytes > 0.9 * max_rec_size:
|
|
85
|
+
model_mb = model_bytes // 2**20
|
|
86
|
+
max_rec_mb = max_rec_size // 2**20
|
|
87
|
+
print(f"{model_mb=} {max_rec_mb=}")
|
|
88
|
+
old_limit = mx.set_wired_limit(max_rec_size)
|
|
89
|
+
try:
|
|
90
|
+
yield
|
|
91
|
+
finally:
|
|
92
|
+
if streams is not None:
|
|
93
|
+
for s in streams:
|
|
94
|
+
mx.synchronize(s)
|
|
95
|
+
else:
|
|
96
|
+
mx.synchronize()
|
|
97
|
+
mx.set_wired_limit(old_limit)
|
|
98
|
+
|
|
99
|
+
def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits):
|
|
100
|
+
if kv_bits is None:
|
|
101
|
+
return
|
|
102
|
+
for e, c in enumerate(prompt_cache):
|
|
103
|
+
if hasattr(c, "to_quantized") and c.offset >= quantized_kv_start:
|
|
104
|
+
prompt_cache[e] = c.to_quantized(group_size=kv_group_size, bits=kv_bits)
|
|
43
105
|
|
|
44
106
|
def parse_tool(tools, names):
|
|
45
107
|
qwen_tools = []
|
|
@@ -121,11 +183,17 @@ def encode(body, tokenizer, system, names, skips):
|
|
|
121
183
|
if parts:
|
|
122
184
|
msgs.append({"role": role}|parts)
|
|
123
185
|
if not msgs[-1].get('content', '').strip():
|
|
124
|
-
return None
|
|
125
|
-
|
|
186
|
+
return None, -1
|
|
187
|
+
apply_chat_template = lambda x: tokenizer.apply_chat_template(x, tools = parse_tool(body.get("tools", []), names), tokenize=False, add_generation_prompt=True)
|
|
188
|
+
full = apply_chat_template(msgs)
|
|
189
|
+
last_user_idx = max((i for i, m in enumerate(msgs) if m.get("role") == "user"), default=None)
|
|
190
|
+
if last_user_idx is None:
|
|
191
|
+
return full, -1
|
|
192
|
+
p_msgs = msgs[:last_user_idx] + [dict(role='user', content='h' if msgs[last_user_idx]['content'][0] != 'h' else 'i')]
|
|
193
|
+
pref = apply_chat_template(p_msgs)
|
|
194
|
+
return full, pref
|
|
126
195
|
|
|
127
196
|
def decode(raw_text, tokenizer, parse_think, single_think=False):
|
|
128
|
-
# think_id = tokenizer.convert_tokens_to_ids("<think>")
|
|
129
197
|
def escape(text):
|
|
130
198
|
def repl(match):
|
|
131
199
|
inner = match.group(1)
|
|
@@ -139,7 +207,6 @@ def decode(raw_text, tokenizer, parse_think, single_think=False):
|
|
|
139
207
|
parts = re.split(r'(<think>.*?</think>)', raw_text, flags=re.DOTALL, maxsplit=1 if single_think else 0)
|
|
140
208
|
else:
|
|
141
209
|
parts = [raw_text]
|
|
142
|
-
|
|
143
210
|
for part in parts:
|
|
144
211
|
if not part:
|
|
145
212
|
continue
|
|
@@ -192,7 +259,7 @@ def blocks_to_sse(blocks: list[dict], msg_id: str, in_tokens: int, out_tokens: i
|
|
|
192
259
|
elif bt == "tool_use":
|
|
193
260
|
out += event("content_block_start", {"type": "content_block_start", "index": i,
|
|
194
261
|
"content_block": {"type": "tool_use", "id": block["id"],
|
|
195
|
-
"name": block["name"], "input": {}}})
|
|
262
|
+
"name": block["name"], "input": {}} })
|
|
196
263
|
out += event("content_block_delta", {"type": "content_block_delta", "index": i,
|
|
197
264
|
"delta": {"type": "input_json_delta", "partial_json": json.dumps(block["input"])}})
|
|
198
265
|
out += event("content_block_stop", {"type": "content_block_stop", "index": i})
|
|
@@ -203,6 +270,8 @@ def blocks_to_sse(blocks: list[dict], msg_id: str, in_tokens: int, out_tokens: i
|
|
|
203
270
|
return bytes(out)
|
|
204
271
|
|
|
205
272
|
def dmca(p_str):
|
|
273
|
+
if True: #: False for recording
|
|
274
|
+
return p_str
|
|
206
275
|
symbols = ["▲", "△", "▶", "▷", "▼", "▽", "◀", "◁", "◆", "◇"]
|
|
207
276
|
def mask_text(text):
|
|
208
277
|
return re.sub(r"\S", lambda _: random.choice(symbols), text)
|
|
@@ -218,66 +287,6 @@ def dmca(p_str):
|
|
|
218
287
|
p_str = re.sub(pattern, lambda m: mask_text(m.group(0)), p_str)
|
|
219
288
|
return p_str
|
|
220
289
|
|
|
221
|
-
def generate(model, tokenizer, prompt, hook=None, max_tokens=256, helper_max_tokens=64, **kwargs):
|
|
222
|
-
global prompt_cache
|
|
223
|
-
if prompt is None:
|
|
224
|
-
return '', 0, 0
|
|
225
|
-
if not isinstance(tokenizer, mlx_lm.tokenizer_utils.TokenizerWrapper):
|
|
226
|
-
tokenizer = mlx_lm.tokenizer_utils.TokenizerWrapper(tokenizer)
|
|
227
|
-
detokenizer = tokenizer.detokenizer
|
|
228
|
-
if isinstance(prompt, str):
|
|
229
|
-
add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(tokenizer.bos_token)
|
|
230
|
-
prompt_s = prompt
|
|
231
|
-
prompt = tokenizer.encode(prompt, add_special_tokens=add_special_tokens)
|
|
232
|
-
else:
|
|
233
|
-
prompt_s = tokenizer.decode(prompt)
|
|
234
|
-
stream_logger.debug(dmca(prompt_s))
|
|
235
|
-
common_len = 0
|
|
236
|
-
if prompt_cache.get('cache', None):
|
|
237
|
-
for p, h in zip(prompt, prompt_cache['hx']):
|
|
238
|
-
if p == h:
|
|
239
|
-
common_len += 1
|
|
240
|
-
else:
|
|
241
|
-
break
|
|
242
|
-
common_len = min(common_len, len(prompt) - 1)
|
|
243
|
-
else:
|
|
244
|
-
prompt_cache['hx'] = []
|
|
245
|
-
prompt_cache['cache'] = mlx_lm.models.cache.make_prompt_cache(model)
|
|
246
|
-
hx_len = len(prompt_cache['hx'])
|
|
247
|
-
trim_len = hx_len - common_len
|
|
248
|
-
mlx_lm.models.cache.trim_prompt_cache(prompt_cache['cache'], trim_len)
|
|
249
|
-
token_gen = generate_step(
|
|
250
|
-
mx.array(prompt[common_len:]),
|
|
251
|
-
model,
|
|
252
|
-
prompt_cache=prompt_cache['cache'],
|
|
253
|
-
max_tokens=max_tokens,
|
|
254
|
-
**kwargs,
|
|
255
|
-
)
|
|
256
|
-
text = ""
|
|
257
|
-
tic_non = time.perf_counter()
|
|
258
|
-
gens = []
|
|
259
|
-
for token, _ in token_gen:
|
|
260
|
-
gens.append(token)
|
|
261
|
-
if token in tokenizer.eos_token_ids:
|
|
262
|
-
break
|
|
263
|
-
detokenizer.add_token(token)
|
|
264
|
-
seg = detokenizer.last_segment
|
|
265
|
-
stream_logger.debug(seg)
|
|
266
|
-
text += seg
|
|
267
|
-
if len(gens) == 1:
|
|
268
|
-
tic_inp = time.perf_counter()
|
|
269
|
-
if prompt_cache.get('file_name'):
|
|
270
|
-
_fn = prompt_cache.pop('file_name')
|
|
271
|
-
mlx_lm.models.cache.save_prompt_cache(_fn, prompt_cache['cache'], metadata=dict(model_name=prompt_cache['model_name'], hx=json.dumps(prompt+gens)))
|
|
272
|
-
if len(gens) >= max_tokens:
|
|
273
|
-
break
|
|
274
|
-
tic_out = time.perf_counter()
|
|
275
|
-
detokenizer.finalize()
|
|
276
|
-
text += detokenizer.last_segment
|
|
277
|
-
prompt_cache['hx'] = prompt+gens
|
|
278
|
-
trace_logger.debug(f'G {hx_len} {len(prompt)} {common_len} {trim_len} {len(gens)}\n=== TPS ===\n- Processed {len(prompt)} input tokens in {tic_inp-tic_non:.0f} seconds ({len(prompt)/(tic_inp-tic_non):.0f} tokens per second)\n- Generated {len(gens)} new tokens in {tic_out-tic_inp:.0f} seconds ({len(gens)/(tic_out-tic_inp):.0f} tokens per second)\n\n=== INP ===\n{dmca(prompt_s)}\n=== OUT ===\n{text}')
|
|
279
|
-
return text, len(prompt), len(gens)
|
|
280
|
-
|
|
281
290
|
def make_handler(model, tokenizer, system, names, skips, parse_think=True):
|
|
282
291
|
class Handler(BaseHTTPRequestHandler):
|
|
283
292
|
def log_message(self, fmt, *args):
|
|
@@ -308,9 +317,9 @@ def make_handler(model, tokenizer, system, names, skips, parse_think=True):
|
|
|
308
317
|
return
|
|
309
318
|
n = int(self.headers.get("Content-Length", 0))
|
|
310
319
|
body = json.loads(self.rfile.read(n))
|
|
311
|
-
prompt = encode(body, tokenizer, system, names, skips)
|
|
320
|
+
prompt, pref = encode(body, tokenizer, system, names, skips)
|
|
312
321
|
with gen_lock:
|
|
313
|
-
raw, in_tokens, out_tokens = generate(model, tokenizer, prompt=prompt, max_tokens=body.get("max_tokens", 8192))
|
|
322
|
+
raw, in_tokens, out_tokens = generate(model, tokenizer, pref=pref, prompt=prompt, max_tokens=body.get("max_tokens", 8192))
|
|
314
323
|
blocks, stop_reason = decode(raw, tokenizer, parse_think=parse_think)
|
|
315
324
|
msg_id = f"msg_{uuid.uuid4().hex}"
|
|
316
325
|
sse = blocks_to_sse(blocks, msg_id, in_tokens, out_tokens, stop_reason)
|
|
@@ -326,6 +335,18 @@ def make_handler(model, tokenizer, system, names, skips, parse_think=True):
|
|
|
326
335
|
pass
|
|
327
336
|
return Handler
|
|
328
337
|
|
|
338
|
+
def load_dict_cache(cache_path):
|
|
339
|
+
global dict_cache
|
|
340
|
+
cache, metadata = mlx_lm.models.cache.load_prompt_cache(cache_path, return_metadata=True)
|
|
341
|
+
mx.eval(cache)
|
|
342
|
+
model_name = metadata.pop("model_name", "")
|
|
343
|
+
tokens_str = metadata.pop("hx", "[]")
|
|
344
|
+
tokens = json.loads(tokens_str)
|
|
345
|
+
dict_cache = dict(cache=cache, hx=tokens, model_name=model_name)
|
|
346
|
+
|
|
347
|
+
def save_dict_cache(cache_path, metadata, prompt_cache):
|
|
348
|
+
mlx_lm.models.cache.save_prompt_cache(cache_path, prompt_cache, metadata=metadata)
|
|
349
|
+
|
|
329
350
|
def main():
|
|
330
351
|
parser = argparse.ArgumentParser()
|
|
331
352
|
parser.add_argument("--model", default="mlx-community/Qwen3.5-4B-OptiQ-4bit")
|
|
@@ -334,7 +355,8 @@ def main():
|
|
|
334
355
|
parser.add_argument("--system", type=str, default='')
|
|
335
356
|
# parser.add_argument("--system", type=str, default='# Env\n{env}')
|
|
336
357
|
# parser.add_argument("--system", type=str, default=None)
|
|
337
|
-
parser.add_argument("--cache", type=str, default='cache/cache.safetensors')
|
|
358
|
+
# parser.add_argument("--cache", type=str, default='cache/cache.safetensors')
|
|
359
|
+
parser.add_argument("--cache", type=str, default='cache')
|
|
338
360
|
# parser.add_argument("--names", nargs="+", default=[])
|
|
339
361
|
parser.add_argument("--names", nargs="+", default=['Read','Edit','Write','Grep','Glob','Bash','Agent','Skill'])
|
|
340
362
|
# parser.add_argument("--names", nargs="+", default=None)
|
|
@@ -347,20 +369,9 @@ def main():
|
|
|
347
369
|
parser.add_argument("--home", default=tempfile.mkdtemp())
|
|
348
370
|
parser.add_argument("--work", default=os.getcwd())
|
|
349
371
|
args, claude_args = parser.parse_known_args()
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
mx.eval(cache)
|
|
354
|
-
model_name = metadata.pop("model_name", "")
|
|
355
|
-
tokens_str = metadata.pop("hx", "[]")
|
|
356
|
-
tokens = json.loads(tokens_str)
|
|
357
|
-
prompt_cache = dict(cache=cache, hx=tokens, model_name=model_name)
|
|
358
|
-
if prompt_cache.get('model_name') != args.model:
|
|
359
|
-
prompt_cache = dict(model_name=args.model)
|
|
360
|
-
else:
|
|
361
|
-
Path(args.cache).parent.mkdir(parents=True, exist_ok=True)
|
|
362
|
-
prompt_cache = dict(model_name=args.model)
|
|
363
|
-
prompt_cache['file_name']=args.cache
|
|
372
|
+
Path(args.cache).mkdir(parents=True, exist_ok=True)
|
|
373
|
+
global dict_cache
|
|
374
|
+
dict_cache = dict(model_name=args.model, cache_dir = args.cache)
|
|
364
375
|
model, tokenizer = mlx_lm.load(args.model)
|
|
365
376
|
server = HTTPServer((args.host, args.port), make_handler(model, tokenizer, args.system, args.names, args.skips))
|
|
366
377
|
threading.Thread(target=server.serve_forever, daemon=True).start()
|
|
@@ -380,5 +391,228 @@ def main():
|
|
|
380
391
|
mirror_workspace(args.work, workspace)
|
|
381
392
|
sys.exit(subprocess.run(["claude"] + claude_args, env=env, cwd=workspace).returncode)
|
|
382
393
|
|
|
394
|
+
def generate_step(
|
|
395
|
+
prompt: mx.array,
|
|
396
|
+
model: nn.Module,
|
|
397
|
+
*,
|
|
398
|
+
max_tokens: int = 256,
|
|
399
|
+
sampler: Optional[Callable[[mx.array], mx.array]] = None,
|
|
400
|
+
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
|
401
|
+
max_kv_size: Optional[int] = None,
|
|
402
|
+
prompt_cache: Optional[Any] = None,
|
|
403
|
+
prefill_step_size: int = 2048,
|
|
404
|
+
kv_bits: Optional[int] = None,
|
|
405
|
+
kv_group_size: int = 64,
|
|
406
|
+
quantized_kv_start: int = 0,
|
|
407
|
+
prompt_progress_callback: Optional[Callable[[int, int], None]] = None,
|
|
408
|
+
input_embeddings: Optional[mx.array] = None,
|
|
409
|
+
save_at: int = -1,
|
|
410
|
+
save_fn = None,
|
|
411
|
+
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
|
412
|
+
if input_embeddings is not None:
|
|
413
|
+
if not does_model_support_input_embeddings(model):
|
|
414
|
+
raise ValueError("Model does not support input embeddings.")
|
|
415
|
+
elif len(prompt) > 0 and len(prompt) != len(input_embeddings):
|
|
416
|
+
raise ValueError(f"{len(input_embeddings)=} {len(prompt)=}")
|
|
417
|
+
elif len(prompt) == 0:
|
|
418
|
+
raise ValueError("Either input_embeddings or prompt (or both) must be provided.")
|
|
419
|
+
|
|
420
|
+
tokens = None
|
|
421
|
+
|
|
422
|
+
if prompt_cache is None:
|
|
423
|
+
prompt_cache = cache.make_prompt_cache(
|
|
424
|
+
model,
|
|
425
|
+
max_kv_size=max_kv_size,
|
|
426
|
+
)
|
|
427
|
+
|
|
428
|
+
prompt_progress_callback = prompt_progress_callback or (lambda *_: None)
|
|
429
|
+
|
|
430
|
+
quantize_cache_fn = functools.partial(
|
|
431
|
+
maybe_quantize_kv_cache,
|
|
432
|
+
quantized_kv_start=quantized_kv_start,
|
|
433
|
+
kv_group_size=kv_group_size,
|
|
434
|
+
kv_bits=kv_bits,
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
sampler = sampler or (lambda x: mx.argmax(x, axis=-1))
|
|
438
|
+
|
|
439
|
+
def _model_call(input_tokens: mx.array, input_embeddings: Optional[mx.array]):
|
|
440
|
+
if input_embeddings is not None:
|
|
441
|
+
return model(
|
|
442
|
+
input_tokens, cache=prompt_cache, input_embeddings=input_embeddings
|
|
443
|
+
)
|
|
444
|
+
else:
|
|
445
|
+
return model(input_tokens, cache=prompt_cache)
|
|
446
|
+
|
|
447
|
+
def _step(input_tokens: mx.array, input_embeddings: Optional[mx.array] = None):
|
|
448
|
+
nonlocal tokens
|
|
449
|
+
|
|
450
|
+
with mx.stream(generation_stream):
|
|
451
|
+
logits = _model_call(
|
|
452
|
+
input_tokens=input_tokens[None],
|
|
453
|
+
input_embeddings=(
|
|
454
|
+
input_embeddings[None] if input_embeddings is not None else None
|
|
455
|
+
),
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
logits = logits[:, -1, :]
|
|
459
|
+
|
|
460
|
+
if logits_processors and len(input_tokens) > 0:
|
|
461
|
+
tokens = (
|
|
462
|
+
mx.concat([tokens, input_tokens])
|
|
463
|
+
if tokens is not None
|
|
464
|
+
else input_tokens
|
|
465
|
+
)
|
|
466
|
+
for processor in logits_processors:
|
|
467
|
+
logits = processor(tokens, logits)
|
|
468
|
+
|
|
469
|
+
quantize_cache_fn(prompt_cache)
|
|
470
|
+
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
|
471
|
+
sampled = sampler(logprobs)
|
|
472
|
+
return sampled, logprobs.squeeze(0)
|
|
473
|
+
|
|
474
|
+
with mx.stream(generation_stream):
|
|
475
|
+
total_prompt_tokens = (
|
|
476
|
+
len(input_embeddings) if input_embeddings is not None else len(prompt)
|
|
477
|
+
)
|
|
478
|
+
prompt_processed_tokens = 0
|
|
479
|
+
prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens)
|
|
480
|
+
while total_prompt_tokens - prompt_processed_tokens > 1:
|
|
481
|
+
remaining = (total_prompt_tokens - prompt_processed_tokens) - 1
|
|
482
|
+
n_to_process = min(prefill_step_size, remaining)
|
|
483
|
+
if prompt_processed_tokens < save_at:
|
|
484
|
+
n_to_process = min(n_to_process, save_at - prompt_processed_tokens)
|
|
485
|
+
|
|
486
|
+
_model_call(
|
|
487
|
+
input_tokens=prompt[:n_to_process][None],
|
|
488
|
+
input_embeddings=(
|
|
489
|
+
input_embeddings[:n_to_process][None]
|
|
490
|
+
if input_embeddings is not None
|
|
491
|
+
else None
|
|
492
|
+
),
|
|
493
|
+
)
|
|
494
|
+
quantize_cache_fn(prompt_cache)
|
|
495
|
+
mx.eval([c.state for c in prompt_cache])
|
|
496
|
+
prompt_processed_tokens += n_to_process
|
|
497
|
+
prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens)
|
|
498
|
+
prompt = prompt[n_to_process:]
|
|
499
|
+
input_embeddings = (
|
|
500
|
+
input_embeddings[n_to_process:]
|
|
501
|
+
if input_embeddings is not None
|
|
502
|
+
else input_embeddings
|
|
503
|
+
)
|
|
504
|
+
mx.clear_cache()
|
|
505
|
+
if save_fn is not None and prompt_processed_tokens == save_at:
|
|
506
|
+
save_fn(prompt_cache)
|
|
507
|
+
|
|
508
|
+
y, logprobs = _step(input_tokens=prompt, input_embeddings=input_embeddings)
|
|
509
|
+
|
|
510
|
+
mx.async_eval(y, logprobs)
|
|
511
|
+
n = 0
|
|
512
|
+
while True:
|
|
513
|
+
if n != max_tokens:
|
|
514
|
+
next_y, next_logprobs = _step(y)
|
|
515
|
+
mx.async_eval(next_y, next_logprobs)
|
|
516
|
+
if n == 0:
|
|
517
|
+
mx.eval(y)
|
|
518
|
+
prompt_progress_callback(total_prompt_tokens, total_prompt_tokens)
|
|
519
|
+
if n == max_tokens:
|
|
520
|
+
break
|
|
521
|
+
yield y.item(), logprobs
|
|
522
|
+
if n % 256 == 0:
|
|
523
|
+
mx.clear_cache()
|
|
524
|
+
y, logprobs = next_y, next_logprobs
|
|
525
|
+
n += 1
|
|
526
|
+
|
|
527
|
+
def generate(model, tokenizer, prompt, pref, hook=None, max_tokens=256, helper_max_tokens=64, **kwargs):
|
|
528
|
+
global dict_cache
|
|
529
|
+
if prompt is None:
|
|
530
|
+
return '', 0, 0
|
|
531
|
+
if not isinstance(tokenizer, mlx_lm.tokenizer_utils.TokenizerWrapper):
|
|
532
|
+
tokenizer = mlx_lm.tokenizer_utils.TokenizerWrapper(tokenizer)
|
|
533
|
+
detokenizer = tokenizer.detokenizer
|
|
534
|
+
if isinstance(prompt, str):
|
|
535
|
+
add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(tokenizer.bos_token)
|
|
536
|
+
prompt_s = prompt
|
|
537
|
+
prompt = tokenizer.encode(prompt, add_special_tokens=add_special_tokens)
|
|
538
|
+
_pref = tokenizer.encode(pref, add_special_tokens=add_special_tokens)
|
|
539
|
+
save_at = get_common_len(prompt, _pref)
|
|
540
|
+
else:
|
|
541
|
+
prompt_s = tokenizer.decode(prompt)
|
|
542
|
+
save_at = -1 # □ for now
|
|
543
|
+
stream_logger.debug(dmca(prompt_s))
|
|
544
|
+
text = ''
|
|
545
|
+
gens = []
|
|
546
|
+
common_len = 0
|
|
547
|
+
hx_len = None
|
|
548
|
+
trim_len = None
|
|
549
|
+
save_fn = None
|
|
550
|
+
if not dict_cache.get('cache'):
|
|
551
|
+
ckpt_path = Path(dict_cache['cache_dir'])/f'{"".join(c for c in dict_cache["model_name"] if c.isalnum())}_{save_at}_{hash_tokens(prompt[:save_at])}.safetensors'
|
|
552
|
+
trace_logger.debug(ckpt_path.resolve())
|
|
553
|
+
trace_logger.debug(ckpt_path.absolute())
|
|
554
|
+
if os.path.exists(ckpt_path):
|
|
555
|
+
load_dict_cache(ckpt_path)
|
|
556
|
+
else:
|
|
557
|
+
dict_cache |= dict(cache=mlx_lm.models.cache.make_prompt_cache(model), hx=[])
|
|
558
|
+
save_fn = functools.partial(save_dict_cache, ckpt_path, dict(model_name=dict_cache['model_name'], hx=json.dumps(prompt[:save_at+1])))
|
|
559
|
+
|
|
560
|
+
if (hx := dict_cache.get('hx')):
|
|
561
|
+
_hx = hx[:-1]
|
|
562
|
+
common_len = get_common_len(prompt, _hx)
|
|
563
|
+
hx_len = len(_hx)
|
|
564
|
+
trim_len = hx_len - common_len
|
|
565
|
+
if trim_len > 0:
|
|
566
|
+
if all(c.is_trimmable() for c in dict_cache['cache']):
|
|
567
|
+
mlx_lm.models.cache.trim_prompt_cache(dict_cache['cache'], trim_len)
|
|
568
|
+
else:
|
|
569
|
+
ckpt_path = Path(dict_cache['cache_dir'])/f'{"".join(c for c in dict_cache["model_name"] if c.isalnum())}_{save_at}_{hash_tokens(prompt[:save_at])}.safetensors'
|
|
570
|
+
if os.path.exists(ckpt_path):
|
|
571
|
+
load_dict_cache(ckpt_path)
|
|
572
|
+
common_len = save_at
|
|
573
|
+
if save_at > common_len and not all(c.is_trimmable() for c in dict_cache['cache']):
|
|
574
|
+
ckpt_path = Path(dict_cache['cache_dir'])/f'{"".join(c for c in dict_cache["model_name"] if c.isalnum())}_{save_at}_{hash_tokens(prompt[:save_at])}.safetensors'
|
|
575
|
+
save_fn = functools.partial(save_dict_cache, ckpt_path, dict(model_name=dict_cache['model_name'], hx=json.dumps(prompt[:save_at+1])))
|
|
576
|
+
else:
|
|
577
|
+
save_at = -1
|
|
578
|
+
|
|
579
|
+
if common_len==len(prompt):
|
|
580
|
+
_last_gen = dict_cache['hx'][common_len]
|
|
581
|
+
prompt_arr = mx.array([_last_gen])
|
|
582
|
+
gens.append(_last_gen)
|
|
583
|
+
detokenizer.add(_last_gen)
|
|
584
|
+
else:
|
|
585
|
+
prompt_arr = mx.array(prompt[common_len:])
|
|
586
|
+
|
|
587
|
+
trace_logger.debug(f'{save_at=} {common_len=}')
|
|
588
|
+
token_gen = generate_step(
|
|
589
|
+
prompt_arr,
|
|
590
|
+
model,
|
|
591
|
+
prompt_cache=dict_cache['cache'],
|
|
592
|
+
max_tokens=max_tokens,
|
|
593
|
+
save_at=save_at-common_len,
|
|
594
|
+
save_fn=save_fn,
|
|
595
|
+
**kwargs,
|
|
596
|
+
)
|
|
597
|
+
tic_non = time.perf_counter()
|
|
598
|
+
for token, _ in token_gen:
|
|
599
|
+
gens.append(token)
|
|
600
|
+
if token in tokenizer.eos_token_ids:
|
|
601
|
+
break
|
|
602
|
+
detokenizer.add_token(token)
|
|
603
|
+
seg = detokenizer.last_segment
|
|
604
|
+
stream_logger.debug(seg)
|
|
605
|
+
text += seg
|
|
606
|
+
if len(gens) == 1:
|
|
607
|
+
tic_inp = time.perf_counter()
|
|
608
|
+
if len(gens) >= max_tokens:
|
|
609
|
+
break
|
|
610
|
+
tic_out = time.perf_counter()
|
|
611
|
+
detokenizer.finalize()
|
|
612
|
+
text += detokenizer.last_segment
|
|
613
|
+
dict_cache['hx'] = prompt+gens
|
|
614
|
+
trace_logger.debug(f'G {hx_len} {len(prompt)} {common_len} {trim_len} {len(gens)}\n=== TPS ===\n- Processed {len(prompt)} input tokens in {tic_inp-tic_non:.0f} seconds ({len(prompt)/(tic_inp-tic_non):.0f} tokens per second)\n- Generated {len(gens)} new tokens in {tic_out-tic_inp:.0f} seconds ({len(gens)/(tic_out-tic_inp):.0f} tokens per second)\n\n=== INP ===\n{dmca(prompt_s)}\n=== OUT ===\n{text}')
|
|
615
|
+
return text, len(prompt), len(gens)
|
|
616
|
+
|
|
383
617
|
if __name__ == "__main__":
|
|
384
618
|
main()
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: mlx-code
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.2a2
|
|
4
4
|
Summary: Local Claude Code for Mac
|
|
5
5
|
Home-page: https://github.com/JosefAlbers/mlx-code
|
|
6
6
|
Author: J Joe
|
|
@@ -52,19 +52,7 @@ mlx-code [options] [-- claude options]
|
|
|
52
52
|
| `--work` | `$CWD` | Working directory mirrored into the Claude session |
|
|
53
53
|
| `--home` | temp dir | Home directory for the Claude process |
|
|
54
54
|
|
|
55
|
-
Any extra arguments after `--` are forwarded to the `claude` CLI
|
|
56
|
-
|
|
57
|
-
| Command | What it does | Example |
|
|
58
|
-
|--------|--------------|--------|
|
|
59
|
-
| `mlx-code` | Start interactive mode | `mlx-code` |
|
|
60
|
-
| `mlx-code "task"` | Run a one-time task | `mlx-code "fix the build error"` |
|
|
61
|
-
| `mlx-code -p "query"` | Run one-off query, then exit | `mlx-code -p "explain this function"` |
|
|
62
|
-
| `mlx-code -c` | Continue most recent conversation in current directory | `mlx-code -c` |
|
|
63
|
-
| `mlx-code -r` | Resume a previous conversation | `mlx-code -r` |
|
|
64
|
-
| `mlx-code commit` | Create a Git commit | `mlx-code commit` |
|
|
65
|
-
| `/clear` | Clear conversation history | `/clear` |
|
|
66
|
-
| `/help` | Show available commands | `/help` |
|
|
67
|
-
| `exit` or `Ctrl+C` | Exit Claude Code | `exit` |
|
|
55
|
+
Any extra arguments after `--` are forwarded to the `claude` CLI.
|
|
68
56
|
|
|
69
57
|
### Licence
|
|
70
58
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|