mlx-code 0.0.2a1__tar.gz → 0.0.2a2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mlx-code
3
- Version: 0.0.2a1
3
+ Version: 0.0.2a2
4
4
  Summary: Local Claude Code for Mac
5
5
  Home-page: https://github.com/JosefAlbers/mlx-code
6
6
  Author: J Joe
@@ -52,19 +52,7 @@ mlx-code [options] [-- claude options]
52
52
  | `--work` | `$CWD` | Working directory mirrored into the Claude session |
53
53
  | `--home` | temp dir | Home directory for the Claude process |
54
54
 
55
- Any extra arguments after `--` are forwarded to the `claude` CLI:
56
-
57
- | Command | What it does | Example |
58
- |--------|--------------|--------|
59
- | `mlx-code` | Start interactive mode | `mlx-code` |
60
- | `mlx-code "task"` | Run a one-time task | `mlx-code "fix the build error"` |
61
- | `mlx-code -p "query"` | Run one-off query, then exit | `mlx-code -p "explain this function"` |
62
- | `mlx-code -c` | Continue most recent conversation in current directory | `mlx-code -c` |
63
- | `mlx-code -r` | Resume a previous conversation | `mlx-code -r` |
64
- | `mlx-code commit` | Create a Git commit | `mlx-code commit` |
65
- | `/clear` | Clear conversation history | `/clear` |
66
- | `/help` | Show available commands | `/help` |
67
- | `exit` or `Ctrl+C` | Exit Claude Code | `exit` |
55
+ Any extra arguments after `--` are forwarded to the `claude` CLI.
68
56
 
69
57
  ### Licence
70
58
 
@@ -28,19 +28,7 @@ mlx-code [options] [-- claude options]
28
28
  | `--work` | `$CWD` | Working directory mirrored into the Claude session |
29
29
  | `--home` | temp dir | Home directory for the Claude process |
30
30
 
31
- Any extra arguments after `--` are forwarded to the `claude` CLI:
32
-
33
- | Command | What it does | Example |
34
- |--------|--------------|--------|
35
- | `mlx-code` | Start interactive mode | `mlx-code` |
36
- | `mlx-code "task"` | Run a one-time task | `mlx-code "fix the build error"` |
37
- | `mlx-code -p "query"` | Run one-off query, then exit | `mlx-code -p "explain this function"` |
38
- | `mlx-code -c` | Continue most recent conversation in current directory | `mlx-code -c` |
39
- | `mlx-code -r` | Resume a previous conversation | `mlx-code -r` |
40
- | `mlx-code commit` | Create a Git commit | `mlx-code commit` |
41
- | `/clear` | Clear conversation history | `/clear` |
42
- | `/help` | Show available commands | `/help` |
43
- | `exit` or `Ctrl+C` | Exit Claude Code | `exit` |
31
+ Any extra arguments after `--` are forwarded to the `claude` CLI.
44
32
 
45
33
  ### Licence
46
34
 
@@ -1,3 +1,4 @@
1
+ # {{{
1
2
  # Copyright 2026 J Joe
2
3
  # Licensed under the Apache License, Version 2.0 (the "License");
3
4
  # you may not use this file except in compliance with the License.
@@ -25,8 +26,22 @@ from http.server import BaseHTTPRequestHandler, HTTPServer
25
26
  from pathlib import Path
26
27
  import mlx.core as mx
27
28
  import mlx_lm
28
- from mlx_lm.generate import generate_step
29
+ import numpy as np
30
+ import hashlib
31
+ import contextlib
32
+ import functools
33
+ import mlx.nn as nn
34
+ from typing import (
35
+ Any,
36
+ Callable,
37
+ Generator,
38
+ List,
39
+ Optional,
40
+ Tuple,
41
+ Union,
42
+ )
29
43
 
44
+ generation_stream = mx.new_stream(mx.default_device())
30
45
  stream_logger = logging.getLogger("stream")
31
46
  stream_logger.setLevel(logging.DEBUG)
32
47
  s_handler = logging.FileHandler("mlx_stream.log", mode='w')
@@ -39,7 +54,54 @@ t_handler = logging.FileHandler("mlx_trace.log", mode='w')
39
54
  t_handler.setFormatter(logging.Formatter("【%(message)s\n】\n"))
40
55
  trace_logger.addHandler(t_handler)
41
56
  gen_lock = threading.Lock()
