@zuvia-software-solutions/code-mapper 2.3.0 → 2.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -266,17 +266,24 @@ export const analyzeCommand = async (inputPath, options) => {
266
266
  continue;
267
267
  try {
268
268
  const msg = JSON.parse(line);
269
- if (msg.phase === 'loaded') {
269
+ if (msg.phase === 'downloading' || msg.phase === 'converting') {
270
+ updateBar(90, msg.message);
271
+ }
272
+ else if (msg.phase === 'loaded') {
270
273
  updateBar(91, `Model loaded (${msg.load_ms}ms)`);
271
274
  }
272
275
  else if (msg.phase === 'queried') {
273
- updateBar(92, `Found ${msg.nodes} embeddable nodes`);
276
+ updateBar(92, `Found ${msg.nodes} embeddable nodes${msg.skipped_tests ? ` (${msg.skipped_tests} test files skipped)` : ''}`);
274
277
  }
275
278
  else if (msg.phase === 'prepared') {
276
279
  updateBar(93, `${msg.to_embed} to embed, ${msg.skipped} cached`);
277
280
  }
281
+ else if (msg.phase === 'embedding') {
282
+ const scaled = 93 + Math.round((msg.progress / 100) * 4);
283
+ updateBar(scaled, `Embedding... ${msg.progress}% (${msg.embedded} written)`);
284
+ }
278
285
  else if (msg.phase === 'embedded') {
279
- updateBar(96, `Embedded ${msg.count} nodes (${(msg.ms / 1000).toFixed(1)}s)`);
286
+ updateBar(97, `Embedded ${msg.count} nodes (${(msg.ms / 1000).toFixed(1)}s)`);
280
287
  }
281
288
  else if (msg.phase === 'done') {
282
289
  updateBar(98, `Embeddings complete (${msg.embedded} new, ${msg.skipped} cached)`);
@@ -0,0 +1,73 @@
1
+ {
2
+ "architectures": [
3
+ "Qwen2ForCausalLM"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "bos_token_id": 151643,
7
+ "eos_token_id": 151643,
8
+ "hidden_act": "silu",
9
+ "hidden_size": 896,
10
+ "initializer_range": 0.02,
11
+ "intermediate_size": 4864,
12
+ "layer_types": [
13
+ "full_attention",
14
+ "full_attention",
15
+ "full_attention",
16
+ "full_attention",
17
+ "full_attention",
18
+ "full_attention",
19
+ "full_attention",
20
+ "full_attention",
21
+ "full_attention",
22
+ "full_attention",
23
+ "full_attention",
24
+ "full_attention",
25
+ "full_attention",
26
+ "full_attention",
27
+ "full_attention",
28
+ "full_attention",
29
+ "full_attention",
30
+ "full_attention",
31
+ "full_attention",
32
+ "full_attention",
33
+ "full_attention",
34
+ "full_attention",
35
+ "full_attention",
36
+ "full_attention"
37
+ ],
38
+ "matryoshka_dims": [
39
+ 64,
40
+ 128,
41
+ 256,
42
+ 512,
43
+ 896
44
+ ],
45
+ "max_position_embeddings": 32768,
46
+ "max_window_layers": 24,
47
+ "model_type": "qwen2",
48
+ "num_attention_heads": 14,
49
+ "num_hidden_layers": 24,
50
+ "num_key_value_heads": 2,
51
+ "prompt_names": [
52
+ "query",
53
+ "passage"
54
+ ],
55
+ "rms_norm_eps": 1e-06,
56
+ "rope_scaling": null,
57
+ "rope_theta": 1000000.0,
58
+ "sliding_window": null,
59
+ "task_names": [
60
+ "nl2code",
61
+ "qa",
62
+ "code2code",
63
+ "code2nl",
64
+ "code2completion"
65
+ ],
66
+ "tie_word_embeddings": true,
67
+ "tokenizer_class": "Qwen2TokenizerFast",
68
+ "torch_dtype": "bfloat16",
69
+ "transformers_version": "4.53.0",
70
+ "use_cache": true,
71
+ "use_sliding_window": false,
72
+ "vocab_size": 151936
73
+ }
@@ -0,0 +1,127 @@
1
+
2
+ from dataclasses import dataclass
3
+ from typing import Optional, List
4
+ import mlx.core as mx
5
+ import mlx.nn as nn
6
+
7
+ @dataclass
8
+ class ModelArgs:
9
+ hidden_size: int
10
+ num_hidden_layers: int
11
+ intermediate_size: int
12
+ num_attention_heads: int
13
+ rms_norm_eps: float
14
+ vocab_size: int
15
+ num_key_value_heads: int
16
+ max_position_embeddings: int
17
+ rope_theta: float = 1000000.0
18
+ tie_word_embeddings: bool = True
19
+
20
+ class Attention(nn.Module):
21
+ def __init__(self, args):
22
+ super().__init__()
23
+ dim = args.hidden_size
24
+ self.n_heads = args.num_attention_heads
25
+ self.n_kv_heads = args.num_key_value_heads
26
+ self.head_dim = dim // self.n_heads
27
+ self.scale = self.head_dim ** -0.5
28
+ self.rope_theta = args.rope_theta
29
+ self.q_proj = nn.Linear(dim, self.n_heads * self.head_dim, bias=True)
30
+ self.k_proj = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=True)
31
+ self.v_proj = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=True)
32
+ self.o_proj = nn.Linear(self.n_heads * self.head_dim, dim, bias=False)
33
+
34
+ def __call__(self, x, mask=None):
35
+ B, L, D = x.shape
36
+ q = self.q_proj(x).reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
37
+ k = self.k_proj(x).reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
38
+ v = self.v_proj(x).reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
39
+ q = mx.fast.rope(q, self.head_dim, traditional=False, base=self.rope_theta, scale=1.0, offset=0)
40
+ k = mx.fast.rope(k, self.head_dim, traditional=False, base=self.rope_theta, scale=1.0, offset=0)
41
+ out = mx.fast.scaled_dot_product_attention(q, k, v, mask=mask.astype(q.dtype) if mask is not None else None, scale=self.scale)
42
+ return self.o_proj(out.transpose(0, 2, 1, 3).reshape(B, L, -1))
43
+
44
+ class MLP(nn.Module):
45
+ def __init__(self, dim, hidden):
46
+ super().__init__()
47
+ self.gate_proj = nn.Linear(dim, hidden, bias=False)
48
+ self.down_proj = nn.Linear(hidden, dim, bias=False)
49
+ self.up_proj = nn.Linear(dim, hidden, bias=False)
50
+ def __call__(self, x):
51
+ return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
52
+
53
+ class TransformerBlock(nn.Module):
54
+ def __init__(self, args):
55
+ super().__init__()
56
+ self.self_attn = Attention(args)
57
+ self.mlp = MLP(args.hidden_size, args.intermediate_size)
58
+ self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
59
+ self.post_attention_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
60
+ def __call__(self, x, mask=None):
61
+ h = x + self.self_attn(self.input_layernorm(x), mask)
62
+ return h + self.mlp(self.post_attention_layernorm(h))
63
+
64
+ class Qwen2Model(nn.Module):
65
+ def __init__(self, args):
66
+ super().__init__()
67
+ self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
68
+ self.layers = [TransformerBlock(args) for _ in range(args.num_hidden_layers)]
69
+ self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
70
+ def __call__(self, inputs, mask=None):
71
+ h = self.embed_tokens(inputs)
72
+ for layer in self.layers:
73
+ h = layer(h, mask)
74
+ return self.norm(h)
75
+
76
+ class JinaCodeEmbeddingModel(nn.Module):
77
+ def __init__(self, config):
78
+ super().__init__()
79
+ args = ModelArgs(
80
+ hidden_size=config["hidden_size"],
81
+ num_hidden_layers=config["num_hidden_layers"],
82
+ intermediate_size=config["intermediate_size"],
83
+ num_attention_heads=config["num_attention_heads"],
84
+ rms_norm_eps=config["rms_norm_eps"],
85
+ vocab_size=config["vocab_size"],
86
+ num_key_value_heads=config["num_key_value_heads"],
87
+ max_position_embeddings=config["max_position_embeddings"],
88
+ rope_theta=config.get("rope_theta", 1000000.0),
89
+ )
90
+ self.model = Qwen2Model(args)
91
+ self.config = config
92
+
93
+ def __call__(self, input_ids, attention_mask=None):
94
+ B, L = input_ids.shape
95
+ causal = mx.tril(mx.ones((L, L)))
96
+ causal = mx.where(causal == 0, -1e4, 0.0)[None, None, :, :]
97
+ if attention_mask is not None:
98
+ pad = mx.where(attention_mask == 0, -1e4, 0.0)[:, None, None, :]
99
+ mask = causal + pad
100
+ else:
101
+ mask = causal
102
+ h = self.model(input_ids, mask)
103
+ if attention_mask is not None:
104
+ seq_lens = mx.sum(attention_mask.astype(mx.int32), axis=1) - 1
105
+ embs = h[mx.arange(B), seq_lens]
106
+ else:
107
+ embs = h[:, -1, :]
108
+ norms = mx.linalg.norm(embs, axis=1, keepdims=True)
109
+ return embs / norms
110
+
111
+ def encode(self, texts, tokenizer, max_length=8192, truncate_dim=None, task="nl2code", prompt_type="query"):
112
+ PREFIXES = {"nl2code": {"query": "Find the most relevant code snippet given the following query:\n", "passage": "Candidate code snippet:\n"}}
113
+ prefix = PREFIXES.get(task, {}).get(prompt_type, "")
114
+ if prefix:
115
+ texts = [prefix + t for t in texts]
116
+ encodings = tokenizer.encode_batch(texts)
117
+ ml = min(max_length, max(len(e.ids) for e in encodings))
118
+ iids, amask = [], []
119
+ for e in encodings:
120
+ ids = e.ids[:ml]; m = e.attention_mask[:ml]; p = ml - len(ids)
121
+ if p > 0: ids = ids + [0]*p; m = m + [0]*p
122
+ iids.append(ids); amask.append(m)
123
+ embs = self(mx.array(iids), mx.array(amask))
124
+ if truncate_dim:
125
+ embs = embs[:, :truncate_dim]
126
+ embs = embs / mx.linalg.norm(embs, axis=1, keepdims=True)
127
+ return embs
@@ -35,27 +35,54 @@ MODEL_DIR = os.path.dirname(os.path.abspath(__file__)) + "/jina-code-0.5b-mlx"
35
35
 
36
36
 
37
37
  def ensure_model_downloaded():
38
- """Download model weights from HuggingFace if not present."""
38
+ """Download and convert model weights from HuggingFace if not present."""
39
39
  weights_path = os.path.join(MODEL_DIR, "model.safetensors")
40
40
  if os.path.exists(weights_path):
41
41
  return
42
42
 
43
- print(json.dumps({"phase": "downloading", "message": "Downloading embedding model (~1.1GB, first time only)..."}), flush=True)
43
+ os.makedirs(MODEL_DIR, exist_ok=True)
44
+ print(json.dumps({"phase": "downloading", "message": "Downloading Jina Code 0.5B embedding model (~940MB, first time only)..."}), flush=True)
45
+
44
46
  try:
45
47
  from huggingface_hub import hf_hub_download
46
48
  import shutil
47
- repo = "jinaai/jina-embeddings-v5-text-small-retrieval-mlx"
48
- for fname in ["model.safetensors", "tokenizer.json", "vocab.json", "merges.txt", "tokenizer_config.json"]:
49
+
50
+ repo = "jinaai/jina-code-embeddings-0.5b"
51
+
52
+ # Download tokenizer from the v5 MLX model (same Qwen tokenizer, has tokenizer.json)
53
+ tokenizer_repo = "jinaai/jina-embeddings-v5-text-small-retrieval-mlx"
54
+ for fname in ["tokenizer.json"]:
55
+ dest = os.path.join(MODEL_DIR, fname)
56
+ if not os.path.exists(dest):
57
+ path = hf_hub_download(tokenizer_repo, fname)
58
+ shutil.copy(path, dest)
59
+
60
+ # Download vocab/merges from the 0.5B model
61
+ for fname in ["vocab.json", "merges.txt"]:
49
62
  dest = os.path.join(MODEL_DIR, fname)
50
63
  if not os.path.exists(dest):
51
64
  path = hf_hub_download(repo, fname)
52
65
  shutil.copy(path, dest)
53
- print(json.dumps({"phase": "downloaded", "message": "Model downloaded successfully"}), flush=True)
66
+
67
+ # Download and convert weights: bf16 → fp16, add 'model.' key prefix
68
+ print(json.dumps({"phase": "converting", "message": "Converting model weights (bf16 → fp16)..."}), flush=True)
69
+ raw_path = hf_hub_download(repo, "model.safetensors")
70
+ raw_weights = mx.load(raw_path)
71
+ converted = {}
72
+ for k, v in raw_weights.items():
73
+ new_key = "model." + k
74
+ if v.dtype == mx.bfloat16:
75
+ v = v.astype(mx.float16)
76
+ converted[new_key] = v
77
+ mx.save_safetensors(weights_path, converted)
78
+
79
+ print(json.dumps({"phase": "downloaded", "message": f"Model ready ({len(converted)} weights converted)"}), flush=True)
80
+
54
81
  except ImportError:
55
82
  raise RuntimeError(
56
83
  "Model weights not found. Install huggingface_hub to auto-download:\n"
57
84
  " pip3 install huggingface_hub\n"
58
- "Or manually download from: https://huggingface.co/jinaai/jina-embeddings-v5-text-small-retrieval-mlx"
85
+ "Or manually download from: https://huggingface.co/jinaai/jina-code-embeddings-0.5b"
59
86
  )
60
87
 
61
88
 
@@ -388,26 +415,81 @@ def batch_mode(db_path, dims=256, max_tokens=2048):
388
415
  unique_texts = [v["text"] for v in unique_by_hash.values()]
389
416
  deduped = len(to_embed) - len(unique_texts)
390
417
 
391
- # Embed only unique texts
418
+ # Embed unique texts in streaming fashion — process each batch, write to DB
419
+ # immediately, free GPU memory. Keeps peak memory at ONE batch instead of ALL.
392
420
  t0_embed = time.time()
393
- embeddings = embed_tiered(model, tokenizer, unique_texts, "retrieval.passage", dims, max_tokens)
394
- embed_ms = int((time.time() - t0_embed) * 1000)
421
+ unique_entries = list(unique_by_hash.values())
395
422
 
396
- print(json.dumps({"phase": "embedded", "count": len(unique_texts), "deduped": deduped, "ms": embed_ms}), flush=True)
423
+ # Tokenize + sort (same as embed_tiered but we handle the loop here)
424
+ is_code_model = "jina-code" in MODEL_DIR
425
+ if is_code_model:
426
+ prefix_map = {"retrieval.query": "Find the most relevant code snippet given the following query:\n", "retrieval.passage": "Candidate code snippet:\n"}
427
+ else:
428
+ prefix_map = {"retrieval.query": "Query: ", "retrieval.passage": "Document: "}
429
+ prefix = prefix_map.get("retrieval.passage", "")
430
+ prefixed = [prefix + e["text"] for e in unique_entries]
431
+ encodings = tokenizer.encode_batch(prefixed)
432
+ indexed = sorted(range(len(unique_entries)), key=lambda i: len(encodings[i].ids))
397
433
 
398
- # Write to database — copy embedding to all nodes sharing the same hash
399
- t0_write = time.time()
434
+ embedded_count = 0
400
435
  db.execute("BEGIN")
401
- for i, (text_hash, entry) in enumerate(unique_by_hash.items()):
402
- emb = embeddings[i]
403
- if emb is None:
404
- continue
405
- blob = float_list_to_blob(emb)
406
- for nid, th in entry["node_ids"]:
407
- db.execute("INSERT OR REPLACE INTO embeddings (nodeId, embedding, textHash) VALUES (?, ?, ?)",
408
- (nid, blob, th))
436
+
437
+ i = 0
438
+ while i < len(indexed):
439
+ peek_idx = indexed[min(i + 1, len(indexed) - 1)]
440
+ tok_count = min(len(encodings[peek_idx].ids), max_tokens)
441
+ batch_size = get_batch_size_for_tokens(tok_count)
442
+
443
+ batch_indices = []
444
+ batch_encs = []
445
+ while len(batch_encs) < batch_size and i < len(indexed):
446
+ orig_idx = indexed[i]
447
+ batch_indices.append(orig_idx)
448
+ batch_encs.append(encodings[orig_idx])
449
+ i += 1
450
+
451
+ max_len = min(max_tokens, max(len(e.ids) for e in batch_encs))
452
+ input_ids = []
453
+ attention_mask = []
454
+ for enc in batch_encs:
455
+ ids = enc.ids[:max_len]
456
+ mask = enc.attention_mask[:max_len]
457
+ pad = max_len - len(ids)
458
+ if pad > 0:
459
+ ids = ids + [0] * pad
460
+ mask = mask + [0] * pad
461
+ input_ids.append(ids)
462
+ attention_mask.append(mask)
463
+
464
+ # Forward pass
465
+ embs = model(mx.array(input_ids), mx.array(attention_mask))
466
+ if dims and dims < embs.shape[1]:
467
+ embs = embs[:, :dims]
468
+ norms = mx.linalg.norm(embs, axis=1, keepdims=True)
469
+ embs = embs / norms
470
+ mx.eval(embs)
471
+
472
+ # Convert to Python + write to DB immediately
473
+ emb_list = embs.tolist()
474
+ del embs # free MLX GPU memory
475
+
476
+ for j, orig_idx in enumerate(batch_indices):
477
+ entry = unique_entries[orig_idx]
478
+ blob = float_list_to_blob(emb_list[j])
479
+ for nid, th in entry["node_ids"]:
480
+ db.execute("INSERT OR REPLACE INTO embeddings (nodeId, embedding, textHash) VALUES (?, ?, ?)",
481
+ (nid, blob, th))
482
+ embedded_count += len(entry["node_ids"])
483
+
484
+ # Progress
485
+ pct = i * 100 // len(indexed)
486
+ print(json.dumps({"phase": "embedding", "progress": pct, "embedded": embedded_count}), flush=True)
487
+
409
488
  db.execute("COMMIT")
410
- write_ms = int((time.time() - t0_write) * 1000)
489
+ embed_ms = int((time.time() - t0_embed) * 1000)
490
+ write_ms = 0 # included in embed_ms now
491
+
492
+ print(json.dumps({"phase": "embedded", "count": len(unique_entries), "deduped": deduped, "ms": embed_ms}), flush=True)
411
493
 
412
494
  total_ms = int((time.time() - t0_total) * 1000)
413
495
  print(json.dumps({
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@zuvia-software-solutions/code-mapper",
3
- "version": "2.3.0",
3
+ "version": "2.3.2",
4
4
  "description": "Graph-powered code intelligence for AI agents. Index any codebase, query via MCP or CLI.",
5
5
  "author": "Abhigyan Patwari",
6
6
  "license": "PolyForm-Noncommercial-1.0.0",
@@ -36,8 +36,8 @@
36
36
  "skills",
37
37
  "vendor",
38
38
  "models/mlx-embedder.py",
39
- "models/jina-v5-small-mlx/model.py",
40
- "models/jina-v5-small-mlx/config.json"
39
+ "models/jina-code-0.5b-mlx/model.py",
40
+ "models/jina-code-0.5b-mlx/config.json"
41
41
  ],
42
42
  "scripts": {
43
43
  "build": "tsc",
@@ -1,19 +0,0 @@
1
- {
2
- "model_type": "qwen3",
3
- "hidden_size": 1024,
4
- "num_hidden_layers": 28,
5
- "intermediate_size": 3072,
6
- "num_attention_heads": 16,
7
- "num_key_value_heads": 8,
8
- "rms_norm_eps": 1e-06,
9
- "vocab_size": 151936,
10
- "max_position_embeddings": 32768,
11
- "rope_theta": 3500000,
12
- "rope_parameters": {
13
- "rope_theta": 3500000,
14
- "rope_type": "default"
15
- },
16
- "head_dim": 128,
17
- "tie_word_embeddings": true,
18
- "rope_scaling": null
19
- }
@@ -1,260 +0,0 @@
1
- """
2
- Jina Embeddings v5 Text Small - MLX Implementation
3
-
4
- Pure MLX port of jina-embeddings-v5-text-small (Qwen3-0.6B backbone).
5
- Zero dependency on PyTorch or transformers.
6
-
7
- Features:
8
- - Causal attention (decoder architecture)
9
- - QKNorm (q_norm/k_norm per head)
10
- - Last-token pooling
11
- - L2 normalization
12
- - Matryoshka embedding dimensions: [32, 64, 128, 256, 512, 768, 1024]
13
- - Max sequence length: 32768 tokens
14
- - Embedding dimension: 1024
15
-
16
- Architecture:
17
- - RoPE (rope_theta from config)
18
- - SwiGLU MLP
19
- - RMSNorm
20
- - QKNorm (RMSNorm on Q/K per head)
21
- - No attention bias
22
- """
23
-
24
- from dataclasses import dataclass
25
- from typing import Any, Dict, Optional, Union
26
-
27
- import mlx.core as mx
28
- import mlx.nn as nn
29
-
30
-
31
- @dataclass
32
- class ModelArgs:
33
- model_type: str
34
- hidden_size: int
35
- num_hidden_layers: int
36
- intermediate_size: int
37
- num_attention_heads: int
38
- rms_norm_eps: float
39
- vocab_size: int
40
- num_key_value_heads: int
41
- max_position_embeddings: int
42
- head_dim: int
43
- tie_word_embeddings: bool
44
- rope_parameters: Optional[Dict[str, Union[float, str]]] = None
45
- rope_theta: Optional[float] = None
46
- rope_scaling: Optional[Dict[str, Union[float, str]]] = None
47
-
48
-
49
- class Attention(nn.Module):
50
- def __init__(self, args: ModelArgs):
51
- super().__init__()
52
-
53
- dim = args.hidden_size
54
- self.n_heads = n_heads = args.num_attention_heads
55
- self.n_kv_heads = n_kv_heads = args.num_key_value_heads
56
-
57
- head_dim = args.head_dim
58
- self.scale = head_dim**-0.5
59
- self.head_dim = head_dim
60
-
61
- self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
62
- self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
63
- self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
64
- self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
65
-
66
- # Qwen3 has QKNorm
67
- self.q_norm = nn.RMSNorm(head_dim, eps=args.rms_norm_eps)
68
- self.k_norm = nn.RMSNorm(head_dim, eps=args.rms_norm_eps)
69
-
70
- # Resolve rope_theta from config
71
- if args.rope_parameters and 'rope_theta' in args.rope_parameters:
72
- rope_theta = float(args.rope_parameters['rope_theta'])
73
- elif args.rope_theta:
74
- rope_theta = float(args.rope_theta)
75
- else:
76
- rope_theta = 10000.0
77
- self.rope_theta = rope_theta
78
-
79
- def __call__(
80
- self,
81
- x: mx.array,
82
- mask: Optional[mx.array] = None,
83
- ) -> mx.array:
84
- B, L, D = x.shape
85
-
86
- queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
87
-
88
- # Reshape and apply QKNorm
89
- queries = self.q_norm(queries.reshape(B, L, self.n_heads, -1)).transpose(0, 2, 1, 3)
90
- keys = self.k_norm(keys.reshape(B, L, self.n_kv_heads, -1)).transpose(0, 2, 1, 3)
91
- values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
92
-
93
- # RoPE via mx.fast
94
- queries = mx.fast.rope(queries, self.head_dim, traditional=False, base=self.rope_theta, scale=1.0, offset=0)
95
- keys = mx.fast.rope(keys, self.head_dim, traditional=False, base=self.rope_theta, scale=1.0, offset=0)
96
-
97
- # Scaled dot-product attention (handles GQA natively)
98
- output = mx.fast.scaled_dot_product_attention(
99
- queries, keys, values,
100
- mask=mask.astype(queries.dtype) if mask is not None else None,
101
- scale=self.scale,
102
- )
103
-
104
- output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
105
- return self.o_proj(output)
106
-
107
-
108
- class MLP(nn.Module):
109
- def __init__(self, dim, hidden_dim):
110
- super().__init__()
111
- self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
112
- self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
113
- self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
114
-
115
- def __call__(self, x) -> mx.array:
116
- gate = nn.silu(self.gate_proj(x))
117
- return self.down_proj(gate * self.up_proj(x))
118
-
119
-
120
- class TransformerBlock(nn.Module):
121
- def __init__(self, args: ModelArgs):
122
- super().__init__()
123
- self.self_attn = Attention(args)
124
- self.mlp = MLP(args.hidden_size, args.intermediate_size)
125
- self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
126
- self.post_attention_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
127
-
128
- def __call__(
129
- self,
130
- x: mx.array,
131
- mask: Optional[mx.array] = None,
132
- ) -> mx.array:
133
- r = self.self_attn(self.input_layernorm(x), mask)
134
- h = x + r
135
- r = self.mlp(self.post_attention_layernorm(h))
136
- out = h + r
137
- return out
138
-
139
-
140
- class Qwen3Model(nn.Module):
141
- def __init__(self, args: ModelArgs):
142
- super().__init__()
143
- self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
144
- self.layers = [TransformerBlock(args=args) for _ in range(args.num_hidden_layers)]
145
- self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
146
-
147
- def __call__(self, inputs: mx.array, mask: Optional[mx.array] = None):
148
- h = self.embed_tokens(inputs)
149
- for layer in self.layers:
150
- h = layer(h, mask)
151
- return self.norm(h)
152
-
153
-
154
- class JinaEmbeddingModel(nn.Module):
155
- """Jina v5-text-small embedding model with last-token pooling."""
156
-
157
- def __init__(self, config: dict):
158
- super().__init__()
159
- args = ModelArgs(**config)
160
- self.model = Qwen3Model(args)
161
- self.config = config
162
-
163
- def __call__(
164
- self,
165
- input_ids: mx.array,
166
- attention_mask: Optional[mx.array] = None,
167
- ):
168
- batch_size, seq_len = input_ids.shape
169
-
170
- # Causal mask (Qwen3 is a decoder model)
171
- causal_mask = mx.tril(mx.ones((seq_len, seq_len)))
172
- causal_mask = mx.where(causal_mask == 0, -1e4, 0.0)
173
- causal_mask = causal_mask[None, None, :, :]
174
-
175
- # Combine with padding mask
176
- if attention_mask is not None:
177
- padding_mask = mx.where(attention_mask == 0, -1e4, 0.0)
178
- padding_mask = padding_mask[:, None, None, :]
179
- mask = causal_mask + padding_mask
180
- else:
181
- mask = causal_mask
182
-
183
- hidden_states = self.model(input_ids, mask)
184
-
185
- # Last token pooling
186
- if attention_mask is not None:
187
- sequence_lengths = mx.sum(attention_mask, axis=1) - 1
188
- batch_indices = mx.arange(hidden_states.shape[0])
189
- embeddings = hidden_states[batch_indices, sequence_lengths]
190
- else:
191
- embeddings = hidden_states[:, -1, :]
192
-
193
- # L2 normalization
194
- norms = mx.linalg.norm(embeddings, axis=1, keepdims=True)
195
- embeddings = embeddings / norms
196
-
197
- return embeddings
198
-
199
- def encode(
200
- self,
201
- texts: list[str],
202
- tokenizer,
203
- max_length: int = 8192,
204
- truncate_dim: Optional[int] = None,
205
- task_type: str = "retrieval.query",
206
- ):
207
- """
208
- Encode texts to embeddings.
209
-
210
- Args:
211
- texts: List of input texts
212
- tokenizer: Tokenizer instance (from tokenizers library)
213
- max_length: Maximum sequence length
214
- truncate_dim: Optional Matryoshka dimension [32, 64, 128, 256, 512, 768, 1024]
215
- task_type: Task prefix ("retrieval.query", "retrieval.passage", etc.)
216
-
217
- Returns:
218
- Embeddings array [batch, dim]
219
- """
220
- prefix_map = {
221
- "retrieval.query": "Query: ",
222
- "retrieval.passage": "Document: ",
223
- "classification": "Document: ",
224
- "text-matching": "Document: ",
225
- "clustering": "Document: ",
226
- }
227
- prefix = prefix_map.get(task_type, "")
228
-
229
- if prefix:
230
- texts = [prefix + text for text in texts]
231
-
232
- encodings = tokenizer.encode_batch(texts)
233
-
234
- max_len = min(max_length, max(len(enc.ids) for enc in encodings))
235
- input_ids = []
236
- attention_mask = []
237
-
238
- for encoding in encodings:
239
- ids = encoding.ids[:max_len]
240
- mask = encoding.attention_mask[:max_len]
241
-
242
- pad_len = max_len - len(ids)
243
- if pad_len > 0:
244
- ids = ids + [0] * pad_len
245
- mask = mask + [0] * pad_len
246
-
247
- input_ids.append(ids)
248
- attention_mask.append(mask)
249
-
250
- input_ids = mx.array(input_ids)
251
- attention_mask = mx.array(attention_mask)
252
-
253
- embeddings = self(input_ids, attention_mask)
254
-
255
- if truncate_dim is not None:
256
- embeddings = embeddings[:, :truncate_dim]
257
- norms = mx.linalg.norm(embeddings, axis=1, keepdims=True)
258
- embeddings = embeddings / norms
259
-
260
- return embeddings