mlx-code 0.0.2a1__tar.gz → 0.0.2a3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mlx-code
3
- Version: 0.0.2a1
3
+ Version: 0.0.2a3
4
4
  Summary: Local Claude Code for Mac
5
5
  Home-page: https://github.com/JosefAlbers/mlx-code
6
6
  Author: J Joe
@@ -52,22 +52,12 @@ mlx-code [options] [-- claude options]
52
52
  | `--work` | `$CWD` | Working directory mirrored into the Claude session |
53
53
  | `--home` | temp dir | Home directory for the Claude process |
54
54
 
55
- Any extra arguments after `--` are forwarded to the `claude` CLI:
55
+ Any extra arguments after `--` are forwarded to the `claude` CLI.
56
56
 
57
- | Command | What it does | Example |
58
- |--------|--------------|--------|
59
- | `mlx-code` | Start interactive mode | `mlx-code` |
60
- | `mlx-code "task"` | Run a one-time task | `mlx-code "fix the build error"` |
61
- | `mlx-code -p "query"` | Run one-off query, then exit | `mlx-code -p "explain this function"` |
62
- | `mlx-code -c` | Continue most recent conversation in current directory | `mlx-code -c` |
63
- | `mlx-code -r` | Resume a previous conversation | `mlx-code -r` |
64
- | `mlx-code commit` | Create a Git commit | `mlx-code commit` |
65
- | `/clear` | Clear conversation history | `/clear` |
66
- | `/help` | Show available commands | `/help` |
67
- | `exit` or `Ctrl+C` | Exit Claude Code | `exit` |
57
+ ### Credits
58
+
59
+ `pie.py` is based on [pi](https://github.com/badlogic/pi-mono) by Mario Zechner, used under the MIT License.
68
60
 
69
61
  ### Licence
70
62
 
71
63
  Apache License 2.0 — see LICENSE for details.
72
-
73
-
@@ -28,22 +28,12 @@ mlx-code [options] [-- claude options]
28
28
  | `--work` | `$CWD` | Working directory mirrored into the Claude session |
29
29
  | `--home` | temp dir | Home directory for the Claude process |
30
30
 
31
- Any extra arguments after `--` are forwarded to the `claude` CLI:
32
-
33
- | Command | What it does | Example |
34
- |--------|--------------|--------|
35
- | `mlx-code` | Start interactive mode | `mlx-code` |
36
- | `mlx-code "task"` | Run a one-time task | `mlx-code "fix the build error"` |
37
- | `mlx-code -p "query"` | Run one-off query, then exit | `mlx-code -p "explain this function"` |
38
- | `mlx-code -c` | Continue most recent conversation in current directory | `mlx-code -c` |
39
- | `mlx-code -r` | Resume a previous conversation | `mlx-code -r` |
40
- | `mlx-code commit` | Create a Git commit | `mlx-code commit` |
41
- | `/clear` | Clear conversation history | `/clear` |
42
- | `/help` | Show available commands | `/help` |
43
- | `exit` or `Ctrl+C` | Exit Claude Code | `exit` |
31
+ Any extra arguments after `--` are forwarded to the `claude` CLI.
44
32
 
45
- ### Licence
33
+ ### Credits
46
34
 
47
- Apache License 2.0 see LICENSE for details.
35
+ `pie.py` is based on [pi](https://github.com/badlogic/pi-mono) by Mario Zechner, used under the MIT License.
48
36
 
37
+ ### Licence
49
38
 
39
+ Apache License 2.0 — see LICENSE for details.
@@ -25,8 +25,22 @@ from http.server import BaseHTTPRequestHandler, HTTPServer
25
25
  from pathlib import Path
26
26
  import mlx.core as mx
27
27
  import mlx_lm
28
- from mlx_lm.generate import generate_step
28
+ import numpy as np
29
+ import hashlib
30
+ import contextlib
31
+ import functools
32
+ import mlx.nn as nn
33
+ from typing import (
34
+ Any,
35
+ Callable,
36
+ Generator,
37
+ List,
38
+ Optional,
39
+ Tuple,
40
+ Union,
41
+ )
29
42
 
43
+ generation_stream = mx.new_stream(mx.default_device())
30
44
  stream_logger = logging.getLogger("stream")
31
45
  stream_logger.setLevel(logging.DEBUG)
32
46
  s_handler = logging.FileHandler("mlx_stream.log", mode='w')
@@ -39,7 +53,54 @@ t_handler = logging.FileHandler("mlx_trace.log", mode='w')
39
53
  t_handler.setFormatter(logging.Formatter("【%(message)s\n】\n"))
40
54
  trace_logger.addHandler(t_handler)
41
55
  gen_lock = threading.Lock()
42
- prompt_cache = {}
56
+ dict_cache = {}
57
+
58
+ def hash_tokens(tokens):
59
+ arr = np.array(tokens, dtype=np.uint32)
60
+ return hashlib.blake2b(arr.tobytes(), digest_size=8).hexdigest()
61
+
62
+ def get_common_len(a, b):
63
+ common_len = 0
64
+ for p, h in zip(a, b):
65
+ if p == h:
66
+ common_len += 1
67
+ else:
68
+ break
69
+ return common_len
70
+
71
+ @contextlib.contextmanager
72
+ def wired_limit(model: nn.Module, streams: Optional[List[mx.Stream]] = None):
73
+ if not mx.metal.is_available():
74
+ try:
75
+ yield
76
+ finally:
77
+ pass
78
+ else:
79
+ model_bytes = tree_reduce(
80
+ lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc, model, 0
81
+ )
82
+ max_rec_size = mx.device_info()["max_recommended_working_set_size"]
83
+ if model_bytes > 0.9 * max_rec_size:
84
+ model_mb = model_bytes // 2**20
85
+ max_rec_mb = max_rec_size // 2**20
86
+ print(f"{model_mb=} {max_rec_mb=}")
87
+ old_limit = mx.set_wired_limit(max_rec_size)
88
+ try:
89
+ yield
90
+ finally:
91
+ if streams is not None:
92
+ for s in streams:
93
+ mx.synchronize(s)
94
+ else:
95
+ mx.synchronize()
96
+ mx.set_wired_limit(old_limit)
97
+
98
+ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits):
99
+ if kv_bits is None:
100
+ return
101
+ for e, c in enumerate(prompt_cache):
102
+ if hasattr(c, "to_quantized") and c.offset >= quantized_kv_start:
103
+ prompt_cache[e] = c.to_quantized(group_size=kv_group_size, bits=kv_bits)
43
104
 
44
105
  def parse_tool(tools, names):
45
106
  qwen_tools = []
@@ -121,11 +182,17 @@ def encode(body, tokenizer, system, names, skips):
121
182
  if parts:
122
183
  msgs.append({"role": role}|parts)
123
184
  if not msgs[-1].get('content', '').strip():
124
- return None
125
- return tokenizer.apply_chat_template(msgs, tools = parse_tool(body.get("tools", []), names), tokenize=False, add_generation_prompt=True)
185
+ return None, ''
186
+ apply_chat_template = lambda x: tokenizer.apply_chat_template(x, tools = parse_tool(body.get("tools", []), names), tokenize=False, add_generation_prompt=True)
187
+ full = apply_chat_template(msgs)
188
+ last_user_idx = max((i for i, m in enumerate(msgs) if m.get("role") == "user"), default=None)
189
+ if last_user_idx is None:
190
+ return full, ''
191
+ p_msgs = msgs[:last_user_idx] + [dict(role='user', content='h' if msgs[last_user_idx]['content'][0] != 'h' else 'i')]
192
+ pref = apply_chat_template(p_msgs)
193
+ return full, pref
126
194
 
127
195
  def decode(raw_text, tokenizer, parse_think, single_think=False):
128
- # think_id = tokenizer.convert_tokens_to_ids("<think>")
129
196
  def escape(text):
130
197
  def repl(match):
131
198
  inner = match.group(1)
@@ -139,7 +206,6 @@ def decode(raw_text, tokenizer, parse_think, single_think=False):
139
206
  parts = re.split(r'(<think>.*?</think>)', raw_text, flags=re.DOTALL, maxsplit=1 if single_think else 0)
140
207
  else:
141
208
  parts = [raw_text]
142
-
143
209
  for part in parts:
144
210
  if not part:
145
211
  continue
@@ -192,7 +258,7 @@ def blocks_to_sse(blocks: list[dict], msg_id: str, in_tokens: int, out_tokens: i
192
258
  elif bt == "tool_use":
193
259
  out += event("content_block_start", {"type": "content_block_start", "index": i,
194
260
  "content_block": {"type": "tool_use", "id": block["id"],
195
- "name": block["name"], "input": {}}})
261
+ "name": block["name"], "input": {}} })
196
262
  out += event("content_block_delta", {"type": "content_block_delta", "index": i,
197
263
  "delta": {"type": "input_json_delta", "partial_json": json.dumps(block["input"])}})
