escribano 0.4.5 → 0.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,41 +1,56 @@
1
1
  #!/usr/bin/env python3
2
2
  """
3
- MLX-VLM Bridge for Escribano
3
+ MLX Bridge for Escribano
4
4
 
5
- A Unix domain socket server that provides interleaved VLM batch processing.
5
+ A Unix domain socket server that provides VLM and/or LLM inference.
6
6
  Communicates with TypeScript via NDJSON (newline-delimited JSON).
7
7
 
8
8
  Usage:
9
- python3 scripts/mlx_bridge.py
9
+ python3 scripts/mlx_bridge.py --mode vlm # VLM-only (frame analysis)
10
+ python3 scripts/mlx_bridge.py --mode llm # LLM-only (text generation)
10
11
 
11
12
  Environment Variables:
12
- ESCRIBANO_VLM_MODEL - MLX model name (default: mlx-community/Qwen3-VL-2B-Instruct-4bit)
13
+ ESCRIBANO_VLM_MODEL - MLX VLM model name (default: mlx-community/Qwen3-VL-2B-Instruct-4bit)
13
14
  ESCRIBANO_VLM_BATCH_SIZE - Frames per batch (default: 2)
14
- ESCRIBANO_VLM_MAX_TOKENS - Token budget per batch (default: 2000)
15
+ ESCRIBANO_VLM_MAX_TOKENS - Token budget per batch (default: 4000)
15
16
  ESCRIBANO_MLX_SOCKET_PATH - Unix socket path (default: /tmp/escribano-mlx.sock)
16
17
  ESCRIBANO_VERBOSE - Enable verbose logging (default: false)
17
18
  """
18
19
 
20
+ import argparse
19
21
  import json
20
22
  import os
21
23
  import re
22
24
  import signal
23
25
  import socket
26
+ import sqlite3
24
27
  import sys
25
28
  import time
26
29
  from pathlib import Path
27
- from typing import Any
30
+ from typing import Any, Literal
28
31
 
29
32
  # Configuration from environment (all defaults come from TypeScript config.ts)
30
33
  MODEL_NAME = os.environ.get(
31
34
  "ESCRIBANO_VLM_MODEL", "mlx-community/Qwen3-VL-2B-Instruct-4bit"
32
35
  )
33
36
  BATCH_SIZE = int(os.environ.get("ESCRIBANO_VLM_BATCH_SIZE", "2"))
34
- MAX_TOKENS = int(os.environ.get("ESCRIBANO_VLM_MAX_TOKENS", "2000"))
37
+ MAX_TOKENS_VLM = int(os.environ.get("ESCRIBANO_VLM_MAX_TOKENS", "4000"))
38
+
35
39
  SOCKET_PATH = os.environ.get("ESCRIBANO_MLX_SOCKET_PATH", "/tmp/escribano-mlx.sock")
36
40
  VERBOSE = os.environ.get("ESCRIBANO_VERBOSE", "false").lower() == "true"
37
41
  TEMPERATURE = 0.3
38
42
 
43
+ # Debug logging configuration
44
+ DB_PATH = os.environ.get("ESCRIBANO_DB_PATH", "")
45
+ DEBUG_LLM = os.environ.get("ESCRIBANO_DEBUG_LLM", "false").lower() == "true"
46
+
47
+ # Bridge mode (set via --mode flag)
48
+ BridgeMode = Literal["vlm", "llm"]
49
+ BRIDGE_MODE: BridgeMode = "vlm"
50
+
51
+ # Shutdown flag for graceful exit
52
+ shutting_down = False
53
+
39
54
 
40
55
  def find_project_root() -> Path:
41
56
  """Find the project root by walking up from the script location."""
@@ -52,9 +67,12 @@ def load_vlm_prompt(batch_size: int) -> str:
52
67
  """Load and template the VLM prompt from prompts/vlm-batch.md."""
53
68
  project_root = find_project_root()
54
69
  prompt_file = project_root / "prompts" / "vlm-batch.md"
55
-
70
+
56
71
  if not prompt_file.exists():
57
- log(f"Warning: prompt file not found at {prompt_file}, using inline prompt", "info")
72
+ log(
73
+ f"Warning: prompt file not found at {prompt_file}, using inline prompt",
74
+ "info",
75
+ )
58
76
  # Fallback inline prompt (old behavior)
59
77
  return f"""Analyze these {batch_size} screenshots from a screen recording.
60
78
 
@@ -68,7 +86,7 @@ Output in this exact format for each frame:
68
86
  Frame 1: description: ... | activity: ... | apps: [...] | topics: [...]
69
87
  Frame 2: description: ... | activity: ... | apps: [...] | topics: [...]
70
88
  ...and so on for all {batch_size} frames."""