42
- prompt_cache = {}
57
+ dict_cache = {}
58
+
59
+ def hash_tokens(tokens):
60
+ arr = np.array(tokens, dtype=np.uint32)
61
+ return hashlib.blake2b(arr.tobytes(), digest_size=8).hexdigest()
62
+
63
+ def get_common_len(a, b):
64
+ common_len = 0
65
+ for p, h in zip(a, b):
66
+ if p == h:
67
+ common_len += 1
68
+ else:
69
+ break
70
+ return common_len
71
+
72
+ @contextlib.contextmanager
73
+ def wired_limit(model: nn.Module, streams: Optional[List[mx.Stream]] = None):
74
+ if not mx.metal.is_available():
75
+ try:
76
+ yield
77
+ finally:
78
+ pass
79
+ else:
80
+ model_bytes = tree_reduce(
81
+ lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc, model, 0
82
+ )
83
+ max_rec_size = mx.device_info()["max_recommended_working_set_size"]
84
+ if model_bytes > 0.9 * max_rec_size:
85
+ model_mb = model_bytes // 2**20
86
+ max_rec_mb = max_rec_size // 2**20
87
+ print(f"{model_mb=} {max_rec_mb=}")
88
+ old_limit = mx.set_wired_limit(max_rec_size)
89
+ try:
90
+ yield
91
+ finally:
92
+ if streams is not None:
93
+ for s in streams:
94
+ mx.synchronize(s)
95
+ else:
96
+ mx.synchronize()
97
+ mx.set_wired_limit(old_limit)
98
+
99
+ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits):
100
+ if kv_bits is None:
101
+ return
102
+ for e, c in enumerate(prompt_cache):
103
+ if hasattr(c, "to_quantized") and c.offset >= quantized_kv_start:
104
+ prompt_cache[e] = c.to_quantized(group_size=kv_group_size, bits=kv_bits)
43
105
 
44
106
  def parse_tool(tools, names):
45
107
  qwen_tools = []
@@ -121,11 +183,17 @@ def encode(body, tokenizer, system, names, skips):
121
183
  if parts:
122
184
  msgs.append({"role": role}|parts)
123
185
  if not msgs[-1].get('content', '').strip():
124
- return None
125
- return tokenizer.apply_chat_template(msgs, tools = parse_tool(body.get("tools", []), names), tokenize=False, add_generation_prompt=True)
186
+ return None, -1
187
+ apply_chat_template = lambda x: tokenizer.apply_chat_template(x, tools = parse_tool(body.get("tools", []), names), tokenize=False, add_generation_prompt=True)
188
+ full = apply_chat_template(msgs)
189
+ last_user_idx = max((i for i, m in enumerate(msgs) if m.get("role") == "user"), default=None)
190
+ if last_user_idx is None:
191
+ return full, -1
192
+ p_msgs = msgs[:last_user_idx] + [dict(role='user', content='h' if msgs[last_user_idx]['content'][0] != 'h' else 'i')]
193
+ pref = apply_chat_template(p_msgs)
194
+ return full, pref
126
195
 
127
196
  def decode(raw_text, tokenizer, parse_think, single_think=False):
128
- # think_id = tokenizer.convert_tokens_to_ids("<think>")
129
197
  def escape(text):
130
198
  def repl(match):
131
199
  inner = match.group(1)
@@ -139,7 +207,6 @@ def decode(raw_text, tokenizer, parse_think, single_think=False):
139
207
  parts = re.split(r'(<think>.*?</think>)', raw_text, flags=re.DOTALL, maxsplit=1 if single_think else 0)
140
208
  else:
141
209
  parts = [raw_text]
142
-
143
210
  for part in parts:
144
211
  if not part:
145
212
  continue
@@ -192,7 +259,7 @@ def blocks_to_sse(blocks: list[dict], msg_id: str, in_tokens: int, out_tokens: i
192
259
  elif bt == "tool_use":
193
260
  out += event("content_block_start", {"type": "content_block_start", "index": i,
194
261
  "content_block": {"type": "tool_use", "id": block["id"],
195
- "name": block["name"], "input": {}}})
262
+ "name": block["name"], "input": {}} })
196
263
  out += event("content_block_delta", {"type": "content_block_delta", "index": i,
197
264
  "delta": {"type": "input_json_delta", "partial_json": json.dumps(block["input"])}})
198
265
  out += event("content_block_stop", {"type": "content_block_stop", "index": i})
