mlx-code 0.0.2a0__tar.gz → 0.0.2a2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mlx-code
3
- Version: 0.0.2a0
3
+ Version: 0.0.2a2
4
4
  Summary: Local Claude Code for Mac
5
5
  Home-page: https://github.com/JosefAlbers/mlx-code
6
6
  Author: J Joe
@@ -52,19 +52,7 @@ mlx-code [options] [-- claude options]
52
52
  | `--work` | `$CWD` | Working directory mirrored into the Claude session |
53
53
  | `--home` | temp dir | Home directory for the Claude process |
54
54
 
55
- Any extra arguments after `--` are forwarded to the `claude` CLI:
56
-
57
- | Command | What it does | Example |
58
- |--------|--------------|--------|
59
- | `mlx-code` | Start interactive mode | `mlx-code` |
60
- | `mlx-code "task"` | Run a one-time task | `mlx-code "fix the build error"` |
61
- | `mlx-code -p "query"` | Run one-off query, then exit | `mlx-code -p "explain this function"` |
62
- | `mlx-code -c` | Continue most recent conversation in current directory | `mlx-code -c` |
63
- | `mlx-code -r` | Resume a previous conversation | `mlx-code -r` |
64
- | `mlx-code commit` | Create a Git commit | `mlx-code commit` |
65
- | `/clear` | Clear conversation history | `/clear` |
66
- | `/help` | Show available commands | `/help` |
67
- | `exit` or `Ctrl+C` | Exit Claude Code | `exit` |
55
+ Any extra arguments after `--` are forwarded to the `claude` CLI.
68
56
 
69
57
  ### Licence
70
58
 
@@ -28,19 +28,7 @@ mlx-code [options] [-- claude options]
28
28
  | `--work` | `$CWD` | Working directory mirrored into the Claude session |
29
29
  | `--home` | temp dir | Home directory for the Claude process |
30
30
 
31
- Any extra arguments after `--` are forwarded to the `claude` CLI:
32
-
33
- | Command | What it does | Example |
34
- |--------|--------------|--------|
35
- | `mlx-code` | Start interactive mode | `mlx-code` |
36
- | `mlx-code "task"` | Run a one-time task | `mlx-code "fix the build error"` |
37
- | `mlx-code -p "query"` | Run one-off query, then exit | `mlx-code -p "explain this function"` |
38
- | `mlx-code -c` | Continue most recent conversation in current directory | `mlx-code -c` |
39
- | `mlx-code -r` | Resume a previous conversation | `mlx-code -r` |
40
- | `mlx-code commit` | Create a Git commit | `mlx-code commit` |
41
- | `/clear` | Clear conversation history | `/clear` |
42
- | `/help` | Show available commands | `/help` |
43
- | `exit` or `Ctrl+C` | Exit Claude Code | `exit` |
31
+ Any extra arguments after `--` are forwarded to the `claude` CLI.
44
32
 
45
33
  ### Licence
46
34
 
@@ -1,3 +1,4 @@
1
+ # {{{
1
2
  # Copyright 2026 J Joe
2
3
  # Licensed under the Apache License, Version 2.0 (the "License");
3
4
  # you may not use this file except in compliance with the License.
@@ -25,8 +26,22 @@ from http.server import BaseHTTPRequestHandler, HTTPServer
25
26
  from pathlib import Path
26
27
  import mlx.core as mx
27
28
  import mlx_lm
28
- from mlx_lm.generate import generate_step
29
+ import numpy as np
30
+ import hashlib
31
+ import contextlib
32
+ import functools
33
+ import mlx.nn as nn
34
+ from typing import (
35
+ Any,
36
+ Callable,
37
+ Generator,
38
+ List,
39
+ Optional,
40
+ Tuple,
41
+ Union,
42
+ )
29
43
 
44
+ generation_stream = mx.new_stream(mx.default_device())
30
45
  stream_logger = logging.getLogger("stream")
31
46
  stream_logger.setLevel(logging.DEBUG)
32
47
  s_handler = logging.FileHandler("mlx_stream.log", mode='w')
@@ -39,7 +54,54 @@ t_handler = logging.FileHandler("mlx_trace.log", mode='w')
39
54
  t_handler.setFormatter(logging.Formatter("【%(message)s\n】\n"))
40
55
  trace_logger.addHandler(t_handler)
41
56
  gen_lock = threading.Lock()
