@zuvia-software-solutions/code-mapper 2.4.0 → 2.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,604 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- MLX-accelerated code embedder for Apple Silicon.
4
-
5
- TWO MODES:
6
- 1. Batch mode (main use): reads nodes directly from SQLite, embeds, writes back.
7
- No IPC overhead — everything happens in one process.
8
- Usage: python3 mlx-embedder.py batch <db_path> [--dims 256] [--max-tokens 2048]
9
-
10
- 2. Interactive mode (for MCP query embedding): reads JSON from stdin.
11
- Usage: python3 mlx-embedder.py [interactive]
12
-
13
- Model: Jina Embeddings v5 Text Small Retrieval (677M params, Qwen3-0.6B backbone)
14
- Optimized with int4 quantization (Linear) + int6 quantization (Embedding).
15
- """
16
-
17
- import sys
18
- import os
19
- import json
20
- import time
21
- import struct
22
- import hashlib
23
-
24
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
25
-
26
- import mlx.core as mx
27
- import mlx.nn as nn
28
- from tokenizers import Tokenizer
29
-
30
- MODEL_DIR = os.path.dirname(os.path.abspath(__file__)) + "/jina-code-0.5b-mlx"
31
-
32
-
33
-
34
-
35
-
36
-
37
- def ensure_model_downloaded():
38
- """Download and convert model weights from HuggingFace if not present."""
39
- weights_path = os.path.join(MODEL_DIR, "model.safetensors")
40
- if os.path.exists(weights_path):
41
- return
42
-
43
- os.makedirs(MODEL_DIR, exist_ok=True)
44
- print(json.dumps({"phase": "downloading", "message": "Downloading Jina Code 0.5B embedding model (~940MB, first time only)..."}), flush=True)
45
-
46
- try:
47
- from huggingface_hub import hf_hub_download
48
- import shutil
49
-
50
- repo = "jinaai/jina-code-embeddings-0.5b"
51
-
52
- # Download tokenizer from the v5 MLX model (same Qwen tokenizer, has tokenizer.json)
53
- tokenizer_repo = "jinaai/jina-embeddings-v5-text-small-retrieval-mlx"
54
- for fname in ["tokenizer.json"]:
55
- dest = os.path.join(MODEL_DIR, fname)
56
- if not os.path.exists(dest):
57
- path = hf_hub_download(tokenizer_repo, fname)
58
- shutil.copy(path, dest)
59
-
60
- # Download vocab/merges from the 0.5B model
61
- for fname in ["vocab.json", "merges.txt"]:
62
- dest = os.path.join(MODEL_DIR, fname)
63
- if not os.path.exists(dest):
64
- path = hf_hub_download(repo, fname)
65
- shutil.copy(path, dest)
66
-
67
- # Download and convert weights: bf16 → fp16, add 'model.' key prefix
68
- print(json.dumps({"phase": "converting", "message": "Converting model weights (bf16 → fp16)..."}), flush=True)
69
- raw_path = hf_hub_download(repo, "model.safetensors")
70
- raw_weights = mx.load(raw_path)
71
- converted = {}
72
- for k, v in raw_weights.items():
73
- new_key = "model." + k
74
- if v.dtype == mx.bfloat16:
75
- v = v.astype(mx.float16)
76
- converted[new_key] = v
77
- mx.save_safetensors(weights_path, converted)
78
-
79
- print(json.dumps({"phase": "downloaded", "message": f"Model ready ({len(converted)} weights converted)"}), flush=True)
80
-
81
- except ImportError:
82
- raise RuntimeError(
83
- "Model weights not found. Install huggingface_hub to auto-download:\n"
84
- " pip3 install huggingface_hub\n"
85
- "Or manually download from: https://huggingface.co/jinaai/jina-code-embeddings-0.5b"
86
- )
87
-
88
-
89
- def load_model():
90
- """Load model, quantize for speed. Auto-downloads weights on first use."""
91
- ensure_model_downloaded()
92
-
93
- sys.path.insert(0, MODEL_DIR)
94
- import importlib
95
- model_module = importlib.import_module("model")
96
- # Support both model class names (v5 = JinaEmbeddingModel, code-0.5b = JinaCodeEmbeddingModel)
97
- JinaEmbeddingModel = getattr(model_module, "JinaEmbeddingModel", None) or getattr(model_module, "JinaCodeEmbeddingModel")
98
-
99
- with open(os.path.join(MODEL_DIR, "config.json")) as f:
100
- config = json.load(f)
101
-
102
- model = JinaEmbeddingModel(config)
103
- weights = mx.load(os.path.join(MODEL_DIR, "model.safetensors"))
104
- model.load_weights(list(weights.items()))
105
-
106
- nn.quantize(model.model, bits=4, group_size=64,
107
- class_predicate=lambda _, m: isinstance(m, nn.Linear))
108
- nn.quantize(model.model, bits=6, group_size=64,
109
- class_predicate=lambda _, m: isinstance(m, nn.Embedding))
110
- mx.eval(model.parameters())
111
-
112
- tokenizer = Tokenizer.from_file(os.path.join(MODEL_DIR, "tokenizer.json"))
113
- return model, tokenizer
114
-
115
-
116
- def get_batch_size_for_tokens(token_count):
117
- """Optimal batch size based on actual token count."""
118
- if token_count <= 64: return 256
119
- if token_count <= 128: return 128
120
- if token_count <= 256: return 64
121
- if token_count <= 512: return 32
122
- if token_count <= 1024: return 16
123
- return 8
124
-
125
-
126
- def embed_tiered(model, tokenizer, texts, task_type="retrieval.passage", truncate_dim=256, max_tokens=2048):
127
- """Embed texts with token-aware batching. Tokenizes first, batches by token count.
128
- Returns embeddings in the ORIGINAL input order."""
129
- if not texts:
130
- return []
131
-
132
- # Add task prefix — auto-detect based on model type
133
- # v5 (Qwen3): "Query: " / "Document: "
134
- # code-0.5b (Qwen2): "Find the most relevant code snippet...\n" / "Candidate code snippet:\n"
135
- is_code_model = "jina-code" in MODEL_DIR
136
- if is_code_model:
137
- prefix_map = {
138
- "retrieval.query": "Find the most relevant code snippet given the following query:\n",
139
- "retrieval.passage": "Candidate code snippet:\n",
140
- }
141
- else:
142
- prefix_map = {"retrieval.query": "Query: ", "retrieval.passage": "Document: "}
143
- prefix = prefix_map.get(task_type, "")
144
- prefixed = [prefix + t for t in texts] if prefix else texts
145
-
146
- # Tokenize everything in one call (fast — Rust HF tokenizer)
147
- encodings = tokenizer.encode_batch(prefixed)
148
-
149
- # Sort by token length for minimal padding
150
- indexed = sorted(range(len(texts)), key=lambda i: len(encodings[i].ids))
151
-
152
- all_embeddings = [None] * len(texts)
153
- i = 0
154
-
155
- while i < len(indexed):
156
- peek_idx = indexed[min(i + 1, len(indexed) - 1)]
157
- tok_count = min(len(encodings[peek_idx].ids), max_tokens)
158
- batch_size = get_batch_size_for_tokens(tok_count)
159
-
160
- batch_indices = []
161
- batch_encs = []
162
- while len(batch_encs) < batch_size and i < len(indexed):
163
- orig_idx = indexed[i]
164
- batch_indices.append(orig_idx)
165
- batch_encs.append(encodings[orig_idx])
166
- i += 1
167
-
168
- max_len = min(max_tokens, max(len(e.ids) for e in batch_encs))
169
- input_ids = []
170
- attention_mask = []
171
- for enc in batch_encs:
172
- ids = enc.ids[:max_len]
173
- mask = enc.attention_mask[:max_len]
174
- pad = max_len - len(ids)
175
- if pad > 0:
176
- ids = ids + [0] * pad
177
- mask = mask + [0] * pad
178
- input_ids.append(ids)
179
- attention_mask.append(mask)
180
-
181
- embs = model(mx.array(input_ids), mx.array(attention_mask))
182
- if truncate_dim and truncate_dim < embs.shape[1]:
183
- embs = embs[:, :truncate_dim]
184
- norms = mx.linalg.norm(embs, axis=1, keepdims=True)
185
- embs = embs / norms
186
- mx.eval(embs)
187
-
188
- emb_list = embs.tolist()
189
- for j, orig_idx in enumerate(batch_indices):
190
- all_embeddings[orig_idx] = emb_list[j]
191
-
192
- return all_embeddings
193
-
194
-
195
- def float_list_to_blob(floats):
196
- """Convert list of floats to a binary blob (Float32Array compatible)."""
197
- return struct.pack(f'{len(floats)}f', *floats)
198
-
199
-
200
- def md5(text):
201
- return hashlib.md5(text.encode()).hexdigest()
202
-
203
-
204
- # =========================================================================
205
- # BATCH MODE — read from SQLite, embed, write back. Zero IPC.
206
- # =========================================================================
207
-
208
- def batch_mode(db_path, dims=256, max_tokens=2048):
209
- import sqlite3
210
-
211
- t0_total = time.time()
212
-
213
- # Load model
214
- print(json.dumps({"phase": "loading", "message": "Loading MLX model..."}), flush=True)
215
- model, tokenizer = load_model()
216
- load_ms = int((time.time() - t0_total) * 1000)
217
- print(json.dumps({"phase": "loaded", "load_ms": load_ms, "device": str(mx.default_device())}), flush=True)
218
-
219
- # Open database
220
- db = sqlite3.connect(db_path)
221
- db.execute("PRAGMA journal_mode=WAL")
222
- db.execute("PRAGMA synchronous=NORMAL")
223
-
224
- # Ensure textHash column exists (migration)
225
- try:
226
- db.execute("SELECT textHash FROM embeddings LIMIT 0")
227
- except sqlite3.OperationalError:
228
- db.execute("ALTER TABLE embeddings ADD COLUMN textHash TEXT")
229
-
230
- # Query embeddable nodes — skip test/fixture files (BM25 covers them)
231
- labels = ('Function', 'Class', 'Method', 'Interface', 'Const', 'Enum', 'TypeAlias', 'Namespace', 'Module', 'Struct')
232
- placeholders = ','.join('?' * len(labels))
233
- all_rows = db.execute(
234
- f"SELECT id, name, label, filePath, content, startLine, endLine, nameExpanded FROM nodes WHERE label IN ({placeholders})",
235
- labels
236
- ).fetchall()
237
-
238
- # Filter out test files — they're searchable via BM25 keyword matching
239
- test_patterns = ('/test/', '/tests/', '/spec/', '/fixtures/', '/__tests__/', '/__mocks__/',
240
- '.test.', '.spec.', '_test.', '_spec.')
241
- rows = [r for r in all_rows if not any(p in (r[3] or '') for p in test_patterns)]
242
- skipped_tests = len(all_rows) - len(rows)
243
-
244
- print(json.dumps({"phase": "queried", "nodes": len(rows), "skipped_tests": skipped_tests}), flush=True)
245
-
246
- if not rows:
247
- print(json.dumps({"phase": "done", "embedded": 0, "skipped": 0, "ms": 0}), flush=True)
248
- db.close()
249
- return
250
-
251
- # Fetch graph context (callers, callees, module) for richer embedding text
252
- node_ids = [r[0] for r in rows]
253
- id_set = set(node_ids)
254
-
255
- # Batch fetch callers
256
- caller_map = {}
257
- callee_map = {}
258
- module_map = {}
259
-
260
- # Chunk the IN clause to avoid SQLite variable limits
261
- CHUNK = 500
262
- for ci in range(0, len(node_ids), CHUNK):
263
- chunk_ids = node_ids[ci:ci+CHUNK]
264
- ph = ','.join('?' * len(chunk_ids))
265
-
266
- for row in db.execute(f"SELECT e.targetId, n.name FROM edges e JOIN nodes n ON n.id = e.sourceId WHERE e.targetId IN ({ph}) AND e.type = 'CALLS' AND e.confidence >= 0.7 LIMIT {len(chunk_ids)*3}", chunk_ids):
267
- caller_map.setdefault(row[0], []).append(row[1])
268
-
269
- for row in db.execute(f"SELECT e.sourceId, n.name FROM edges e JOIN nodes n ON n.id = e.targetId WHERE e.sourceId IN ({ph}) AND e.type = 'CALLS' AND e.confidence >= 0.7 LIMIT {len(chunk_ids)*3}", chunk_ids):
270
- callee_map.setdefault(row[0], []).append(row[1])
271
-
272
- for row in db.execute(f"SELECT e.sourceId, c.heuristicLabel FROM edges e JOIN nodes c ON c.id = e.targetId WHERE e.sourceId IN ({ph}) AND e.type = 'MEMBER_OF' AND c.label = 'Community' LIMIT {len(chunk_ids)}", chunk_ids):
273
- module_map[row[0]] = row[1]
274
-
275
- # Batch fetch import names per file
276
- import_map = {}
277
- for ci in range(0, len(node_ids), CHUNK):
278
- chunk_ids = node_ids[ci:ci+CHUNK]
279
- ph = ','.join('?' * len(chunk_ids))
280
- # Get unique file paths for these nodes
281
- file_paths = [r[3] for r in rows if r[0] in set(chunk_ids)]
282
- unique_files = list(set(file_paths))
283
- if unique_files:
284
- fph = ','.join('?' * len(unique_files))
285
- for row in db.execute(
286
- f"SELECT DISTINCT n.filePath, tn.name FROM nodes n JOIN edges e ON e.sourceId = n.id AND e.type = 'IMPORTS' JOIN nodes tn ON tn.id = e.targetId WHERE n.filePath IN ({fph}) LIMIT {len(unique_files)*10}",
287
- unique_files
288
- ):
289
- import_map.setdefault(row[0], []).append(row[1])
290
-
291
- print(json.dumps({"phase": "context", "with_callers": len(caller_map), "with_module": len(module_map), "with_imports": len(import_map)}), flush=True)
292
-
293
- # Get existing text hashes for skip detection
294
- existing_hashes = {}
295
- for row in db.execute("SELECT nodeId, textHash FROM embeddings WHERE textHash IS NOT NULL"):
296
- existing_hashes[row[0]] = row[1]
297
-
298
- # Generate embedding texts + hashes
299
- # Optimized: semantic summary (name + comment + signature + context)
300
- # instead of raw code dump. 55% fewer tokens, equal search quality.
301
- to_embed = [] # (node_id, text, hash)
302
- skipped = 0
303
-
304
- def extract_first_comment(content):
305
- """Extract JSDoc/comment as natural language description (max 3 lines)."""
306
- if not content:
307
- return ""
308
- lines = content.split("\n")
309
- comment_lines = []
310
- in_block = False
311
- for l in lines:
312
- t = l.strip()
313
- if t.startswith("/**") or t.startswith("/*"):
314
- in_block = True
315
- inner = t.lstrip("/").lstrip("*").strip().rstrip("*/").strip()
316
- if inner and not inner.startswith("@"):
317
- comment_lines.append(inner)
318
- if "*/" in t:
319
- in_block = False
320
- continue
321
- if in_block:
322
- if "*/" in t:
323
- in_block = False
324
- continue
325
- inner = t.lstrip("*").strip()
326
- if inner and not inner.startswith("@"):
327
- comment_lines.append(inner)
328
- if len(comment_lines) >= 3:
329
- break
330
- continue
331
- if t.startswith("//"):
332
- inner = t[2:].strip()
333
- if inner:
334
- comment_lines.append(inner)
335
- if len(comment_lines) >= 3:
336
- break
337
- continue
338
- if t.startswith("#") and not t.startswith("#!"):
339
- inner = t[1:].strip()
340
- if inner:
341
- comment_lines.append(inner)
342
- if len(comment_lines) >= 3:
343
- break
344
- continue
345
- if comment_lines:
346
- break
347
- return " ".join(comment_lines)
348
-
349
- def extract_signature(content, label):
350
- """Extract code signature without full body."""
351
- if not content:
352
- return ""
353
- lines = content.split("\n")
354
- if label == "Interface":
355
- return "\n".join(lines[:30]).strip() if len(lines) <= 30 else "\n".join(lines[:30]) + "\n // ..."
356
- if label == "Class":
357
- sigs = []
358
- for l in lines[:60]:
359
- t = l.strip()
360
- if not t or t.startswith("//") or t.startswith("*") or t.startswith("/*"):
361
- continue
362
- if any(kw in t for kw in ("class ", "private ", "public ", "protected ", "readonly ", "static ", "abstract ")):
363
- sigs.append(t)
364
- if len(sigs) >= 20:
365
- break
366
- return "\n".join(sigs)
367
- return "\n".join(lines[:min(8, len(lines))]).strip()
368
-
369
- for row in rows:
370
- nid, name, label, filePath, content, startLine, endLine, nameExpanded = row
371
- content = content or ""
372
- file_name = filePath.rsplit('/', 1)[-1] if filePath else ""
373
-
374
- # Build semantic embedding text
375
- parts = [f"{label}: {name}"]
376
-
377
- # nameExpanded: natural language bridge (e.g. "checkStaleness" → "check staleness")
378
- if nameExpanded and nameExpanded != name.lower():
379
- parts.append(nameExpanded)
380
-
381
- # First comment as natural language description
382
- comment = extract_first_comment(content)
383
- if comment:
384
- parts.append(comment)
385
-
386
- # Directory context
387
- dir_parts = filePath.rsplit('/', 2)
388
- dir_context = '/'.join(dir_parts[:-1])[-40:] if '/' in filePath else ''
389
-
390
- # File + module location
391
- loc = f"File: {file_name}"
392
- if dir_context:
393
- loc += f" in {dir_context}"
394
- module = module_map.get(nid, "")
395
- if module:
396
- loc += f" | Module: {module}"
397
- parts.append(loc)
398
-
399
- # Import context
400
- file_imports = import_map.get(filePath, [])[:5]
401
- if file_imports:
402
- parts.append(f"Imports: {', '.join(file_imports)}")
403
-
404
- # Graph context
405
- callers = caller_map.get(nid, [])[:5]
406
- callees = callee_map.get(nid, [])[:5]
407
- if callers:
408
- parts.append(f"Called by: {', '.join(callers)}")
409
- if callees:
410
- parts.append(f"Calls: {', '.join(callees)}")
411
-
412
- # Code signature (not full body)
413
- sig = extract_signature(content, label)
414
- if sig:
415
- parts.extend(["", sig])
416
-
417
- text = '\n'.join(parts)
418
- text_hash = md5(text)
419
-
420
- # Skip if hash unchanged
421
- if existing_hashes.get(nid) == text_hash:
422
- skipped += 1
423
- continue
424
-
425
- to_embed.append((nid, text, text_hash))
426
-
427
- print(json.dumps({"phase": "prepared", "to_embed": len(to_embed), "skipped": skipped}), flush=True)
428
-
429
- if not to_embed:
430
- print(json.dumps({"phase": "done", "embedded": 0, "skipped": skipped, "ms": int((time.time() - t0_total) * 1000)}), flush=True)
431
- db.close()
432
- return
433
-
434
- # Deduplicate — embed unique texts only, copy vectors to duplicates.
435
- # Identical embedding texts produce identical vectors; no quality loss.
436
- unique_by_hash = {} # text_hash -> { text, node_ids: [(nid, text_hash)] }
437
- for nid, text, text_hash in to_embed:
438
- if text_hash in unique_by_hash:
439
- unique_by_hash[text_hash]["node_ids"].append((nid, text_hash))
440
- else:
441
- unique_by_hash[text_hash] = {"text": text, "node_ids": [(nid, text_hash)]}
442
- unique_texts = [v["text"] for v in unique_by_hash.values()]
443
- deduped = len(to_embed) - len(unique_texts)
444
-
445
- # Embed unique texts in streaming fashion — process each batch, write to DB
446
- # immediately, free GPU memory. Keeps peak memory at ONE batch instead of ALL.
447
- t0_embed = time.time()
448
- unique_entries = list(unique_by_hash.values())
449
-
450
- # Tokenize + sort (same as embed_tiered but we handle the loop here)
451
- is_code_model = "jina-code" in MODEL_DIR
452
- if is_code_model:
453
- prefix_map = {"retrieval.query": "Find the most relevant code snippet given the following query:\n", "retrieval.passage": "Candidate code snippet:\n"}
454
- else:
455
- prefix_map = {"retrieval.query": "Query: ", "retrieval.passage": "Document: "}
456
- prefix = prefix_map.get("retrieval.passage", "")
457
- prefixed = [prefix + e["text"] for e in unique_entries]
458
- encodings = tokenizer.encode_batch(prefixed)
459
- indexed = sorted(range(len(unique_entries)), key=lambda i: len(encodings[i].ids))
460
-
461
- embedded_count = 0
462
- db.execute("BEGIN")
463
-
464
- i = 0
465
- while i < len(indexed):
466
- peek_idx = indexed[min(i + 1, len(indexed) - 1)]
467
- tok_count = min(len(encodings[peek_idx].ids), max_tokens)
468
- batch_size = get_batch_size_for_tokens(tok_count)
469
-
470
- batch_indices = []
471
- batch_encs = []
472
- while len(batch_encs) < batch_size and i < len(indexed):
473
- orig_idx = indexed[i]
474
- batch_indices.append(orig_idx)
475
- batch_encs.append(encodings[orig_idx])
476
- i += 1
477
-
478
- max_len = min(max_tokens, max(len(e.ids) for e in batch_encs))
479
- input_ids = []
480
- attention_mask = []
481
- for enc in batch_encs:
482
- ids = enc.ids[:max_len]
483
- mask = enc.attention_mask[:max_len]
484
- pad = max_len - len(ids)
485
- if pad > 0:
486
- ids = ids + [0] * pad
487
- mask = mask + [0] * pad
488
- input_ids.append(ids)
489
- attention_mask.append(mask)
490
-
491
- # Forward pass
492
- embs = model(mx.array(input_ids), mx.array(attention_mask))
493
- if dims and dims < embs.shape[1]:
494
- embs = embs[:, :dims]
495
- norms = mx.linalg.norm(embs, axis=1, keepdims=True)
496
- embs = embs / norms
497
- mx.eval(embs)
498
-
499
- # Convert to Python + write to DB immediately
500
- emb_list = embs.tolist()
501
- del embs # free MLX GPU memory
502
-
503
- for j, orig_idx in enumerate(batch_indices):
504
- entry = unique_entries[orig_idx]
505
- blob = float_list_to_blob(emb_list[j])
506
- for nid, th in entry["node_ids"]:
507
- db.execute("INSERT OR REPLACE INTO embeddings (nodeId, embedding, textHash) VALUES (?, ?, ?)",
508
- (nid, blob, th))
509
- embedded_count += len(entry["node_ids"])
510
-
511
- # Progress
512
- pct = i * 100 // len(indexed)
513
- print(json.dumps({"phase": "embedding", "progress": pct, "embedded": embedded_count}), flush=True)
514
-
515
- db.execute("COMMIT")
516
- embed_ms = int((time.time() - t0_embed) * 1000)
517
- write_ms = 0 # included in embed_ms now
518
-
519
- print(json.dumps({"phase": "embedded", "count": len(unique_entries), "deduped": deduped, "ms": embed_ms}), flush=True)
520
-
521
- total_ms = int((time.time() - t0_total) * 1000)
522
- print(json.dumps({
523
- "phase": "done",
524
- "embedded": len(to_embed),
525
- "skipped": skipped,
526
- "embed_ms": embed_ms,
527
- "write_ms": write_ms,
528
- "total_ms": total_ms,
529
- }), flush=True)
530
-
531
- db.close()
532
-
533
-
534
- # =========================================================================
535
- # INTERACTIVE MODE — stdin/stdout JSON for MCP query embedding
536
- # =========================================================================
537
-
538
- def interactive_mode():
539
- t0 = time.time()
540
- model, tokenizer = load_model()
541
- load_ms = int((time.time() - t0) * 1000)
542
-
543
- print(json.dumps({
544
- "status": "ready",
545
- "model": "jina-v5-text-small-retrieval",
546
- "device": str(mx.default_device()),
547
- "load_ms": load_ms,
548
- "precision": "int4-g64",
549
- }), flush=True)
550
-
551
- for line in sys.stdin:
552
- line = line.strip()
553
- if not line:
554
- continue
555
-
556
- try:
557
- req = json.loads(line)
558
- except json.JSONDecodeError:
559
- print(json.dumps({"error": "Invalid JSON"}), flush=True)
560
- continue
561
-
562
- if "cmd" in req:
563
- if req["cmd"] == "ping":
564
- print(json.dumps({"status": "ready"}), flush=True)
565
- elif req["cmd"] == "quit":
566
- break
567
- continue
568
-
569
- texts = req.get("texts", [])
570
- prompt_type = req.get("type", "passage")
571
- dims = req.get("dims", 256)
572
- task_type = "retrieval.query" if prompt_type == "query" else "retrieval.passage"
573
-
574
- t0 = time.time()
575
- try:
576
- embeddings = embed_tiered(model, tokenizer, texts, task_type, dims)
577
- elapsed_ms = int((time.time() - t0) * 1000)
578
- print(json.dumps({
579
- "embeddings": embeddings,
580
- "count": len(embeddings),
581
- "dims": dims,
582
- "ms": elapsed_ms,
583
- }), flush=True)
584
- except Exception as e:
585
- print(json.dumps({"error": str(e)}), flush=True)
586
-
587
-
588
- # =========================================================================
589
- # MAIN
590
- # =========================================================================
591
-
592
- if __name__ == "__main__":
593
- if len(sys.argv) >= 3 and sys.argv[1] == "batch":
594
- db_path = sys.argv[2]
595
- dims = 256
596
- max_tokens = 2048
597
- for i, arg in enumerate(sys.argv[3:], 3):
598
- if arg == "--dims" and i + 1 < len(sys.argv):
599
- dims = int(sys.argv[i + 1])
600
- if arg == "--max-tokens" and i + 1 < len(sys.argv):
601
- max_tokens = int(sys.argv[i + 1])
602
- batch_mode(db_path, dims, max_tokens)
603
- else:
604
- interactive_mode()