@@ -203,6 +270,8 @@ def blocks_to_sse(blocks: list[dict], msg_id: str, in_tokens: int, out_tokens: i
203
270
  return bytes(out)
204
271
 
205
272
  def dmca(p_str):
273
+ if True: #: False for recording
274
+ return p_str
206
275
  symbols = ["▲", "△", "▶", "▷", "▼", "▽", "◀", "◁", "◆", "◇"]
207
276
  def mask_text(text):
208
277
  return re.sub(r"\S", lambda _: random.choice(symbols), text)
@@ -218,66 +287,6 @@ def dmca(p_str):
218
287
  p_str = re.sub(pattern, lambda m: mask_text(m.group(0)), p_str)
219
288
  return p_str
220
289
 
221
- def generate(model, tokenizer, prompt, hook=None, max_tokens=256, helper_max_tokens=64, **kwargs):
222
- global prompt_cache
223
- if prompt is None:
224
- return '', 0, 0
225
- if not isinstance(tokenizer, mlx_lm.tokenizer_utils.TokenizerWrapper):
226
- tokenizer = mlx_lm.tokenizer_utils.TokenizerWrapper(tokenizer)
227
- detokenizer = tokenizer.detokenizer
228
- if isinstance(prompt, str):
229
- add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(tokenizer.bos_token)
230
- prompt_s = prompt
231
- prompt = tokenizer.encode(prompt, add_special_tokens=add_special_tokens)
232
- else:
233
- prompt_s = tokenizer.decode(prompt)
234
- stream_logger.debug(dmca(prompt_s))
235
- common_len = 0
236
- if prompt_cache.get('cache', None):
237
- for p, h in zip(prompt, prompt_cache['hx']):
238
- if p == h:
239
- common_len += 1
240
- else:
241
- break
242
- common_len = min(common_len, len(prompt) - 1)
243
- else:
244
- prompt_cache['hx'] = []
245
- prompt_cache['cache'] = mlx_lm.models.cache.make_prompt_cache(model)
246
- hx_len = len(prompt_cache['hx'])
247
- trim_len = hx_len - common_len
248
- mlx_lm.models.cache.trim_prompt_cache(prompt_cache['cache'], trim_len)
249
- token_gen = generate_step(
250
- mx.array(prompt[common_len:]),
251
- model,
252
- prompt_cache=prompt_cache['cache'],
253
- max_tokens=max_tokens,
254
- **kwargs,
255
- )
256
- text = ""
257
- tic_non = time.perf_counter()
258
- gens = []
259
- for token, _ in token_gen:
260
- gens.append(token)
261
- if token in tokenizer.eos_token_ids:
262
- break
263
- detokenizer.add_token(token)
264
- seg = detokenizer.last_segment
265
- stream_logger.debug(seg)
266
- text += seg
267
- if len(gens) == 1:
268
- tic_inp = time.perf_counter()
269
- if prompt_cache.get('file_name'):
270
- _fn = prompt_cache.pop('file_name')
271
- mlx_lm.models.cache.save_prompt_cache(_fn, prompt_cache['cache'], metadata=dict(model_name=prompt_cache['model_name'], hx=json.dumps(prompt+gens)))
272
- if len(gens) >= max_tokens:
273
- break
274
- tic_out = time.perf_counter()
275
- detokenizer.finalize()
276
- text += detokenizer.last_segment
277
- prompt_cache['hx'] = prompt+gens
278
- trace_logger.debug(f'G {hx_len} {len(prompt)} {common_len} {trim_len} {len(gens)}\n=== TPS ===\n- Processed {len(prompt)} input tokens in {tic_inp-tic_non:.0f} seconds ({len(prompt)/(tic_inp-tic_non):.0f} tokens per second)\n- Generated {len(gens)} new tokens in {tic_out-tic_inp:.0f} seconds ({len(gens)/(tic_out-tic_inp):.0f} tokens per second)\n\n=== INP ===\n{dmca(prompt_s)}\n=== OUT ===\n{text}')
279
- return text, len(prompt), len(gens)
280
-
281
290
  def make_handler(model, tokenizer, system, names, skips, parse_think=True):
282
291
  class Handler(BaseHTTPRequestHandler):
283
292
  def log_message(self, fmt, *args):
@@ -308,9 +317,9 @@ def make_handler(model, tokenizer, system, names, skips, parse_think=True):
308
317
  return
309
318
  n = int(self.headers.get("Content-Length", 0))
310
319
  body = json.loads(self.rfile.read(n))
311
- prompt = encode(body, tokenizer, system, names, skips)
320
+ prompt, pref = encode(body, tokenizer, system, names, skips)
312
321
  with gen_lock:
313
- raw, in_tokens, out_tokens = generate(model, tokenizer, prompt=prompt, max_tokens=body.get("max_tokens", 8192))
322
+ raw, in_tokens, out_tokens = generate(model, tokenizer, pref=pref, prompt=prompt, max_tokens=body.get("max_tokens", 8192))
314
323
  blocks, stop_reason = decode(raw, tokenizer, parse_think=parse_think)
315
324
  msg_id = f"msg_{uuid.uuid4().hex}"
316
325
  sse = blocks_to_sse(blocks, msg_id, in_tokens, out_tokens, stop_reason)
@@ -326,6 +335,18 @@ def make_handler(model, tokenizer, system, names, skips, parse_think=True):
326
335
  pass
327
336
  return Handler
328
337
 
338
+ def load_dict_cache(cache_path):
339
+ global dict_cache
340
+ cache, metadata = mlx_lm.models.cache.load_prompt_cache(cache_path, return_metadata=True)
341
+ mx.eval(cache)
342
+ model_name = metadata.pop("model_name", "")
343
+ tokens_str = metadata.pop("hx", "[]")
344
+ tokens = json.loads(tokens_str)
345
+ dict_cache = dict(cache=cache, hx=tokens, model_name=model_name)
346
+
347
+ def save_dict_cache(cache_path, metadata, prompt_cache):
348
+ mlx_lm.models.cache.save_prompt_cache(cache_path, prompt_cache, metadata=metadata)
349
+
329
350
  def main():
330
351
  parser = argparse.ArgumentParser()
331
352
  parser.add_argument("--model", default="mlx-community/Qwen3.5-4B-OptiQ-4bit")
@@ -334,7 +355,8 @@ def main():
334
355
  parser.add_argument("--system", type=str, default='')
335
356
  # parser.add_argument("--system", type=str, default='# Env\n{env}')
336
357
  # parser.add_argument("--system", type=str, default=None)
337
- parser.add_argument("--cache", type=str, default='cache/cache.safetensors')
358
+ # parser.add_argument("--cache", type=str, default='cache/cache.safetensors')
359
+ parser.add_argument("--cache", type=str, default='cache')
338
360
  # parser.add_argument("--names", nargs="+", default=[])
339
361
  parser.add_argument("--names", nargs="+", default=['Read','Edit','Write','Grep','Glob','Bash','Agent','Skill'])
340
362
  # parser.add_argument("--names", nargs="+", default=None)
@@ -347,20 +369,9 @@ def main():
347
369
  parser.add_argument("--home", default=tempfile.mkdtemp())
348
370
  parser.add_argument("--work", default=os.getcwd())
349
371
  args, claude_args = parser.parse_known_args()
350
- global prompt_cache
351
- if os.path.exists(args.cache):
352
- cache, metadata = mlx_lm.models.cache.load_prompt_cache(args.cache, return_metadata=True)
353
- mx.eval(cache)
354
- model_name = metadata.pop("model_name", "")
355
- tokens_str = metadata.pop("hx", "[]")
356
- tokens = json.loads(tokens_str)
357
- prompt_cache = dict(cache=cache, hx=tokens, model_name=model_name)
358
- if prompt_cache.get('model_name') != args.model:
359
- prompt_cache = dict(model_name=args.model)
360
- else:
361
- Path(args.cache).parent.mkdir(parents=True, exist_ok=True)
362
- prompt_cache = dict(model_name=args.model)
363
- prompt_cache['file_name']=args.cache
372
+ Path(args.cache).mkdir(parents=True, exist_ok=True)
373
+ global dict_cache
374
+ dict_cache = dict(model_name=args.model, cache_dir = args.cache)
364
375
  model, tokenizer = mlx_lm.load(args.model)
365
376
  server = HTTPServer((args.host, args.port), make_handler(model, tokenizer, args.system, args.names, args.skips))
366
377
  threading.Thread(target=server.serve_forever, daemon=True).start()
@@ -380,5 +391,228 @@ def main():
380
391
  mirror_workspace(args.work, workspace)
381
392
  sys.exit(subprocess.run(["claude"] + claude_args, env=env, cwd=workspace).returncode)
382
393
 
394
+ def generate_step(
395
+ prompt: mx.array,
396
+ model: nn.Module,
397
+ *,
398
+ max_tokens: int = 256,
399
+ sampler: Optional[Callable[[mx.array], mx.array]] = None,
400
+ logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
401
+ max_kv_size: Optional[int] = None,
402
+ prompt_cache: Optional[Any] = None,
403
+ prefill_step_size: int = 2048,
404
+ kv_bits: Optional[int] = None,
405
+ kv_group_size: int = 64,
406
+ quantized_kv_start: int = 0,
407
+ prompt_progress_callback: Optional[Callable[[int, int], None]] = None,
408
+ input_embeddings: Optional[mx.array] = None,
409
+ save_at: int = -1,
410
+ save_fn = None,
411
+ ) -> Generator[Tuple[mx.array, mx.array], None, None]:
412
+ if input_embeddings is not None:
413
+ if not does_model_support_input_embeddings(model):
414
+ raise ValueError("Model does not support input embeddings.")
415
+ elif len(prompt) > 0 and len(prompt) != len(input_embeddings):
416
+ raise ValueError(f"{len(input_embeddings)=} {len(prompt)=}")
417
+ elif len(prompt) == 0:
418
+ raise ValueError("Either input_embeddings or prompt (or both) must be provided.")
419
+
420
+ tokens = None
421
+
422
+ if prompt_cache is None:
423
+ prompt_cache = cache.make_prompt_cache(
424
+ model,
425
+ max_kv_size=max_kv_size,
426
+ )
427
+
428
+ prompt_progress_callback = prompt_progress_callback or (lambda *_: None)
429
+
430
+ quantize_cache_fn = functools.partial(
431
+ maybe_quantize_kv_cache,
432
+ quantized_kv_start=quantized_kv_start,
433
+ kv_group_size=kv_group_size,
434
+ kv_bits=kv_bits,
435
+ )
436
+
437
+ sampler = sampler or (lambda x: mx.argmax(x, axis=-1))
438
+
439
+ def _model_call(input_tokens: mx.array, input_embeddings: Optional[mx.array]):
440
+ if input_embeddings is not None:
441
+ return model(
442
+ input_tokens, cache=prompt_cache, input_embeddings=input_embeddings
443
+ )
444
+ else:
445
+ return model(input_tokens, cache=prompt_cache)
446
+
447
+ def _step(input_tokens: mx.array, input_embeddings: Optional[mx.array] = None):
448
+ nonlocal tokens
449
+
450
+ with mx.stream(generation_stream):
451
+ logits = _model_call(
452
+ input_tokens=input_tokens[None],
453
+ input_embeddings=(
454
+ input_embeddings[None] if input_embeddings is not None else None
455
+ ),
456
+ )
457
+
458
+ logits = logits[:, -1, :]
459
+
460
+ if logits_processors and len(input_tokens) > 0:
461
+ tokens = (
462
+ mx.concat([tokens, input_tokens])
463
+ if tokens is not None
464
+ else input_tokens
465
+ )
466
+ for processor in logits_processors:
467
+ logits = processor(tokens, logits)
468
+
469
+ quantize_cache_fn(prompt_cache)
470
+ logprobs = logits - mx.logsumexp(logits, keepdims=True)
471
+ sampled = sampler(logprobs)
472
+ return sampled, logprobs.squeeze(0)
473
+
474
+ with mx.stream(generation_stream):
475
+ total_prompt_tokens = (
476
+ len(input_embeddings) if input_embeddings is not None else len(prompt)
477
+ )
478
+ prompt_processed_tokens = 0
479
+ prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens)
480
+ while total_prompt_tokens - prompt_processed_tokens > 1:
481
+ remaining = (total_prompt_tokens - prompt_processed_tokens) - 1
482
+ n_to_process = min(prefill_step_size, remaining)
483
+ if prompt_processed_tokens < save_at:
484
+ n_to_process = min(n_to_process, save_at - prompt_processed_tokens)
485
+
486
+ _model_call(
487
+ input_tokens=prompt[:n_to_process][None],
488
+ input_embeddings=(
489
+ input_embeddings[:n_to_process][None]
490
+ if input_embeddings is not None
491
+ else None
492
+ ),
493
+ )
494
+ quantize_cache_fn(prompt_cache)
495
+ mx.eval([c.state for c in prompt_cache])
496
+ prompt_processed_tokens += n_to_process
497
+ prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens)
498
+ prompt = prompt[n_to_process:]
499
+ input_embeddings = (
500
+ input_embeddings[n_to_process:]
501
+ if input_embeddings is not None
502
+ else input_embeddings
503
+ )
504
+ mx.clear_cache()
505
+ if save_fn is not None and prompt_processed_tokens == save_at:
506
+ save_fn(prompt_cache)
507
+
508
+ y, logprobs = _step(input_tokens=prompt, input_embeddings=input_embeddings)
509
+
510
+ mx.async_eval(y, logprobs)
511
+ n = 0
512
+ while True:
513
+ if n != max_tokens:
514
+ next_y, next_logprobs = _step(y)
515
+ mx.async_eval(next_y, next_logprobs)
516
+ if n == 0:
517
+ mx.eval(y)
518
+ prompt_progress_callback(total_prompt_tokens, total_prompt_tokens)
519
+ if n == max_tokens:
520
+ break
521
+ yield y.item(), logprobs
522
+ if n % 256 == 0:
523
+ mx.clear_cache()
524
+ y, logprobs = next_y, next_logprobs
525
+ n += 1
526
+
527
+ def generate(model, tokenizer, prompt, pref, hook=None, max_tokens=256, helper_max_tokens=64, **kwargs):
528
+ global dict_cache
529
+ if prompt is None:
530
+ return '', 0, 0
531
+ if not isinstance(tokenizer, mlx_lm.tokenizer_utils.TokenizerWrapper):
532
+ tokenizer = mlx_lm.tokenizer_utils.TokenizerWrapper(tokenizer)
533
+ detokenizer = tokenizer.detokenizer
534
+ if isinstance(prompt, str):
535
+ add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(tokenizer.bos_token)
536
+ prompt_s = prompt
537
+ prompt = tokenizer.encode(prompt, add_special_tokens=add_special_tokens)
538
+ _pref = tokenizer.encode(pref, add_special_tokens=add_special_tokens)
539
+ save_at = get_common_len(prompt, _pref)
540
+ else:
541
+ prompt_s = tokenizer.decode(prompt)
542
+ save_at = -1 # □ for now
543
+ stream_logger.debug(dmca(prompt_s))
544
+ text = ''
545
+ gens = []
546
+ common_len = 0
547
+ hx_len = None
548
+ trim_len = None
549
+ save_fn = None
550
+ if not dict_cache.get('cache'):
551
+ ckpt_path = Path(dict_cache['cache_dir'])/f'{"".join(c for c in dict_cache["model_name"] if c.isalnum())}_{save_at}_{hash_tokens(prompt[:save_at])}.safetensors'
552
+ trace_logger.debug(ckpt_path.resolve())
553
+ trace_logger.debug(ckpt_path.absolute())
554
+ if os.path.exists(ckpt_path):
555
+ load_dict_cache(ckpt_path)
556
+ else:
557
+ dict_cache |= dict(cache=mlx_lm.models.cache.make_prompt_cache(model), hx=[])
558
+ save_fn = functools.partial(save_dict_cache, ckpt_path, dict(model_name=dict_cache['model_name'], hx=json.dumps(prompt[:save_at+1])))
559
+
560
+ if (hx := dict_cache.get('hx')):
561
+ _hx = hx[:-1]
562
+ common_len = get_common_len(prompt, _hx)
563
+ hx_len = len(_hx)
564
+ trim_len = hx_len - common_len
565
+ if trim_len > 0:
566
+ if all(c.is_trimmable() for c in dict_cache['cache']):
567
+ mlx_lm.models.cache.trim_prompt_cache(dict_cache['cache'], trim_len)
568
+ else:
569
+ ckpt_path = Path(dict_cache['cache_dir'])/f'{"".join(c for c in dict_cache["model_name"] if c.isalnum())}_{save_at}_{hash_tokens(prompt[:save_at])}.safetensors'
570
+ if os.path.exists(ckpt_path):
571
+ load_dict_cache(ckpt_path)
572
+ common_len = save_at
573
+ if save_at > common_len and not all(c.is_trimmable() for c in dict_cache['cache']):
574
+ ckpt_path = Path(dict_cache['cache_dir'])/f'{"".join(c for c in dict_cache["model_name"] if c.isalnum())}_{save_at}_{hash_tokens(prompt[:save_at])}.safetensors'
575
+ save_fn = functools.partial(save_dict_cache, ckpt_path, dict(model_name=dict_cache['model_name'], hx=json.dumps(prompt[:save_at+1])))
576
+ else:
577
+ save_at = -1
578
+
579
+ if common_len==len(prompt):
580
+ _last_gen = dict_cache['hx'][common_len]
581
+ prompt_arr = mx.array([_last_gen])
582
+ gens.append(_last_gen)
583
+ detokenizer.add(_last_gen)
584
+ else:
585
+ prompt_arr = mx.array(prompt[common_len:])
586
+
587
+ trace_logger.debug(f'{save_at=} {common_len=}')
588
+ token_gen = generate_step(
589
+ prompt_arr,
590
+ model,
591
+ prompt_cache=dict_cache['cache'],
592
+ max_tokens=max_tokens,
593
+ save_at=save_at-common_len,
594
+ save_fn=save_fn,
595
+ **kwargs,
596
+ )
597
+ tic_non = time.perf_counter()
598
+ for token, _ in token_gen:
599
+ gens.append(token)
600
+ if token in tokenizer.eos_token_ids:
601
+ break
602
+ detokenizer.add_token(token)
603
+ seg = detokenizer.last_segment
604
+ stream_logger.debug(seg)
605
+ text += seg
606
+ if len(gens) == 1:
607
+ tic_inp = time.perf_counter()
608
+ if len(gens) >= max_tokens:
609
+ break
610
+ tic_out = time.perf_counter()
611
+ detokenizer.finalize()
612
+ text += detokenizer.last_segment
613
+ dict_cache['hx'] = prompt+gens
614
+ trace_logger.debug(f'G {hx_len} {len(prompt)} {common_len} {trim_len} {len(gens)}\n=== TPS ===\n- Processed {len(prompt)} input tokens in {tic_inp-tic_non:.0f} seconds ({len(prompt)/(tic_inp-tic_non):.0f} tokens per second)\n- Generated {len(gens)} new tokens in {tic_out-tic_inp:.0f} seconds ({len(gens)/(tic_out-tic_inp):.0f} tokens per second)\n\n=== INP ===\n{dmca(prompt_s)}\n=== OUT ===\n{text}')
615
+ return text, len(prompt), len(gens)
616
+
383
617
  if __name__ == "__main__":