198
264
  out += event("content_block_stop", {"type": "content_block_stop", "index": i})
@@ -203,6 +269,8 @@ def blocks_to_sse(blocks: list[dict], msg_id: str, in_tokens: int, out_tokens: i
203
269
  return bytes(out)
204
270
 
205
271
  def dmca(p_str):
272
+ if True: #: False for recording
273
+ return p_str
206
274
  symbols = ["▲", "△", "▶", "▷", "▼", "▽", "◀", "◁", "◆", "◇"]
207
275
  def mask_text(text):
208
276
  return re.sub(r"\S", lambda _: random.choice(symbols), text)
@@ -218,66 +286,6 @@ def dmca(p_str):
218
286
  p_str = re.sub(pattern, lambda m: mask_text(m.group(0)), p_str)
219
287
  return p_str
220
288
 
221
- def generate(model, tokenizer, prompt, hook=None, max_tokens=256, helper_max_tokens=64, **kwargs):
222
- global prompt_cache
223
- if prompt is None:
224
- return '', 0, 0
225
- if not isinstance(tokenizer, mlx_lm.tokenizer_utils.TokenizerWrapper):
226
- tokenizer = mlx_lm.tokenizer_utils.TokenizerWrapper(tokenizer)
227
- detokenizer = tokenizer.detokenizer
228
- if isinstance(prompt, str):
229
- add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(tokenizer.bos_token)
230
- prompt_s = prompt
231
- prompt = tokenizer.encode(prompt, add_special_tokens=add_special_tokens)
232
- else:
233
- prompt_s = tokenizer.decode(prompt)
234
- stream_logger.debug(dmca(prompt_s))
235
- common_len = 0
236
- if prompt_cache.get('cache', None):
237
- for p, h in zip(prompt, prompt_cache['hx']):
238
- if p == h:
239
- common_len += 1
240
- else:
241
- break
242
- common_len = min(common_len, len(prompt) - 1)
243
- else:
244
- prompt_cache['hx'] = []
245
- prompt_cache['cache'] = mlx_lm.models.cache.make_prompt_cache(model)
246
- hx_len = len(prompt_cache['hx'])
247
- trim_len = hx_len - common_len
248
- mlx_lm.models.cache.trim_prompt_cache(prompt_cache['cache'], trim_len)
249
- token_gen = generate_step(
250
- mx.array(prompt[common_len:]),
251
- model,
252
- prompt_cache=prompt_cache['cache'],
253
- max_tokens=max_tokens,
254
- **kwargs,
255
- )
256
- text = ""
257
- tic_non = time.perf_counter()
258
- gens = []
259
- for token, _ in token_gen:
260
- gens.append(token)
261
- if token in tokenizer.eos_token_ids:
262
- break
263
- detokenizer.add_token(token)
264
- seg = detokenizer.last_segment
265
- stream_logger.debug(seg)
266
- text += seg
267
- if len(gens) == 1:
268
- tic_inp = time.perf_counter()
269
- if prompt_cache.get('file_name'):
270
- _fn = prompt_cache.pop('file_name')
271
- mlx_lm.models.cache.save_prompt_cache(_fn, prompt_cache['cache'], metadata=dict(model_name=prompt_cache['model_name'], hx=json.dumps(prompt+gens)))
272
- if len(gens) >= max_tokens:
273
- break
274
- tic_out = time.perf_counter()
275
- detokenizer.finalize()
276
- text += detokenizer.last_segment
277
- prompt_cache['hx'] = prompt+gens
278
- trace_logger.debug(f'G {hx_len} {len(prompt)} {common_len} {trim_len} {len(gens)}\n=== TPS ===\n- Processed {len(prompt)} input tokens in {tic_inp-tic_non:.0f} seconds ({len(prompt)/(tic_inp-tic_non):.0f} tokens per second)\n- Generated {len(gens)} new tokens in {tic_out-tic_inp:.0f} seconds ({len(gens)/(tic_out-tic_inp):.0f} tokens per second)\n\n=== INP ===\n{dmca(prompt_s)}\n=== OUT ===\n{text}')
279
- return text, len(prompt), len(gens)
280
-
281
289
  def make_handler(model, tokenizer, system, names, skips, parse_think=True):
282
290
  class Handler(BaseHTTPRequestHandler):
283
291
  def log_message(self, fmt, *args):
@@ -308,12 +316,13 @@ def make_handler(model, tokenizer, system, names, skips, parse_think=True):
308
316
  return
309
317
  n = int(self.headers.get("Content-Length", 0))
310
318
  body = json.loads(self.rfile.read(n))
311
- prompt = encode(body, tokenizer, system, names, skips)
319
+ prompt, pref = encode(body, tokenizer, system, names, skips)
312
320
  with gen_lock:
313
- raw, in_tokens, out_tokens = generate(model, tokenizer, prompt=prompt, max_tokens=body.get("max_tokens", 8192))
321
+ raw, in_tokens, out_tokens = generate(model, tokenizer, pref=pref, prompt=prompt, max_tokens=body.get("max_tokens", 8192))
314
322
  blocks, stop_reason = decode(raw, tokenizer, parse_think=parse_think)
315
323
  msg_id = f"msg_{uuid.uuid4().hex}"
316
324
  sse = blocks_to_sse(blocks, msg_id, in_tokens, out_tokens, stop_reason)
325
+ trace_logger.debug(sse)
317
326
  self.send_response(200)
318
327
  self.send_header("Content-Type", "text/event-stream")
319
328
  self.send_header("Cache-Control", "no-cache")
@@ -326,15 +335,29 @@ def make_handler(model, tokenizer, system, names, skips, parse_think=True):
326
335
  pass
327
336
  return Handler
328
337
 
338
+ def load_dict_cache(cache_path):
339
+ cache, metadata = mlx_lm.models.cache.load_prompt_cache(cache_path, return_metadata=True)
340
+ mx.eval(cache)
341
+ model_name = metadata.pop("model_name", "")
342
+ tokens_str = metadata.pop("hx", "[]")
343
+ tokens = json.loads(tokens_str)
344
+ global dict_cache
345
+ dict_cache |= dict(cache=cache, hx=tokens, model_name=model_name)
346
+
347
+ def save_dict_cache(cache_path, metadata, prompt_cache):
348
+ mlx_lm.models.cache.save_prompt_cache(cache_path, prompt_cache, metadata=metadata)
349
+
329
350
  def main():
330
351
  parser = argparse.ArgumentParser()
352
+ parser.add_argument("--harness", default=None)
353
+ # parser.add_argument("--harness", default="claude")
331
354
  parser.add_argument("--model", default="mlx-community/Qwen3.5-4B-OptiQ-4bit")
332
355
  # parser.add_argument("--model", default="mlx-community/Qwen3.5-2B-OptiQ-4bit")
333
356
  # parser.add_argument("--model", default="mlx-community/Qwen3.5-0.8B-MLX-bf16")
334
357
  parser.add_argument("--system", type=str, default='')
335
358
  # parser.add_argument("--system", type=str, default='# Env\n{env}')
336
359
  # parser.add_argument("--system", type=str, default=None)
337
- parser.add_argument("--cache", type=str, default='cache/cache.safetensors')
360
+ parser.add_argument("--cache", type=str, default='cache')
338
361
  # parser.add_argument("--names", nargs="+", default=[])
339
362
  parser.add_argument("--names", nargs="+", default=['Read','Edit','Write','Grep','Glob','Bash','Agent','Skill'])
340
363
  # parser.add_argument("--names", nargs="+", default=None)
@@ -346,39 +369,265 @@ def main():
346
369
  parser.add_argument("--host", default="127.0.0.1")
347
370
  parser.add_argument("--home", default=tempfile.mkdtemp())
348
371
  parser.add_argument("--work", default=os.getcwd())
349
- args, claude_args = parser.parse_known_args()
350
- global prompt_cache
351
- if os.path.exists(args.cache):
352
- cache, metadata = mlx_lm.models.cache.load_prompt_cache(args.cache, return_metadata=True)
353
- mx.eval(cache)
354
- model_name = metadata.pop("model_name", "")
355
- tokens_str = metadata.pop("hx", "[]")
356
- tokens = json.loads(tokens_str)
357
- prompt_cache = dict(cache=cache, hx=tokens, model_name=model_name)
358
- if prompt_cache.get('model_name') != args.model:
359
- prompt_cache = dict(model_name=args.model)
360
- else:
361
- Path(args.cache).parent.mkdir(parents=True, exist_ok=True)
362
- prompt_cache = dict(model_name=args.model)
363
- prompt_cache['file_name']=args.cache
372
+ parser.add_argument("--nocc", action="store_true", help="Disable Claude Code subprocess and run server only")
373
+ args, harness_args = parser.parse_known_args()
374
+ Path(args.cache).mkdir(parents=True, exist_ok=True)
375
+ global dict_cache
376
+ dict_cache = dict(model_name=args.model, cache_dir = args.cache)
364
377
  model, tokenizer = mlx_lm.load(args.model)
365
378
  server = HTTPServer((args.host, args.port), make_handler(model, tokenizer, args.system, args.names, args.skips))
366
- threading.Thread(target=server.serve_forever, daemon=True).start()
367
- env = os.environ.copy()
368
- env["ANTHROPIC_BASE_URL"] = f"http://{args.host}:{args.port}"
369
- env["ANTHROPIC_AUTH_TOKEN"] = "local"
370
- env["ANTHROPIC_MODEL"] = args.model
371
- env["ANTHROPIC_SMALL_FAST_MODEL"] = args.model
372
- env["HOME"] = args.home
373
- def mirror_workspace(src: str, dst: str):
374
- for root, dirs, files in os.walk(src):
375
- rel = os.path.relpath(root, src)
376
- os.makedirs(os.path.join(dst, rel), exist_ok=True)
377
- for f in files:
378
- os.link(os.path.join(root, f), os.path.join(dst, rel, f))
379
- workspace = os.path.join(args.home, "workspace")
380
- mirror_workspace(args.work, workspace)
381
- sys.exit(subprocess.run(["claude"] + claude_args, env=env, cwd=workspace).returncode)
379
+ if args.nocc:
380
+ try:
381
+ server.serve_forever()
382
+ except KeyboardInterrupt:
383
+ print("\nShutting down server...")
384
+ server.server_close()
385
+ else:
386
+ threading.Thread(target=server.serve_forever, daemon=True).start()
387
+ env = os.environ.copy()
388
+ env["ANTHROPIC_BASE_URL"] = f"http://{args.host}:{args.port}"
389
+ env["ANTHROPIC_AUTH_TOKEN"] = "local"
390
+ env["ANTHROPIC_MODEL"] = args.model
391
+ env["ANTHROPIC_SMALL_FAST_MODEL"] = args.model
392
+ env["HOME"] = args.home
393
+ def mirror_workspace(src: str, dst: str):
394
+ for root, dirs, files in os.walk(src):
395
+ rel = os.path.relpath(root, src)
396
+ os.makedirs(os.path.join(dst, rel), exist_ok=True)
397
+ for f in files:
398
+ os.link(os.path.join(root, f), os.path.join(dst, rel, f))
399
+ workspace = os.path.join(args.home, "workspace")
400
+ mirror_workspace(args.work, workspace)
401
+ if args.harness is None:
402
+ from pie import run_repl
403
+ run_repl(base_url=f"http://{args.host}:{args.port}/v1")
404
+ else:
405
+ sys.exit(subprocess.run([args.harness] + harness_args, env=env, cwd=workspace).returncode)
406
+
407
+ def generate_step(
408
+ prompt: mx.array,
409
+ model: nn.Module,
410
+ *,
411
+ max_tokens: int = 256,
412
+ sampler: Optional[Callable[[mx.array], mx.array]] = None,
413
+ logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
414
+ max_kv_size: Optional[int] = None,
415
+ prompt_cache: Optional[Any] = None,
416
+ prefill_step_size: int = 2048,
417
+ kv_bits: Optional[int] = None,
418
+ kv_group_size: int = 64,
419
+ quantized_kv_start: int = 0,
420
+ prompt_progress_callback: Optional[Callable[[int, int], None]] = None,
421
+ input_embeddings: Optional[mx.array] = None,
422
+ save_at: int = -1,
423
+ save_fn = None,
424
+ ) -> Generator[Tuple[mx.array, mx.array], None, None]:
425
+ if input_embeddings is not None:
426
+ if not does_model_support_input_embeddings(model):
427
+ raise ValueError("Model does not support input embeddings.")
428
+ elif len(prompt) > 0 and len(prompt) != len(input_embeddings):
429
+ raise ValueError(f"{len(input_embeddings)=} {len(prompt)=}")
430
+ elif len(prompt) == 0:
431
+ raise ValueError("Either input_embeddings or prompt (or both) must be provided.")
432
+
433
+ tokens = None
434
+
435
+ if prompt_cache is None:
436
+ prompt_cache = cache.make_prompt_cache(
437
+ model,
438
+ max_kv_size=max_kv_size,
439
+ )
440
+
441
+ prompt_progress_callback = prompt_progress_callback or (lambda *_: None)
442
+
443
+ quantize_cache_fn = functools.partial(
444
+ maybe_quantize_kv_cache,
445
+ quantized_kv_start=quantized_kv_start,
446
+ kv_group_size=kv_group_size,
447
+ kv_bits=kv_bits,
448
+ )
449
+
450
+ sampler = sampler or (lambda x: mx.argmax(x, axis=-1))
451
+
452
+ def _model_call(input_tokens: mx.array, input_embeddings: Optional[mx.array]):
453
+ if input_embeddings is not None:
454
+ return model(
455
+ input_tokens, cache=prompt_cache, input_embeddings=input_embeddings
456
+ )
457
+ else:
458
+ return model(input_tokens, cache=prompt_cache)
459
+
460
+ def _step(input_tokens: mx.array, input_embeddings: Optional[mx.array] = None):
461
+ nonlocal tokens
462
+
463
+ with mx.stream(generation_stream):
464
+ logits = _model_call(
465
+ input_tokens=input_tokens[None],
466
+ input_embeddings=(
467
+ input_embeddings[None] if input_embeddings is not None else None
468
+ ),
469
+ )
470
+
471
+ logits = logits[:, -1, :]
472
+
473
+ if logits_processors and len(input_tokens) > 0:
474
+ tokens = (
475
+ mx.concat([tokens, input_tokens])
476
+ if tokens is not None
477
+ else input_tokens
478
+ )
479
+ for processor in logits_processors:
480
+ logits = processor(tokens, logits)
481
+
482
+ quantize_cache_fn(prompt_cache)
483
+ logprobs = logits - mx.logsumexp(logits, keepdims=True)
484
+ sampled = sampler(logprobs)
485
+ return sampled, logprobs.squeeze(0)
486
+
487
+ with mx.stream(generation_stream):
488
+ total_prompt_tokens = (
489
+ len(input_embeddings) if input_embeddings is not None else len(prompt)
490
+ )
491
+ prompt_processed_tokens = 0
492
+ prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens)
493
+ while total_prompt_tokens - prompt_processed_tokens > 1:
494
+ remaining = (total_prompt_tokens - prompt_processed_tokens) - 1
495
+ n_to_process = min(prefill_step_size, remaining)
496
+ if prompt_processed_tokens < save_at:
497
+ n_to_process = min(n_to_process, save_at - prompt_processed_tokens)
498
+
499
+ _model_call(
500
+ input_tokens=prompt[:n_to_process][None],
501
+ input_embeddings=(
502
+ input_embeddings[:n_to_process][None]
503
+ if input_embeddings is not None
504
+ else None
505
+ ),
506
+ )
507
+ quantize_cache_fn(prompt_cache)
508
+ mx.eval([c.state for c in prompt_cache])
509
+ prompt_processed_tokens += n_to_process
510
+ prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens)
511
+ prompt = prompt[n_to_process:]
512
+ input_embeddings = (
513
+ input_embeddings[n_to_process:]
514
+ if input_embeddings is not None
515
+ else input_embeddings
516
+ )
517
+ mx.clear_cache()
518
+ if save_fn is not None and prompt_processed_tokens == save_at:
519
+ save_fn(prompt_cache)
520
+
521
+ y, logprobs = _step(input_tokens=prompt, input_embeddings=input_embeddings)
522
+
523
+ mx.async_eval(y, logprobs)
524
+ n = 0
525
+ while True:
526
+ if n != max_tokens:
527
+ next_y, next_logprobs = _step(y)
528
+ mx.async_eval(next_y, next_logprobs)
529
+ if n == 0:
530
+ mx.eval(y)
531
+ prompt_progress_callback(total_prompt_tokens, total_prompt_tokens)
532
+ if n == max_tokens:
533
+ break
534
+ yield y.item(), logprobs
535
+ if n % 256 == 0:
536
+ mx.clear_cache()
537
+ y, logprobs = next_y, next_logprobs
538
+ n += 1
539
+
540
+ def generate(model, tokenizer, prompt, pref, hook=None, max_tokens=256, helper_max_tokens=64, **kwargs):
541
+ global dict_cache
542
+ if prompt is None:
543
+ return '', 0, 0
544
+ if not isinstance(tokenizer, mlx_lm.tokenizer_utils.TokenizerWrapper):
545
+ tokenizer = mlx_lm.tokenizer_utils.TokenizerWrapper(tokenizer)
546
+ detokenizer = tokenizer.detokenizer
547
+ if isinstance(prompt, str):
548
+ add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(tokenizer.bos_token)
549
+ prompt_s = prompt
550
+ prompt = tokenizer.encode(prompt, add_special_tokens=add_special_tokens)
551
+ _pref = tokenizer.encode(pref, add_special_tokens=add_special_tokens)
552
+ save_at = get_common_len(prompt, _pref)
553
+ else:
554
+ prompt_s = tokenizer.decode(prompt)
555
+ save_at = -1 # □ for now
556
+ stream_logger.debug(dmca(prompt_s))
557
+ text = ''
558
+ gens = []
559
+ common_len = 0
560
+ hx_len = None
561
+ trim_len = None
562
+ save_fn = None
563
+ if not dict_cache.get('cache'):
564
+ ckpt_path = Path(dict_cache['cache_dir'])/f'{"".join(c for c in dict_cache["model_name"] if c.isalnum())}_{save_at}_{hash_tokens(prompt[:save_at])}.safetensors'
565
+ if os.path.exists(ckpt_path):
566
+ load_dict_cache(ckpt_path)
567
+ else:
568
+ dict_cache |= dict(cache=mlx_lm.models.cache.make_prompt_cache(model), hx=[])
569
+ save_fn = functools.partial(save_dict_cache, ckpt_path, dict(model_name=dict_cache['model_name'], hx=json.dumps(prompt[:save_at+1])))
570
+
571
+ if (hx := dict_cache.get('hx')):
572
+ _hx = hx[:-1]
573
+ common_len = get_common_len(prompt, _hx)
574
+ hx_len = len(_hx)
575
+ trim_len = hx_len - common_len
576
+ if trim_len > 0:
577
+ if all(c.is_trimmable() for c in dict_cache['cache']):
578
+ mlx_lm.models.cache.trim_prompt_cache(dict_cache['cache'], trim_len)
579
+ else:
580
+ ckpt_path = Path(dict_cache['cache_dir'])/f'{"".join(c for c in dict_cache["model_name"] if c.isalnum())}_{save_at}_{hash_tokens(prompt[:save_at])}.safetensors'
581
+ if os.path.exists(ckpt_path):
582
+ load_dict_cache(ckpt_path)
583
+ common_len = save_at
584
+ else:
585
+ dict_cache |= dict(cache=mlx_lm.models.cache.make_prompt_cache(model), hx=[])
586
+ save_fn = functools.partial(save_dict_cache, ckpt_path, dict(model_name=dict_cache['model_name'], hx=json.dumps(prompt[:save_at+1])))
587
+ common_len = 0
588
+
589
+ if save_at > common_len and not all(c.is_trimmable() for c in dict_cache['cache']):
590
+ ckpt_path = Path(dict_cache['cache_dir'])/f'{"".join(c for c in dict_cache["model_name"] if c.isalnum())}_{save_at}_{hash_tokens(prompt[:save_at])}.safetensors'
591
+ save_fn = functools.partial(save_dict_cache, ckpt_path, dict(model_name=dict_cache['model_name'], hx=json.dumps(prompt[:save_at+1])))
592
+ else:
593
+ save_at = -1
594
+
595
+ if common_len==len(prompt):
596
+ _last_gen = dict_cache['hx'][common_len]
597
+ prompt_arr = mx.array([_last_gen])
598
+ gens.append(_last_gen)
599
+ detokenizer.add(_last_gen)
600
+ else:
601
+ prompt_arr = mx.array(prompt[common_len:])
602
+
603
+ token_gen = generate_step(
604
+ prompt_arr,
605
+ model,
606
+ prompt_cache=dict_cache['cache'],
607
+ max_tokens=max_tokens,
608
+ save_at=save_at-common_len,
609
+ save_fn=save_fn,
610
+ **kwargs,
611
+ )
612
+ tic_non = time.perf_counter()
613
+ for token, _ in token_gen:
614
+ gens.append(token)
615
+ if token in tokenizer.eos_token_ids:
616
+ break
617
+ detokenizer.add_token(token)
618
+ seg = detokenizer.last_segment
619
+ stream_logger.debug(seg)
620
+ text += seg
621
+ if len(gens) == 1:
622
+ tic_inp = time.perf_counter()
623
+ if len(gens) >= max_tokens:
624
+ break
625
+ tic_out = time.perf_counter()
626
+ detokenizer.finalize()
627
+ text += detokenizer.last_segment
628
+ dict_cache['hx'] = prompt+gens
629
+ trace_logger.debug(f'G {hx_len} {len(prompt)} {common_len} {trim_len} {len(gens)}\n=== TPS ===\n- Processed {len(prompt)} input tokens in {tic_inp-tic_non:.0f} seconds ({len(prompt)/(tic_inp-tic_non):.0f} tokens per second)\n- Generated {len(gens)} new tokens in {tic_out-tic_inp:.0f} seconds ({len(gens)/(tic_out-tic_inp):.0f} tokens per second)\n\n=== INP ===\n{dmca(prompt_s)}\n=== OUT ===\n{text}')
630
+ return text, len(prompt), len(gens)
382
631
 