71
-
89
+
72
90
  try:
73
91
  content = prompt_file.read_text(encoding="utf-8")
74
92
  # Replace template variable
@@ -90,11 +108,16 @@ Frame 1: description: ... | activity: ... | apps: [...] | topics: [...]
90
108
  Frame 2: description: ... | activity: ... | apps: [...] | topics: [...]
91
109
  ...and so on for all {batch_size} frames."""
92
110
 
111
+
93
112
  # Global state
94
113
  model = None
95
114
  processor = None
96
115
  config = None
116
+ llm_model = None
117
+ llm_tokenizer = None
118
+ llm_loaded_model_name = None
97
119
  server_socket = None
120
+ debug_db_conn = None
98
121
 
99
122
 
100
123
  def log(message: str, level: str = "info") -> None:
@@ -107,9 +130,133 @@ def log(message: str, level: str = "info") -> None:
107
130
  print(f"{prefix} {message}", file=sys.stderr, flush=True)
108
131
 
109
132
 
133
+ def get_debug_db() -> sqlite3.Connection | None:
134
+ """Get or create debug database connection."""
135
+ global debug_db_conn
136
+ if not DEBUG_LLM or not DB_PATH:
137
+ return None
138
+ if debug_db_conn is None:
139
+ try:
140
+ debug_db_conn = sqlite3.connect(DB_PATH)
141
+ log(f"Connected to debug database: {DB_PATH}", "debug")
142
+ except Exception as e:
143
+ log(f"Failed to connect to debug database: {e}", "error")
144
+ return debug_db_conn
145
+
146
+
147
+ def log_llm_call(data: dict) -> None:
148
+ """Log LLM call to debug table (best-effort)."""
149
+ db = get_debug_db()
150
+ if not db:
151
+ return
152
+
153
+ try:
154
+ cursor = db.cursor()
155
+ cursor.execute(
156
+ """
157
+ INSERT INTO llm_debug_log (
158
+ id, recording_id, artifact_id, call_type, prompt, result, metadata
159
+ ) VALUES (?, ?, ?, ?, ?, ?, ?)
160
+ """,
161
+ (
162
+ data["id"],
163
+ data.get("recording_id"),
164
+ data.get("artifact_id"),
165
+ data.get("call_type", "unknown"),
166
+ data.get("prompt"),
167
+ data.get("result"),
168
+ json.dumps(data["metadata"]),
169
+ ),
170
+ )
171
+ db.commit()
172
+ log(f"Logged LLM call to debug table: {data['id']}", "debug")
173
+ except Exception as e:
174
+ log(f"Failed to log LLM call (non-fatal): {e}", "error")
175
+
176
+
177
+ def load_llm_model(model_name: str) -> tuple[Any, Any]:
178
+ """Load an MLX text-only LLM model via mlx_lm."""
179
+ log(f"Loading LLM model: {model_name}")
180
+ log("This may take 30-120 seconds on first run or after memory clear...")
181
+ start = time.time()
182
+
183
+ try:
184
+ import gc
185
+ import mlx.core as mx
186
+
187
+ log("Importing mlx_lm...", "debug")
188
+ from mlx_lm import load
189
+ import mlx_lm
190
+
191
+ log("Loading model weights into memory (this takes the longest)...", "debug")
192
+ model_obj, tokenizer_obj = load(model_name)
193
+
194
+ duration = time.time() - start
195
+ log(f"LLM model loaded in {duration:.1f}s")
196
+ log(f"mlx_lm version: {mlx_lm.__version__}")
197
+
198
+ return model_obj, tokenizer_obj
199
+ except ImportError as e:
200
+ log(f"Failed to import mlx_lm: {e}", "error")
201
+ log(f"Python used: {sys.executable}", "error")
202
+ custom_python = os.environ.get("ESCRIBANO_PYTHON_PATH")
203
+ if custom_python:
204
+ log(
205
+ "ESCRIBANO_PYTHON_PATH is set, so Escribano does not auto-install mlx-lm "
206
+ "into this Python environment.",
207
+ "error",
208
+ )
209
+ log(
210
+ f"Make sure mlx-lm is installed for that Python "
211
+ f"(e.g. `{custom_python} -m pip install mlx-lm`), "
212
+ "or unset ESCRIBANO_PYTHON_PATH to let Escribano manage its own Python.",
213
+ "error",
214
+ )
215
+ raise
216
+ except Exception as e:
217
+ log(f"Failed to load LLM model: {e}", "error")
218
+ raise
219
+
220
+
221
+ def unload_vlm() -> None:
222
+ """Free VLM memory before loading LLM."""
223
+ global model, processor, config
224
+ log("Unloading VLM model to free memory", "debug")
225
+ try:
226
+ import gc
227
+ import mlx.core as mx
228
+
229
+ model = None
230
+ processor = None
231
+ config = None
232
+ gc.collect()
233
+ mx.metal.clear_cache() # Apple Silicon memory cleanup
234
+ log("VLM unloaded successfully", "debug")
235
+ except Exception as e:
236
+ log(f"Error unloading VLM: {e}", "error")
237
+
238
+
239
+ def unload_llm() -> None:
240
+ """Free LLM memory after generation."""
241
+ global llm_model, llm_tokenizer, llm_loaded_model_name
242
+ log("Unloading LLM model to free memory", "debug")
243
+ try:
244
+ import gc
245
+ import mlx.core as mx
246
+
247
+ llm_model = None
248
+ llm_tokenizer = None
249
+ llm_loaded_model_name = None
250
+ gc.collect()
251
+ mx.metal.clear_cache() # Apple Silicon memory cleanup
252
+ log("LLM unloaded successfully", "debug")
253
+ except Exception as e:
254
+ log(f"Error unloading LLM: {e}", "error")
255
+
256
+
110
257
  def cleanup() -> None:
111
- """Clean up socket file on exit."""
112
- global server_socket
258
+ """Clean up socket file and debug database on exit."""
259
+ global server_socket, debug_db_conn
113
260
  if server_socket:
114
261
  try:
115
262
  server_socket.close()
@@ -121,11 +268,19 @@ def cleanup() -> None:
121
268
  log(f"Removed socket: {SOCKET_PATH}", "debug")
122
269
  except Exception as e:
123
270
  log(f"Failed to remove socket: {e}", "error")
271
+ if debug_db_conn:
272
+ try:
273
+ debug_db_conn.close()
274
+ except Exception:
275
+ pass
276
+ debug_db_conn = None
124
277
 
125
278
 
126
279
  def signal_handler(signum: int, frame: Any) -> None:
127
280
  """Handle shutdown signals."""
281
+ global shutting_down
128
282
  log(f"Received signal {signum}, shutting down...")
283
+ shutting_down = True
129
284
  cleanup()
130
285
  sys.exit(0)
131
286
 
@@ -143,7 +298,7 @@ def load_model() -> tuple[Any, Any, Any]:
143
298
 
144
299
  log("Loading model weights into memory (this takes the longest)...", "debug")
145
300
  model_obj, processor_obj = load(MODEL_NAME)
146
-
301
+
147
302
  log("Loading model config...", "debug")
148
303
  config_obj = load_config(MODEL_NAME)
149
304
 
@@ -189,7 +344,10 @@ def send_response(conn: socket.socket, obj: dict) -> None:
189
344
  try:
190
345
  data = json.dumps(obj) + "\n"
191
346
  conn.sendall(data.encode("utf-8"))
192
- log(f"Sent response: {obj.get('id', '?')} batch={obj.get('batch', '?')}", "debug")
347
+ log(
348
+ f"Sent response: {obj.get('id', '?')} batch={obj.get('batch', '?')}",
349
+ "debug",
350
+ )
193
351
  except Exception as e:
194
352
  log(f"Failed to send response: {e}", "error")
195
353
 
@@ -221,7 +379,9 @@ def parse_vlm_response(content: str) -> dict:
221
379
  result["description"] = match[1].strip()
222
380
  result["activity"] = match[2].strip()
223
381
  result["apps"] = list(set(s.strip() for s in apps_str.split(",") if s.strip()))
224
- result["topics"] = list(set(s.strip() for s in topics_str.split(",") if s.strip()))
382
+ result["topics"] = list(
383
+ set(s.strip() for s in topics_str.split(",") if s.strip())
384
+ )
225
385
  else:
226
386
  # Fallback: use content as description
227
387
  result["description"] = content.strip()
@@ -251,11 +411,13 @@ def process_interleaved_batch(
251
411
  timestamp = frame.get("timestamp", "unknown")
252
412
 
253
413
  # Add text label
254
- content.append({"type": "text", "text": f"Frame {frame_num} (timestamp: {timestamp}s):"})
414
+ content.append(
415
+ {"type": "text", "text": f"Frame {frame_num} (timestamp: {timestamp}s):"}
416
+ )
255
417
  # Add image placeholder
256
418
  content.append({"type": "image"})
257
419
 
258
- # Add final prompt with instructions (loaded from prompts/vlm-batch.md)
420
+ # Add final prompt with instructions (loaded from prompts/vlm-batch.md)
259
421
  final_prompt = load_vlm_prompt(len(batch))
260
422
  content.append({"type": "text", "text": final_prompt})
261
423
 
@@ -266,7 +428,7 @@ def process_interleaved_batch(
266
428
  prompt = get_chat_template(processor_obj, messages, add_generation_prompt=True)
267
429
 
268
430
  t_generate_start = time.time()
269
-
431
+
270
432
  # Generate with multiple images
271
433
  output = generate(
272
434
  model_obj,
@@ -274,7 +436,7 @@ def process_interleaved_batch(
274
436
  prompt,
275
437
  image=[f["imagePath"] for f in batch],
276
438
  temperature=TEMPERATURE,
277
- max_tokens=MAX_TOKENS,
439
+ max_tokens=MAX_TOKENS_VLM,
278
440
  verbose=VERBOSE,
279
441
  )
280
442
 
@@ -291,7 +453,7 @@ def process_interleaved_batch(
291
453
 
292
454
  # Parse results for each frame
293
455
  results = parse_interleaved_output(content_text, batch)
294
-
456
+
295
457
  t_parse_end = time.time()
296
458
  parse_time = t_parse_end - t_generate_end
297
459
  total_time = t_parse_end - t_batch_start
@@ -311,11 +473,28 @@ def process_interleaved_batch(
311
473
 
312
474
  # Log detailed stats if verbose
313
475
  if VERBOSE:
314
- log(f" Prompt: {stats['prompt_tokens']} tokens @ {stats['prompt_tps']:.1f} tok/s", "debug")
315
- log(f" Gen: {stats['generation_tokens']} tokens @ {stats['generation_tps']:.1f} tok/s", "debug")
316
- prefill_s = stats['prompt_tokens'] / stats['prompt_tps'] if stats['prompt_tps'] > 0 else 0
317
- gen_s = stats['generation_tokens'] / stats['generation_tps'] if stats['generation_tps'] > 0 else 0
318
- log(f" Time: {generate_time:.2f}s (prefill: {prefill_s:.2f}s, gen: {gen_s:.2f}s)", "debug")
476
+ log(
477
+ f" Prompt: {stats['prompt_tokens']} tokens @ {stats['prompt_tps']:.1f} tok/s",
478
+ "debug",
479
+ )
480
+ log(
481
+ f" Gen: {stats['generation_tokens']} tokens @ {stats['generation_tps']:.1f} tok/s",
482
+ "debug",
483
+ )
484
+ prefill_s = (
485
+ stats["prompt_tokens"] / stats["prompt_tps"]
486
+ if stats["prompt_tps"] > 0
487
+ else 0
488
+ )
489
+ gen_s = (
490
+ stats["generation_tokens"] / stats["generation_tps"]
491
+ if stats["generation_tps"] > 0
492
+ else 0
493
+ )
494
+ log(
495
+ f" Time: {generate_time:.2f}s (prefill: {prefill_s:.2f}s, gen: {gen_s:.2f}s)",
496
+ "debug",
497
+ )
319
498
  log(f" Peak memory: {stats['peak_memory_gb']:.2f} GB", "debug")
320
499
  log(f" Batch total: {total_time:.2f}s", "debug")
321
500
 
@@ -337,40 +516,77 @@ def parse_interleaved_output(text: str, batch: list[dict]) -> list[dict]:
337
516
  apps_str = re.sub(r"^\[|\]$", "", match[3].strip())
338
517
  topics_str = re.sub(r"^\[|\]$", "", match[4].strip())
339
518
 
340
- results.append({
341
- "index": frame.get("index", frame_num - 1),
342
- "timestamp": frame["timestamp"],
343
- "imagePath": frame["imagePath"],
344
- "description": match[1].strip(),
345
- "activity": match[2].strip(),
346
- "apps": [s.strip() for s in apps_str.split(",") if s.strip()],
347
- "topics": [s.strip() for s in topics_str.split(",") if s.strip()],
348
- })
519
+ results.append(
520
+ {
521
+ "index": frame.get("index", frame_num - 1),
522
+ "timestamp": frame["timestamp"],
523
+ "imagePath": frame["imagePath"],
524
+ "description": match[1].strip(),
525
+ "activity": match[2].strip(),
526
+ "apps": [s.strip() for s in apps_str.split(",") if s.strip()],
527
+ "topics": [s.strip() for s in topics_str.split(",") if s.strip()],
528
+ }
529
+ )
349
530
  else:
350
- results.append({
351
- "index": frame.get("index", frame_num - 1),
352
- "timestamp": frame["timestamp"],
353
- "imagePath": frame["imagePath"],
354
- "description": f"Failed to parse Frame {frame_num}",
355
- "activity": "unknown",
356
- "apps": [],
357
- "topics": [],
358
- "raw_response": text,
359
- })
531
+ results.append(
532
+ {
533
+ "index": frame.get("index", frame_num - 1),
534
+ "timestamp": frame["timestamp"],
535
+ "imagePath": frame["imagePath"],
536
+ "description": f"Failed to parse Frame {frame_num}",
537
+ "activity": "unknown",
538
+ "apps": [],
539
+ "topics": [],
540
+ "raw_response": text,
541
+ }
542
+ )
360
543
 
361
544
  return results
362
545
 
363
546
 
547
+ def strip_thinking_tags(text: str) -> str:
548
+ """Remove <think>...</think> tags from thinking-mode output.
549
+
550
+ Handles two cases:
551
+ 1. Standard: <think>...content...</think> (complete pairs)
552
+ 2. Qwen3.5 behavior: thinking text with orphan </think> tag (incomplete pair)
553
+ """
554
+ # Strip complete <think>...</think> pairs (standard case)
555
+ text = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL)
556
+
557
+ # Strip orphan closing tag + everything before it (Qwen3.5 actual behavior)
558
+ # This handles: "Let me analyze...\n</think>\n# Actual answer"
559
+ if "</think>" in text:
560
+ text = text.split("</think>", 1)[1]
561
+
562
+ return text.strip()
563
+
564
+
364
565
  def handle_describe_images(
365
- conn: socket.socket, model_obj: Any, processor_obj: Any, config_obj: Any, params: dict, request_id: int
566
+ conn: socket.socket,
567
+ model_obj: Any,
568
+ processor_obj: Any,
569
+ config_obj: Any,
570
+ params: dict,
571
+ request_id: int,
366
572
  ) -> None:
367
573
  """Handle describe_images request with streaming batch responses."""
574
+ global model, processor, config
575
+
576
+ # Reload model if it was unloaded (lazy reload after unload_vlm)
577
+ if model_obj is None:
578
+ log("VLM model was unloaded, reloading...")
579
+ model, processor, config = load_model()
580
+ model_obj, processor_obj, config_obj = model, processor, config
581
+
368
582
  images = params.get("images", [])
369
583
  batch_size = params.get("batchSize", BATCH_SIZE)
370
584
  total = len(images)
371
585
 
372
586
  if total == 0:
373
- send_response(conn, {"id": request_id, "error": "No images provided", "done": True})
587
+ send_response(
588
+ conn, {"id": request_id, "error": "No images provided", "done": True}
589
+ )
374
590
  return
375
591
 
376
592
  log(f"Processing {total} images in batches of {batch_size}")
@@ -386,9 +602,13 @@ def handle_describe_images(
386
602
  batch_num = batch_idx // batch_size + 1
387
603
 
388
604
  try:
389
- log(f"Processing batch {batch_num}: frames {batch_idx + 1}-{min(batch_idx + batch_size, total)}")
605
+ log(
606
+ f"Processing batch {batch_num}: frames {batch_idx + 1}-{min(batch_idx + batch_size, total)}"
607
+ )
390
608
 
391
- results, stats = process_interleaved_batch(model_obj, processor_obj, config_obj, batch)
609
+ results, stats = process_interleaved_batch(
610
+ model_obj, processor_obj, config_obj, batch
611
+ )
392
612
 
393
613
  # Accumulate stats
394
614
  total_prompt_tokens += stats.get("prompt_tokens", 0)
@@ -397,30 +617,44 @@ def handle_describe_images(
397
617
 
398
618
  # Stream response immediately
399
619
  is_partial = batch_idx + batch_size < total
400
- send_response(conn, {
401
- "id": request_id,
402
- "batch": batch_num,
403
- "results": results,
404
- "stats": stats,
405
- "partial": is_partial,
406
- "progress": {"current": batch_idx + len(batch), "total": total},
407
- })
620
+ send_response(
621
+ conn,
622
+ {
623
+ "id": request_id,
624
+ "batch": batch_num,
625
+ "results": results,
626
+ "stats": stats,
627
+ "partial": is_partial,
628
+ "progress": {"current": batch_idx + len(batch), "total": total},
629
+ },
630
+ )
408
631
 
409
632
  except Exception as e:
410
633
  log(f"Batch {batch_num} failed: {e}", "error")
411
- send_response(conn, {
412
- "id": request_id,
413
- "batch": batch_num,
414
- "error": str(e),
415
- "partial": batch_idx + batch_size < total,
416
- "progress": {"current": batch_idx + len(batch), "total": total},
417
- })
634
+ send_response(
635
+ conn,
636
+ {
637
+ "id": request_id,
638
+ "batch": batch_num,
639
+ "error": str(e),
640
+ "partial": batch_idx + batch_size < total,
641
+ "progress": {"current": batch_idx + len(batch), "total": total},
642
+ },
643
+ )
418
644
 
419
645
  # Log summary stats
420
646
  if total_generate_time > 0:
421
- avg_prompt_tps = total_prompt_tokens / (total_prompt_tokens / 2000) if total_prompt_tokens > 0 else 0
422
- avg_gen_tps = total_gen_tokens / total_generate_time if total_generate_time > 0 else 0
423
- log(f"Total: {total_prompt_tokens} prompt tokens, {total_gen_tokens} gen tokens in {total_generate_time:.1f}s")
647
+ avg_prompt_tps = (
648
+ total_prompt_tokens / (total_prompt_tokens / 2000)
649
+ if total_prompt_tokens > 0
650
+ else 0
651
+ )
652
+ avg_gen_tps = (
653
+ total_gen_tokens / total_generate_time if total_generate_time > 0 else 0
654
+ )
655
+ log(
656
+ f"Total: {total_prompt_tokens} prompt tokens, {total_gen_tokens} gen tokens in {total_generate_time:.1f}s"
657
+ )
424
658
 
425
659
  # Final done signal
426
660
  send_response(conn, {"id": request_id, "done": True})
@@ -438,15 +672,240 @@ def handle_request(
438
672
 
439
673
  log(f"Received request: id={request_id} method={method}", "debug")
440
674
 
675
+ # Validate method compatibility with bridge mode
676
+ if BRIDGE_MODE == "llm" and method == "describe_images":
677
+ send_response(
678
+ conn,
679
+ {
680
+ "id": request_id,
681
+ "error": "describe_images not available in LLM-only mode",
682
+ "done": True,
683
+ },
684
+ )
685
+ return
686
+
687
+ if BRIDGE_MODE == "vlm" and method == "generate_text":
688
+ send_response(
689
+ conn,
690
+ {
691
+ "id": request_id,
692
+ "error": "generate_text not available in VLM-only mode",
693
+ "done": True,
694
+ },
695
+ )
696
+ return
697
+
441
698
  if method == "describe_images":
442
- handle_describe_images(conn, model_obj, processor_obj, config_obj, params, request_id)
699
+ handle_describe_images(
700
+ conn, model_obj, processor_obj, config_obj, params, request_id
701
+ )
702
+ elif method == "load_llm":
703
+ global llm_model, llm_tokenizer, llm_loaded_model_name
704
+ try:
705
+ llm_model, llm_tokenizer = load_llm_model(params.get("model", ""))
706
+ llm_loaded_model_name = params.get("model", "")
707
+ send_response(
708
+ conn, {"id": request_id, "status": "loaded", "done": True}
709
+ )
710
+ except Exception as e:
711
+ send_response(conn, {"id": request_id, "error": str(e), "done": True})
712
+ elif method == "unload_vlm":
713
+ try:
714
+ unload_vlm()
715
+ send_response(
716
+ conn, {"id": request_id, "status": "unloaded", "done": True}
717
+ )
718
+ except Exception as e:
719
+ send_response(conn, {"id": request_id, "error": str(e), "done": True})
720
+ elif method == "unload_llm":
721
+ try:
722
+ unload_llm()
723
+ send_response(
724
+ conn, {"id": request_id, "status": "unloaded", "done": True}
725
+ )
726
+ except Exception as e:
727
+ send_response(conn, {"id": request_id, "error": str(e), "done": True})
728
+ elif method == "generate_text":
729
+ if llm_model is None or llm_tokenizer is None:
730
+ send_response(
731
+ conn,
732
+ {"id": request_id, "error": "LLM model not loaded", "done": True},
733
+ )
734
+ else:
735
+ try:
736
+ from mlx_lm import generate
737
+ from mlx_lm.sample_utils import make_sampler
738
+
739
+ messages = params.get("messages", [])
740
+ raw_prompt = params.get("rawPrompt")
741
+ max_tokens = params.get("maxTokens", 8000)
742
+ think = params.get("think", False)
743
+ temperature = params.get("temperature", 0.7)
744
+
745
+ # Determine prompt source and apply chat template
746
+ if raw_prompt:
747
+ # Apply chat template to raw prompt
748
+ chat_messages = [{"role": "user", "content": raw_prompt}]
749
+ prompt = llm_tokenizer.apply_chat_template(
750
+ chat_messages,
751
+ tokenize=False,
752
+ add_generation_prompt=True,
753
+ chat_template_kwargs={"enable_thinking": think},
754
+ )
755
+ log(
756
+ f"Applied chat template to raw prompt (think={think}, temp={temperature})",
757
+ "debug",
758
+ )
759
+ elif messages:
760
+ # Apply chat template to messages array
761
+ prompt = llm_tokenizer.apply_chat_template(
762
+ messages,
763
+ tokenize=False,
764
+ add_generation_prompt=True,
765
+ chat_template_kwargs={"enable_thinking": think},
766
+ )
767
+ log(
768
+ f"Applied chat template to messages (think={think}, temp={temperature})",
769
+ "debug",
770
+ )
771
+ else:
772
+ send_response(
773
+ conn,
774
+ {
775
+ "id": request_id,
776
+ "error": "No prompt provided (need 'rawPrompt' or 'messages')",
777
+ "done": True,
778
+ },
779
+ )
780
+ return
781
+
782
+ if not prompt:
783
+ send_response(
784
+ conn,
785
+ {
786
+ "id": request_id,
787
+ "error": "Empty prompt after template",
788
+ "done": True,
789
+ },
790
+ )
791
+ return
792
+
793
+ log(
794
+ f"Generating text: max_tokens={max_tokens}, think={think}, temp={temperature}",
795
+ "debug",
796
+ )
797
+ log(f"Prompt length: {len(prompt)} chars", "debug")
798
+ t_start = time.time()
799
+
800
+ # Create sampler with temperature (mlx_lm 0.30.7+ API)
801
+ sampler = make_sampler(temp=temperature)
802
+
803
+ output = generate(
804
+ llm_model,
805
+ llm_tokenizer,
806
+ prompt=prompt,
807
+ max_tokens=max_tokens,
808
+ sampler=sampler,
809
+ verbose=VERBOSE,
810
+ )
811
+
812
+ if hasattr(output, "text"):
813
+ response_text = output.text
814
+ elif isinstance(output, str):
815
+ response_text = output
816
+ else:
817
+ response_text = str(output)
818
+
819
+ # Store raw response for debug logging
820
+ raw_response_text = response_text
821
+
822
+ # Strip thinking tags when think=False (model may still output thinking)
823
+ if not think:
824
+ original_len = len(response_text)
825
+ response_text = strip_thinking_tags(response_text)
826
+ if original_len != len(response_text):
827
+ log(
828
+ f"Stripped thinking: {original_len} → {len(response_text)} chars",
829
+ "debug",
830
+ )
831
+
832
+ t_end = time.time()
833
+ generate_time = t_end - t_start
834
+
835
+ log(f"Generation completed in {generate_time:.2f}s", "debug")
836
+
837
+ # Log to debug table if enabled
838
+ if DEBUG_LLM:
839
+ debug_context = params.get("debugContext", {})
840
+ log_llm_call(
841
+ {
842
+ "id": str(request_id),
843
+ "recording_id": debug_context.get("recordingId"),
844
+ "artifact_id": debug_context.get("artifactId"),
845
+ "call_type": debug_context.get("callType", "unknown"),
846
+ "prompt": raw_prompt
847
+ or (messages if messages else None),
848
+ "result": response_text,
849
+ "metadata": {
850
+ "model": llm_loaded_model_name or "unknown",
851
+ "think_param": 1 if think else 0,
852
+ "temperature": temperature,
853
+ "max_tokens": max_tokens,
854
+ "prompt_after_template": prompt[:500] + "..."
855
+ if len(prompt) > 500
856
+ else prompt,
857
+ "chat_template_kwargs": {"enable_thinking": think},
858
+ "raw_response": raw_response_text,
859
+ "prompt_tokens": getattr(
860
+ output, "prompt_tokens", 0
861
+ ),
862
+ "generation_tokens": getattr(
863
+ output, "generation_tokens", 0
864
+ ),
865
+ "generation_tps": getattr(
866
+ output, "generation_tps", 0.0
867
+ ),
868
+ "generate_time_s": generate_time,
869
+ },
870
+ }
871
+ )
872
+
873
+ send_response(
874
+ conn,
875
+ {
876
+ "id": request_id,
877
+ "text": response_text,
878
+ "stats": {
879
+ "prompt_tokens": getattr(output, "prompt_tokens", 0),
880
+ "generation_tokens": getattr(
881
+ output, "generation_tokens", 0
882
+ ),
883
+ "total_tokens": getattr(output, "total_tokens", 0),
884
+ "generation_tps": getattr(
885
+ output, "generation_tps", 0.0
886
+ ),
887
+ "generate_time_s": generate_time,
888
+ },
889
+ "done": True,
890
+ },
891
+ )
892
+
893
+ except Exception as e:
894
+ log(f"Text generation failed: {e}", "error")
895
+ send_response(
896
+ conn, {"id": request_id, "error": str(e), "done": True}
897
+ )
443
898
  elif method == "shutdown":
899
+ global shutting_down
444
900
  log("Shutdown requested")
445
901
  send_response(conn, {"id": request_id, "status": "shutting_down"})
902
+ shutting_down = True
446
903
  cleanup()
447
904
  sys.exit(0)
448
905
  else:
449
- send_response(conn, {"id": request_id, "error": f"Unknown method: {method}"})
906
+ send_response(
907
+ conn, {"id": request_id, "error": f"Unknown method: {method}"}
908
+ )
450
909
 
451
910
  except json.JSONDecodeError as e:
452
911
  log(f"Invalid JSON: {e}", "error")
@@ -458,7 +917,37 @@ def handle_request(
458
917
 
459
918
  def main() -> None:
460
919
  """Main entry point."""
461
- global model, processor, config, server_socket
920
+ global \
921
+ model, \
922
+ processor, \
923
+ config, \
924
+ server_socket, \
925
+ BRIDGE_MODE, \
926
+ SOCKET_PATH, \
927
+ shutting_down
928
+
929
+ # Log debug configuration at startup
930
+ if DEBUG_LLM:
931
+ log(f"Debug logging enabled (DB_PATH={DB_PATH})")
932
+
933
+ # Parse command-line arguments
934
+ parser = argparse.ArgumentParser(description="MLX Bridge for Escribano")
935
+ parser.add_argument(
936
+ "--mode",
937
+ type=str,
938
+ choices=["vlm", "llm"],
939
+ default="vlm",
940
+ help="Bridge mode: 'vlm' for frame analysis, 'llm' for text generation",
941
+ )
942
+ args = parser.parse_args()
943
+ BRIDGE_MODE = args.mode
944
+
945
+ # Adjust socket path based on mode (VLM and LLM use separate sockets)
946
+ base_socket = SOCKET_PATH.replace(".sock", "")
947
+ if BRIDGE_MODE == "llm":
948
+ SOCKET_PATH = f"{base_socket}-llm.sock"
949
+ else:
950
+ SOCKET_PATH = f"{base_socket}-vlm.sock"
462
951
 
463
952
  # Set up signal handlers
464
953
  signal.signal(signal.SIGTERM, signal_handler)
@@ -467,21 +956,30 @@ def main() -> None:
467
956
  # Clean up any existing socket
468
957
  cleanup()
469
958
 
470
- # Load model
471
- model, processor, config = load_model()
959
+ # Load model based on mode
960
+ if BRIDGE_MODE == "vlm":
961
+ model, processor, config = load_model()
962
+ else:
963
+ # LLM mode: load model lazily on first request
964
+ log("LLM-only mode: model will be loaded on first request")
472
965
 
473
966
  # Create socket
474
967
  server_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
475
968
  server_socket.bind(SOCKET_PATH)
476
969
  server_socket.listen(1)
477
970
 
478
- log(f"Listening on {SOCKET_PATH}")
971
+ log(f"Listening on {SOCKET_PATH} (mode: {BRIDGE_MODE})")
479
972
 
480
973
  # Signal ready (for parent process to detect)
481
- print(json.dumps({"status": "ready", "model": MODEL_NAME}), flush=True)
974
+ ready_msg = {
975
+ "status": "ready",
976
+ "model": MODEL_NAME if BRIDGE_MODE == "vlm" else "llm-lazy",
977
+ "mode": BRIDGE_MODE,
978
+ }
979
+ print(json.dumps(ready_msg), flush=True)
482
980
 
483
981
  # Accept connections
484
- while True:
982
+ while not shutting_down:
485
983
  try:
486
984
  conn, _ = server_socket.accept()
487
985
  log("Client connected", "debug")
@@ -512,6 +1010,8 @@ def main() -> None:
512
1010
  log("Client disconnected", "debug")
513
1011
 
514
1012
  except Exception as e:
1013
+ if shutting_down:
1014
+ break
515
1015
  log(f"Accept error: {e}", "error")
516
1016
  continue
517
1017