42
- prompt_cache = {}
57
+ dict_cache = {}
58
+
59
+ def hash_tokens(tokens):
60
+ arr = np.array(tokens, dtype=np.uint32)
61
+ return hashlib.blake2b(arr.tobytes(), digest_size=8).hexdigest()
62
+
63
+ def get_common_len(a, b):
64
+ common_len = 0
65
+ for p, h in zip(a, b):
66
+ if p == h:
67
+ common_len += 1
68
+ else:
69
+ break
70
+ return common_len
71
+
72
+ @contextlib.contextmanager
73
+ def wired_limit(model: nn.Module, streams: Optional[List[mx.Stream]] = None):
74
+ if not mx.metal.is_available():
75
+ try:
76
+ yield
77
+ finally:
78
+ pass
79
+ else:
80
+ model_bytes = tree_reduce(
81
+ lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc, model, 0
82
+ )
83
+ max_rec_size = mx.device_info()["max_recommended_working_set_size"]
84
+ if model_bytes > 0.9 * max_rec_size:
85
+ model_mb = model_bytes // 2**20
86
+ max_rec_mb = max_rec_size // 2**20
87
+ print(f"{model_mb=} {max_rec_mb=}")
88
+ old_limit = mx.set_wired_limit(max_rec_size)
89
+ try:
90
+ yield
91
+ finally:
92
+ if streams is not None:
93
+ for s in streams:
94
+ mx.synchronize(s)
95
+ else:
96
+ mx.synchronize()
97
+ mx.set_wired_limit(old_limit)
98
+
99
+ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits):
100
+ if kv_bits is None:
101
+ return
102
+ for e, c in enumerate(prompt_cache):
103
+ if hasattr(c, "to_quantized") and c.offset >= quantized_kv_start:
104
+ prompt_cache[e] = c.to_quantized(group_size=kv_group_size, bits=kv_bits)
43
105
 
44
106
  def parse_tool(tools, names):
45
107
  qwen_tools = []
@@ -47,17 +109,18 @@ def parse_tool(tools, names):
47
109
  if names is not None and tool["name"] not in names:
48
110
  continue
49
111
  qwen_tool = {
50
- "type": "function",
51
- "function": {
112
+ # "type": "function",
113
+ # "function": {
52
114
  "name": tool["name"],
53
115
  "description": tool["description"],
54
116
  "parameters": tool.get("input_schema", {
55
117
  "type": "object",
56
118
  "properties": {}
57
119
  })
58
- }
120
+ # }
59
121
  }
60
- params = qwen_tool["function"]["parameters"]
122
+ # params = qwen_tool["function"]["parameters"]
123
+ params = qwen_tool["parameters"]
61
124
  params.pop("$schema", None)
62
125
  qwen_tools.append(qwen_tool)
63
126
  return qwen_tools
@@ -77,24 +140,21 @@ def encode(body, tokenizer, system, names, skips):
77
140
  if block.get("type") != "text":
78
141
  continue
79
142
  text = block.get("text", "").strip()
80
- if re.match(r'^\S+:\s', text) and '\n' not in text:
143
+ if re.match(r'^x-anthropic-billing-header:\s?.*;$', text) and '\n' not in text:
81
144
  continue
82
145
  if text:
83
146
  sys_parts.append(text)
84
147
  if sys_parts:
85
148
  msgs.append({"role": "system", "content": "\n\n".join(sys_parts)})
86
149
  calls = {}
87
- def skip(text, show_skipped=True):
150
+ def skip(text, show_skipped=False):
88
151
  if skips is None:
89
152
  return text
90
153
  lines = []
91
154
  for pattern in skips:
92
155
  found = re.findall(pattern, text)
93
156
  if found:
94
- lines.append(
95
- f"{pattern}\n" +
96
- "\n".join(found)
97
- )
157
+ lines.append(f"{pattern}\n" + "\n".join(found))
98
158
  if lines and show_skipped:
99
159
  trace_logger.debug("\n".join(["S"]+lines))
100
160
  for pattern in skips:
@@ -109,7 +169,7 @@ def encode(body, tokenizer, system, names, skips):
109
169
  for block in content:
110
170
  t = block.get("type")
111
171
  if t == "text":
112
- parts['content'] = parts.get('content', '').rstrip() + '\n' + skip(block['text']).rstrip()
172
+ parts['content'] = parts.get('content', '') + skip(block['text'])
113
173
  elif t == "thinking":