383
632
  if __name__ == "__main__":
384
633
  main()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mlx-code
3
- Version: 0.0.2a1
3
+ Version: 0.0.2a3
4
4
  Summary: Local Claude Code for Mac
5
5
  Home-page: https://github.com/JosefAlbers/mlx-code
6
6
  Author: J Joe
@@ -52,22 +52,12 @@ mlx-code [options] [-- claude options]
52
52
  | `--work` | `$CWD` | Working directory mirrored into the Claude session |
53
53
  | `--home` | temp dir | Home directory for the Claude process |
54
54
 
55
- Any extra arguments after `--` are forwarded to the `claude` CLI:
55
+ Any extra arguments after `--` are forwarded to the `claude` CLI.
56
56
 
57
- | Command | What it does | Example |
58
- |--------|--------------|--------|
59
- | `mlx-code` | Start interactive mode | `mlx-code` |
60
- | `mlx-code "task"` | Run a one-time task | `mlx-code "fix the build error"` |
61
- | `mlx-code -p "query"` | Run one-off query, then exit | `mlx-code -p "explain this function"` |
62
- | `mlx-code -c` | Continue most recent conversation in current directory | `mlx-code -c` |
63
- | `mlx-code -r` | Resume a previous conversation | `mlx-code -r` |
64
- | `mlx-code commit` | Create a Git commit | `mlx-code commit` |
65
- | `/clear` | Clear conversation history | `/clear` |
66
- | `/help` | Show available commands | `/help` |
67
- | `exit` or `Ctrl+C` | Exit Claude Code | `exit` |
57
+ ### Credits
58
+
59
+ `pie.py` is based on [pi](https://github.com/badlogic/pi-mono) by Mario Zechner, used under the MIT License.
68
60
 
69
61
  ### Licence
70
62
 
71
63
  Apache License 2.0 — see LICENSE for details.
72
-
73
-
@@ -1,6 +1,7 @@
1
1
  LICENSE
2
2
  README.md
3
3
  main.py
4
+ pie.py
4
5
  setup.py
5
6
  mlx_code.egg-info/PKG-INFO
6
7
  mlx_code.egg-info/SOURCES.txt