384
618
  main()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mlx-code
3
- Version: 0.0.2a1
3
+ Version: 0.0.2a2
4
4
  Summary: Local Claude Code for Mac
5
5
  Home-page: https://github.com/JosefAlbers/mlx-code
6
6
  Author: J Joe
@@ -52,19 +52,7 @@ mlx-code [options] [-- claude options]
52
52
  | `--work` | `$CWD` | Working directory mirrored into the Claude session |
53
53
  | `--home` | temp dir | Home directory for the Claude process |
54
54
 
55
- Any extra arguments after `--` are forwarded to the `claude` CLI:
56
-
57
- | Command | What it does | Example |
58
- |--------|--------------|--------|
59
- | `mlx-code` | Start interactive mode | `mlx-code` |
60
- | `mlx-code "task"` | Run a one-time task | `mlx-code "fix the build error"` |
61
- | `mlx-code -p "query"` | Run one-off query, then exit | `mlx-code -p "explain this function"` |
62
- | `mlx-code -c` | Continue most recent conversation in current directory | `mlx-code -c` |
63
- | `mlx-code -r` | Resume a previous conversation | `mlx-code -r` |
64
- | `mlx-code commit` | Create a Git commit | `mlx-code commit` |
65
- | `/clear` | Clear conversation history | `/clear` |
66
- | `/help` | Show available commands | `/help` |
67
- | `exit` or `Ctrl+C` | Exit Claude Code | `exit` |
55
+ Any extra arguments after `--` are forwarded to the `claude` CLI.
68
56
 
69
57
  ### Licence
70
58
 
@@ -6,7 +6,7 @@ setup(
6
6
  author_email="albersj66@gmail.com",
7
7
  author="J Joe",
8
8
  license="Apache-2.0",
9
- version="0.0.2a1",
9
+ version="0.0.2a2",
10
10
  readme="README.md",
11
11
  description="Local Claude Code for Mac",
12
12
  long_description=open("README.md").read(),
File without changes
File without changes