114
174
  parts['reasoning_content'] = block['thinking']
115
175
  elif t == "tool_use":
@@ -123,25 +183,39 @@ def encode(body, tokenizer, system, names, skips):
123
183
  if parts:
124
184
  msgs.append({"role": role}|parts)
125
185
  if not msgs[-1].get('content', '').strip():
126
- return None
127
- return tokenizer.apply_chat_template(msgs, tools = parse_tool(body.get("tools", []), names), tokenize=False, add_generation_prompt=True)
186
+ return None, -1
187
+ apply_chat_template = lambda x: tokenizer.apply_chat_template(x, tools = parse_tool(body.get("tools", []), names), tokenize=False, add_generation_prompt=True)
188
+ full = apply_chat_template(msgs)
189
+ last_user_idx = max((i for i, m in enumerate(msgs) if m.get("role") == "user"), default=None)
190
+ if last_user_idx is None:
191
+ return full, -1
192
+ p_msgs = msgs[:last_user_idx] + [dict(role='user', content='h' if msgs[last_user_idx]['content'][0] != 'h' else 'i')]
193
+ pref = apply_chat_template(p_msgs)
194
+ return full, pref
128
195
 
129
- def decode(raw_text, tokenizer, parse_think=True):
196
+ def decode(raw_text, tokenizer, parse_think, single_think=False):
197
+ def escape(text):
198
+ def repl(match):
199
+ inner = match.group(1)
200
+ inner = inner.replace('<', '‹').replace('>', '›')
201
+ return f'`{inner}`'
202
+ return re.sub(r'`([^\n`]*)`', repl, text)
203
+ raw_text = escape(raw_text)
130
204
  raw_text = '<think>' + raw_text if (c := raw_text.find('</think>')) != -1 and ((o := raw_text.find('<think>')) == -1 or c < o) else raw_text
131
205
  blocks = []
132
206
  if parse_think:
133
- parts = re.split(r'(<think>.*?</think>)', raw_text, flags=re.DOTALL)
207
+ parts = re.split(r'(<think>.*?</think>)', raw_text, flags=re.DOTALL, maxsplit=1 if single_think else 0)
134
208
  else:
135
209
  parts = [raw_text]
136
210
  for part in parts:
137
211
  if not part:
138
- continue
139
- if parse_think and part.startswith('<think>') and part.endswith('</think>'):
212
+ continue
213
+ if parse_think and not single_think and part.startswith('<think>') and part.endswith('</think>'):
140
214
  thinking_content = part[7:-8].strip()
141
215
  if thinking_content:
142
216
  blocks.append({"type": "thinking", "thinking": thinking_content})
143
217
  else:
144
- blocks.append({"type": "text", "text": part})
218
+ blocks.append({"type": "text", "text": re.sub(r'</?think>', '‹think›', part)}) #: show tool call
145
219
  tool_pattern = re.compile(r'<tool_call>(.*?)</tool_call>', re.DOTALL)
146
220
  for match in tool_pattern.finditer(part):
147
221
  content = match.group(1).strip()
@@ -185,7 +259,7 @@ def blocks_to_sse(blocks: list[dict], msg_id: str, in_tokens: int, out_tokens: i
185
259
  elif bt == "tool_use":
186
260
  out += event("content_block_start", {"type": "content_block_start", "index": i,
187
261
  "content_block": {"type": "tool_use", "id": block["id"],
188
- "name": block["name"], "input": {}}})
262
+ "name": block["name"], "input": {}} })
189
263
  out += event("content_block_delta", {"type": "content_block_delta", "index": i,
190
264
  "delta": {"type": "input_json_delta", "partial_json": json.dumps(block["input"])}})
191
265
  out += event("content_block_stop", {"type": "content_block_stop", "index": i})
@@ -196,6 +270,8 @@ def blocks_to_sse(blocks: list[dict], msg_id: str, in_tokens: int, out_tokens: i
196
270
  return bytes(out)
197
271
 
198
272
  def dmca(p_str):
273
+ if True: #: False for recording
274
+ return p_str
199
275
  symbols = ["▲", "△", "▶", "▷", "▼", "▽", "◀", "◁", "◆", "◇"]
200
276
  def mask_text(text):
201
277
  return re.sub(r"\S", lambda _: random.choice(symbols), text)
@@ -211,64 +287,6 @@ def dmca(p_str):
211
287
  p_str = re.sub(pattern, lambda m: mask_text(m.group(0)), p_str)
212
288
  return p_str
213
289
 
214
- def generate(model, tokenizer, prompt, hook=None, max_tokens=256, helper_max_tokens=64, **kwargs):
215
- global prompt_cache
216
- if prompt is None:
217
- return '', 0, 0
218
- if not isinstance(tokenizer, mlx_lm.tokenizer_utils.TokenizerWrapper):
219
- tokenizer = mlx_lm.tokenizer_utils.TokenizerWrapper(tokenizer)
220
- detokenizer = tokenizer.detokenizer
221
- if isinstance(prompt, str):
222
- add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(tokenizer.bos_token)
223
- prompt_s = prompt
224
- prompt = tokenizer.encode(prompt, add_special_tokens=add_special_tokens)
225
- else:
226
- prompt_s = tokenizer.decode(prompt)
227
- stream_logger.debug(dmca(prompt_s))
228
- common_len = 0
229
- if prompt_cache.get('cache', None):
230
- for p, h in zip(prompt, prompt_cache['hx']):
231
- if p == h:
232
- common_len += 1
233
- else:
234
- break
235
- else:
236
- prompt_cache['hx'] = []
237
- prompt_cache['cache'] = mlx_lm.models.cache.make_prompt_cache(model)
238
- trim_len = len(prompt_cache['hx']) - common_len
239
- mlx_lm.models.cache.trim_prompt_cache(prompt_cache['cache'], trim_len)
240
- token_gen = generate_step(
241
- mx.array(prompt[common_len:]),
242
- model,
243
- prompt_cache=prompt_cache['cache'],
244
- max_tokens=max_tokens,
245
- **kwargs,
246
- )
247
- text = ""
248
- tic_non = time.perf_counter()
249
- gens = []
250
- for token, _ in token_gen:
251
- gens.append(token)
252
- if token in tokenizer.eos_token_ids:
253
- break
254
- detokenizer.add_token(token)
255
- seg = detokenizer.last_segment
256
- stream_logger.debug(seg)
257
- text += seg
258
- if len(gens) == 1:
259
- tic_inp = time.perf_counter()
260
- if prompt_cache.get('file_name'):
261
- _fn = prompt_cache.pop('file_name')
262
- mlx_lm.models.cache.save_prompt_cache(_fn, prompt_cache['cache'], metadata=dict(model_name=prompt_cache['model_name'], hx=json.dumps(prompt)))
263
- if len(gens) >= max_tokens:
264
- break
265
- tic_out = time.perf_counter()
266
- detokenizer.finalize()
267
- text += detokenizer.last_segment
268
- prompt_cache['hx'] = prompt+gens
269
- trace_logger.debug(f'G {common_len} {trim_len}\n=== TPS ===\n- Processed {len(prompt)} input tokens in {tic_inp-tic_non:.0f} seconds ({len(prompt)/(tic_inp-tic_non):.0f} tokens per second)\n- Generated {len(gens)} new tokens in {tic_out-tic_inp:.0f} seconds ({len(gens)/(tic_out-tic_inp):.0f} tokens per second)\n\n=== INP ===\n{prompt_s}\n=== OUT ===\n{text}')
270
- return text, len(prompt), len(gens)
271
-
272
290
  def make_handler(model, tokenizer, system, names, skips, parse_think=True):
273
291
  class Handler(BaseHTTPRequestHandler):
274
292
  def log_message(self, fmt, *args):
@@ -299,9 +317,9 @@ def make_handler(model, tokenizer, system, names, skips, parse_think=True):
299
317
  return
300
318
  n = int(self.headers.get("Content-Length", 0))
301
319
  body = json.loads(self.rfile.read(n))
302
- prompt = encode(body, tokenizer, system, names, skips)
320
+ prompt, pref = encode(body, tokenizer, system, names, skips)
303
321
  with gen_lock:
304
- raw, in_tokens, out_tokens = generate(model, tokenizer, prompt=prompt, max_tokens=body.get("max_tokens", 8192))
322
+ raw, in_tokens, out_tokens = generate(model, tokenizer, pref=pref, prompt=prompt, max_tokens=body.get("max_tokens", 8192))
305
323
  blocks, stop_reason = decode(raw, tokenizer, parse_think=parse_think)
306
324
  msg_id = f"msg_{uuid.uuid4().hex}"
307
325
  sse = blocks_to_sse(blocks, msg_id, in_tokens, out_tokens, stop_reason)
@@ -317,18 +335,33 @@ def make_handler(model, tokenizer, system, names, skips, parse_think=True):
317
335
  pass
318
336
  return Handler
319
337
 
338
+ def load_dict_cache(cache_path):
339
+ global dict_cache
340
+ cache, metadata = mlx_lm.models.cache.load_prompt_cache(cache_path, return_metadata=True)
341
+ mx.eval(cache)
342
+ model_name = metadata.pop("model_name", "")
343
+ tokens_str = metadata.pop("hx", "[]")
344
+ tokens = json.loads(tokens_str)
345
+ dict_cache = dict(cache=cache, hx=tokens, model_name=model_name)
346
+
347
+ def save_dict_cache(cache_path, metadata, prompt_cache):
348
+ mlx_lm.models.cache.save_prompt_cache(cache_path, prompt_cache, metadata=metadata)
349
+
320
350
  def main():
321
351
  parser = argparse.ArgumentParser()
322
352
  parser.add_argument("--model", default="mlx-community/Qwen3.5-4B-OptiQ-4bit")
323
353
  # parser.add_argument("--model", default="mlx-community/Qwen3.5-2B-OptiQ-4bit")
324
354
  # parser.add_argument("--model", default="mlx-community/Qwen3.5-0.8B-MLX-bf16")
325
- parser.add_argument("--system", type=str, default='# Env\n{env}')
355
+ parser.add_argument("--system", type=str, default='')
356
+ # parser.add_argument("--system", type=str, default='# Env\n{env}')
326
357
  # parser.add_argument("--system", type=str, default=None)
327
- parser.add_argument("--cache", type=str, default='cache/cache.safetensors')
358
+ # parser.add_argument("--cache", type=str, default='cache/cache.safetensors')
359
+ parser.add_argument("--cache", type=str, default='cache')
360
+ # parser.add_argument("--names", nargs="+", default=[])
328
361
  parser.add_argument("--names", nargs="+", default=['Read','Edit','Write','Grep','Glob','Bash','Agent','Skill'])
329
362
  # parser.add_argument("--names", nargs="+", default=None)
330
363
  parser.add_argument("--skips", nargs="+", default=[
331
- r'(?m)^\[SUGGESTION MODE[\s\S]*'
364
+ r'(?m)^\[SUGGESTION MODE[\s\S]*',
332
365
  r'(?m)^<system-reminder>[\s\S]*?^</system-reminder>\s*',
333
366
  ])
334
367
  parser.add_argument("--port", type=int, default=8000)
@@ -336,20 +369,9 @@ def main():
336
369
  parser.add_argument("--home", default=tempfile.mkdtemp())
337
370
  parser.add_argument("--work", default=os.getcwd())
338
371
  args, claude_args = parser.parse_known_args()
339
- global prompt_cache
340
- if os.path.exists(args.cache):
341
- cache, metadata = mlx_lm.models.cache.load_prompt_cache(args.cache, return_metadata=True)
342
- mx.eval(cache)
343
- model_name = metadata.pop("model_name", "")
344
- tokens_str = metadata.pop("hx", "[]")
345
- tokens = json.loads(tokens_str)
346
- prompt_cache = dict(cache=cache, hx=tokens, model_name=model_name)
347
- if prompt_cache.get('model_name') != args.model:
348
- prompt_cache = dict(model_name=args.model)
349
- else:
350
- Path(args.cache).parent.mkdir(parents=True, exist_ok=True)
351
- prompt_cache = dict(model_name=args.model)
352
- prompt_cache['file_name']=args.cache
372
+ Path(args.cache).mkdir(parents=True, exist_ok=True)
373
+ global dict_cache
374
+ dict_cache = dict(model_name=args.model, cache_dir = args.cache)
353
375
  model, tokenizer = mlx_lm.load(args.model)
354
376
  server = HTTPServer((args.host, args.port), make_handler(model, tokenizer, args.system, args.names, args.skips))
355
377
  threading.Thread(target=server.serve_forever, daemon=True).start()
@@ -369,5 +391,228 @@ def main():
369
391
  mirror_workspace(args.work, workspace)
370
392
  sys.exit(subprocess.run(["claude"] + claude_args, env=env, cwd=workspace).returncode)
371
393
 
394
+ def generate_step(
395
+ prompt: mx.array,
396
+ model: nn.Module,
397
+ *,
398
+ max_tokens: int = 256,
399
+ sampler: Optional[Callable[[mx.array], mx.array]] = None,
400
+ logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
401
+ max_kv_size: Optional[int] = None,
402
+ prompt_cache: Optional[Any] = None,
403
+ prefill_step_size: int = 2048,
404
+ kv_bits: Optional[int] = None,
405
+ kv_group_size: int = 64,
406
+ quantized_kv_start: int = 0,
407
+ prompt_progress_callback: Optional[Callable[[int, int], None]] = None,
408
+ input_embeddings: Optional[mx.array] = None,
409
+ save_at: int = -1,
410
+ save_fn = None,
411
+ ) -> Generator[Tuple[mx.array, mx.array], None, None]:
412
+ if input_embeddings is not None:
413
+ if not does_model_support_input_embeddings(model):
414
+ raise ValueError("Model does not support input embeddings.")
415
+ elif len(prompt) > 0 and len(prompt) != len(input_embeddings):
416
+ raise ValueError(f"{len(input_embeddings)=} {len(prompt)=}")
417
+ elif len(prompt) == 0:
418
+ raise ValueError("Either input_embeddings or prompt (or both) must be provided.")
419
+
420
+ tokens = None
421
+
422
+ if prompt_cache is None:
423
+ prompt_cache = cache.make_prompt_cache(
424
+ model,
425
+ max_kv_size=max_kv_size,
426
+ )
427
+
428
+ prompt_progress_callback = prompt_progress_callback or (lambda *_: None)
429
+
430
+ quantize_cache_fn = functools.partial(
431
+ maybe_quantize_kv_cache,
432
+ quantized_kv_start=quantized_kv_start,
433
+ kv_group_size=kv_group_size,
434
+ kv_bits=kv_bits,
435
+ )
436
+
437
+ sampler = sampler or (lambda x: mx.argmax(x, axis=-1))
438
+
439
+ def _model_call(input_tokens: mx.array, input_embeddings: Optional[mx.array]):
440
+ if input_embeddings is not None:
441
+ return model(
442
+ input_tokens, cache=prompt_cache, input_embeddings=input_embeddings
443
+ )
444
+ else:
445
+ return model(input_tokens, cache=prompt_cache)
446
+
447
+ def _step(input_tokens: mx.array, input_embeddings: Optional[mx.array] = None):
448
+ nonlocal tokens
449
+
450
+ with mx.stream(generation_stream):
451
+ logits = _model_call(
452
+ input_tokens=input_tokens[None],
453
+ input_embeddings=(
454
+ input_embeddings[None] if input_embeddings is not None else None
455
+ ),
456
+ )
457
+
458
+ logits = logits[:, -1, :]
459
+
460
+ if logits_processors and len(input_tokens) > 0:
461
+ tokens = (
462
+ mx.concat([tokens, input_tokens])
463
+ if tokens is not None
464
+ else input_tokens
465
+ )
466
+ for processor in logits_processors:
467
+ logits = processor(tokens, logits)
468
+
469
+ quantize_cache_fn(prompt_cache)
470
+ logprobs = logits - mx.logsumexp(logits, keepdims=True)
471
+ sampled = sampler(logprobs)
472
+ return sampled, logprobs.squeeze(0)
473
+
474
+ with mx.stream(generation_stream):
475
+ total_prompt_tokens = (
476
+ len(input_embeddings) if input_embeddings is not None else len(prompt)
477
+ )
478
+ prompt_processed_tokens = 0
479
+ prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens)
480
+ while total_prompt_tokens - prompt_processed_tokens > 1:
481
+ remaining = (total_prompt_tokens - prompt_processed_tokens) - 1
482
+ n_to_process = min(prefill_step_size, remaining)
483
+ if prompt_processed_tokens < save_at:
484
+ n_to_process = min(n_to_process, save_at - prompt_processed_tokens)
485
+
486
+ _model_call(
487
+ input_tokens=prompt[:n_to_process][None],
488
+ input_embeddings=(
489
+ input_embeddings[:n_to_process][None]
490
+ if input_embeddings is not None
491
+ else None
492
+ ),
493
+ )
494
+ quantize_cache_fn(prompt_cache)
495
+ mx.eval([c.state for c in prompt_cache])
496
+ prompt_processed_tokens += n_to_process
497
+ prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens)
498
+ prompt = prompt[n_to_process:]
499
+ input_embeddings = (
500
+ input_embeddings[n_to_process:]
501
+ if input_embeddings is not None
502
+ else input_embeddings
503
+ )
504
+ mx.clear_cache()
505
+ if save_fn is not None and prompt_processed_tokens == save_at:
506
+ save_fn(prompt_cache)
507
+
508
+ y, logprobs = _step(input_tokens=prompt, input_embeddings=input_embeddings)
509
+
510
+ mx.async_eval(y, logprobs)
511
+ n = 0
512
+ while True:
513
+ if n != max_tokens:
514
+ next_y, next_logprobs = _step(y)
515
+ mx.async_eval(next_y, next_logprobs)
516
+ if n == 0:
517
+ mx.eval(y)
518
+ prompt_progress_callback(total_prompt_tokens, total_prompt_tokens)
519
+ if n == max_tokens:
520
+ break
521
+ yield y.item(), logprobs
522
+ if n % 256 == 0:
523
+ mx.clear_cache()
524
+ y, logprobs = next_y, next_logprobs
525
+ n += 1
526
+
527
+ def generate(model, tokenizer, prompt, pref, hook=None, max_tokens=256, helper_max_tokens=64, **kwargs):
528
+ global dict_cache
529
+ if prompt is None:
530
+ return '', 0, 0
531
+ if not isinstance(tokenizer, mlx_lm.tokenizer_utils.TokenizerWrapper):
532
+ tokenizer = mlx_lm.tokenizer_utils.TokenizerWrapper(tokenizer)
533
+ detokenizer = tokenizer.detokenizer
534
+ if isinstance(prompt, str):
535
+ add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(tokenizer.bos_token)
536
+ prompt_s = prompt
537
+ prompt = tokenizer.encode(prompt, add_special_tokens=add_special_tokens)
538
+ _pref = tokenizer.encode(pref, add_special_tokens=add_special_tokens)
539
+ save_at = get_common_len(prompt, _pref)
540
+ else:
541
+ prompt_s = tokenizer.decode(prompt)
542
+ save_at = -1 # □ for now
543
+ stream_logger.debug(dmca(prompt_s))
544
+ text = ''
545
+ gens = []
546
+ common_len = 0
547
+ hx_len = None
548
+ trim_len = None
549
+ save_fn = None
550
+ if not dict_cache.get('cache'):
551
+ ckpt_path = Path(dict_cache['cache_dir'])/f'{"".join(c for c in dict_cache["model_name"] if c.isalnum())}_{save_at}_{hash_tokens(prompt[:save_at])}.safetensors'
552
+ trace_logger.debug(ckpt_path.resolve())
553
+ trace_logger.debug(ckpt_path.absolute())
554
+ if os.path.exists(ckpt_path):
555
+ load_dict_cache(ckpt_path)
556
+ else:
557
+ dict_cache |= dict(cache=mlx_lm.models.cache.make_prompt_cache(model), hx=[])
558
+ save_fn = functools.partial(save_dict_cache, ckpt_path, dict(model_name=dict_cache['model_name'], hx=json.dumps(prompt[:save_at+1])))
559
+
560
+ if (hx := dict_cache.get('hx')):
561
+ _hx = hx[:-1]
562
+ common_len = get_common_len(prompt, _hx)
563
+ hx_len = len(_hx)
564
+ trim_len = hx_len - common_len
565
+ if trim_len > 0:
566
+ if all(c.is_trimmable() for c in dict_cache['cache']):
567
+ mlx_lm.models.cache.trim_prompt_cache(dict_cache['cache'], trim_len)
568
+ else:
569
+ ckpt_path = Path(dict_cache['cache_dir'])/f'{"".join(c for c in dict_cache["model_name"] if c.isalnum())}_{save_at}_{hash_tokens(prompt[:save_at])}.safetensors'
570
+ if os.path.exists(ckpt_path):
571
+ load_dict_cache(ckpt_path)
572
+ common_len = save_at
573
+ if save_at > common_len and not all(c.is_trimmable() for c in dict_cache['cache']):
574
+ ckpt_path = Path(dict_cache['cache_dir'])/f'{"".join(c for c in dict_cache["model_name"] if c.isalnum())}_{save_at}_{hash_tokens(prompt[:save_at])}.safetensors'
575
+ save_fn = functools.partial(save_dict_cache, ckpt_path, dict(model_name=dict_cache['model_name'], hx=json.dumps(prompt[:save_at+1])))
576
+ else:
577
+ save_at = -1
578
+
579
+ if common_len==len(prompt):
580
+ _last_gen = dict_cache['hx'][common_len]
581
+ prompt_arr = mx.array([_last_gen])
582
+ gens.append(_last_gen)
583
+ detokenizer.add(_last_gen)
584
+ else:
585
+ prompt_arr = mx.array(prompt[common_len:])
586
+
587
+ trace_logger.debug(f'{save_at=} {common_len=}')
588
+ token_gen = generate_step(
589
+ prompt_arr,
590
+ model,
591
+ prompt_cache=dict_cache['cache'],
592
+ max_tokens=max_tokens,
593
+ save_at=save_at-common_len,
594
+ save_fn=save_fn,
595
+ **kwargs,
596
+ )
597
+ tic_non = time.perf_counter()
598
+ for token, _ in token_gen:
599
+ gens.append(token)
600
+ if token in tokenizer.eos_token_ids:
601
+ break
602
+ detokenizer.add_token(token)
603
+ seg = detokenizer.last_segment
604
+ stream_logger.debug(seg)
605
+ text += seg
606
+ if len(gens) == 1:
607
+ tic_inp = time.perf_counter()
608
+ if len(gens) >= max_tokens:
609
+ break
610
+ tic_out = time.perf_counter()
611
+ detokenizer.finalize()
612
+ text += detokenizer.last_segment
613
+ dict_cache['hx'] = prompt+gens
614
+ trace_logger.debug(f'G {hx_len} {len(prompt)} {common_len} {trim_len} {len(gens)}\n=== TPS ===\n- Processed {len(prompt)} input tokens in {tic_inp-tic_non:.0f} seconds ({len(prompt)/(tic_inp-tic_non):.0f} tokens per second)\n- Generated {len(gens)} new tokens in {tic_out-tic_inp:.0f} seconds ({len(gens)/(tic_out-tic_inp):.0f} tokens per second)\n\n=== INP ===\n{dmca(prompt_s)}\n=== OUT ===\n{text}')
615
+ return text, len(prompt), len(gens)
616
+
372
617
  if __name__ == "__main__":
373
618
  main()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mlx-code
3
- Version: 0.0.2a0
3
+ Version: 0.0.2a2
4
4
  Summary: Local Claude Code for Mac
5
5
  Home-page: https://github.com/JosefAlbers/mlx-code
6
6
  Author: J Joe
@@ -52,19 +52,7 @@ mlx-code [options] [-- claude options]
52
52
  | `--work` | `$CWD` | Working directory mirrored into the Claude session |
53
53
  | `--home` | temp dir | Home directory for the Claude process |
54
54
 
55
- Any extra arguments after `--` are forwarded to the `claude` CLI:
56
-
57
- | Command | What it does | Example |
58
- |--------|--------------|--------|
59
- | `mlx-code` | Start interactive mode | `mlx-code` |
60
- | `mlx-code "task"` | Run a one-time task | `mlx-code "fix the build error"` |
61
- | `mlx-code -p "query"` | Run one-off query, then exit | `mlx-code -p "explain this function"` |
62
- | `mlx-code -c` | Continue most recent conversation in current directory | `mlx-code -c` |
63
- | `mlx-code -r` | Resume a previous conversation | `mlx-code -r` |
64
- | `mlx-code commit` | Create a Git commit | `mlx-code commit` |
65
- | `/clear` | Clear conversation history | `/clear` |
66
- | `/help` | Show available commands | `/help` |
67
- | `exit` or `Ctrl+C` | Exit Claude Code | `exit` |
55
+ Any extra arguments after `--` are forwarded to the `claude` CLI.
68
56
 
69
57
  ### Licence
70
58
 
@@ -6,7 +6,7 @@ setup(
6
6
  author_email="albersj66@gmail.com",
7
7
  author="J Joe",
8
8
  license="Apache-2.0",
9
- version="0.0.2a0",
9
+ version="0.0.2a2",
10
10
  readme="README.md",
11
11
  description="Local Claude Code for Mac",
12
12
  long_description=open("README.md").read(),
File without changes
File without changes