@zuvia-software-solutions/code-mapper 2.4.1 → 2.5.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/cli/analyze.d.ts +0 -1
- package/dist/cli/analyze.js +11 -87
- package/dist/cli/index.js +2 -2
- package/dist/core/embeddings/index.d.ts +2 -3
- package/dist/core/embeddings/index.js +2 -3
- package/dist/core/embeddings/nl-embed-worker.d.ts +8 -0
- package/dist/core/embeddings/nl-embed-worker.js +38 -0
- package/dist/core/embeddings/nl-embedder.d.ts +1 -1
- package/dist/core/embeddings/nl-embedder.js +199 -30
- package/dist/core/incremental/refresh.js +18 -26
- package/dist/core/ingestion/call-processor.js +11 -5
- package/dist/core/semantic/tsgo-service.js +16 -0
- package/dist/mcp/local/local-backend.js +40 -27
- package/dist/mcp/server.js +2 -2
- package/dist/mcp/tools.js +1 -0
- package/package.json +2 -5
- package/models/jina-code-0.5b-mlx/config.json +0 -73
- package/models/jina-code-0.5b-mlx/model.py +0 -127
- package/models/mlx-embedder.py +0 -604
|
@@ -346,10 +346,10 @@ export class LocalBackend {
|
|
|
346
346
|
}
|
|
347
347
|
this.loadNlEmbeddingCache(id); // NL cache loaded regardless (cheap, may not exist)
|
|
348
348
|
}
|
|
349
|
-
// Pre-warm
|
|
349
|
+
// Pre-warm bge-small embedder so first query has zero model-load latency
|
|
350
350
|
if (anyEmbeddings) {
|
|
351
|
-
import('../../core/embeddings/embedder.js').then(({
|
|
352
|
-
|
|
351
|
+
import('../../core/embeddings/nl-embedder.js').then(({ initNlEmbedder }) => {
|
|
352
|
+
initNlEmbedder().catch(() => { });
|
|
353
353
|
}).catch(() => { });
|
|
354
354
|
}
|
|
355
355
|
return this.repos.size > 0;
|
|
@@ -1119,19 +1119,7 @@ export class LocalBackend {
|
|
|
1119
1119
|
...(r.startLine != null ? { startLine: r.startLine } : {}),
|
|
1120
1120
|
...(r.endLine != null ? { endLine: r.endLine } : {}),
|
|
1121
1121
|
}));
|
|
1122
|
-
|
|
1123
|
-
const nlForRRF = nlSemanticResults.map((r) => ({
|
|
1124
|
-
nodeId: String(r.nodeId ?? ''), name: String(r.name ?? ''), label: String(r.type ?? ''),
|
|
1125
|
-
filePath: String(r.filePath ?? ''), distance: Number(r.distance ?? 1),
|
|
1126
|
-
...(r.startLine != null ? { startLine: r.startLine } : {}),
|
|
1127
|
-
...(r.endLine != null ? { endLine: r.endLine } : {}),
|
|
1128
|
-
}));
|
|
1129
|
-
// Merge code + NL semantic into one semantic list (best of both worlds)
|
|
1130
|
-
const combinedSemantic = [...semanticForRRF, ...nlForRRF]
|
|
1131
|
-
.sort((a, b) => a.distance - b.distance)
|
|
1132
|
-
.filter((r, i, arr) => arr.findIndex(x => x.nodeId === r.nodeId) === i) // dedupe by nodeId
|
|
1133
|
-
.slice(0, searchLimit);
|
|
1134
|
-
let rrfMerged = mergeWithRRF(bm25ForRRF, combinedSemantic, { limit: searchLimit });
|
|
1122
|
+
let rrfMerged = mergeWithRRF(bm25ForRRF, semanticForRRF, { limit: searchLimit });
|
|
1135
1123
|
// Store NL match reasons for display
|
|
1136
1124
|
const nlMatchReasons = new Map();
|
|
1137
1125
|
for (const r of nlSemanticResults) {
|
|
@@ -1139,6 +1127,27 @@ export class LocalBackend {
|
|
|
1139
1127
|
nlMatchReasons.set(r.nodeId, r.match_reason);
|
|
1140
1128
|
}
|
|
1141
1129
|
}
|
|
1130
|
+
// Inject NL semantic results directly — they bridge the vocabulary gap
|
|
1131
|
+
// that BM25 and code embeddings miss. Insert at high score so they
|
|
1132
|
+
// appear in results even when BM25 finds unrelated "prevent" matches.
|
|
1133
|
+
if (nlSemanticResults.length > 0) {
|
|
1134
|
+
const mainIds = new Set(rrfMerged.map(r => r.nodeId || r.filePath));
|
|
1135
|
+
const topMainScore = rrfMerged[0]?.score ?? 0.01;
|
|
1136
|
+
for (let i = 0; i < Math.min(nlSemanticResults.length, 5); i++) {
|
|
1137
|
+
const nlr = nlSemanticResults[i];
|
|
1138
|
+
if (mainIds.has(nlr.nodeId))
|
|
1139
|
+
continue; // already in results
|
|
1140
|
+
// Score NL results high — at or above the top BM25 result
|
|
1141
|
+
const nlScore = topMainScore * (1.0 - i * 0.1);
|
|
1142
|
+
rrfMerged.push({
|
|
1143
|
+
filePath: nlr.filePath, score: nlScore, rank: i + 1,
|
|
1144
|
+
sources: ['semantic'], nodeId: nlr.nodeId, name: nlr.name,
|
|
1145
|
+
label: nlr.type, startLine: nlr.startLine, endLine: nlr.endLine,
|
|
1146
|
+
});
|
|
1147
|
+
}
|
|
1148
|
+
rrfMerged.sort((a, b) => b.score - a.score);
|
|
1149
|
+
rrfMerged = rrfMerged.slice(0, searchLimit);
|
|
1150
|
+
}
|
|
1142
1151
|
// Merge refs + fileWords into the RRF results (lower weight)
|
|
1143
1152
|
if (refsForRRF.length > 0 || fileWordsForRRF.length > 0) {
|
|
1144
1153
|
const supplemental = mergeWithRRF(refsForRRF, fileWordsForRRF.map((r) => ({
|
|
@@ -1200,14 +1209,13 @@ export class LocalBackend {
|
|
|
1200
1209
|
data.match_reason = reason;
|
|
1201
1210
|
return { score: rrf.score, data };
|
|
1202
1211
|
});
|
|
1203
|
-
// Filter
|
|
1212
|
+
// Filter non-code files (JSON, MD, YAML). Test files are included by default.
|
|
1213
|
+
// Agents can pass exclude_tests: true to filter test files when not needed.
|
|
1204
1214
|
merged = merged.filter(item => {
|
|
1205
1215
|
const fp = String(item.data.filePath ?? '').toLowerCase();
|
|
1206
|
-
if (isTestFilePath(fp))
|
|
1207
|
-
return false;
|
|
1208
1216
|
if (fp.endsWith('.json') || fp.endsWith('.md') || fp.endsWith('.yml') || fp.endsWith('.yaml'))
|
|
1209
1217
|
return false;
|
|
1210
|
-
if (
|
|
1218
|
+
if (params.exclude_tests && isTestFilePath(fp))
|
|
1211
1219
|
return false;
|
|
1212
1220
|
return true;
|
|
1213
1221
|
});
|
|
@@ -1591,8 +1599,8 @@ export class LocalBackend {
|
|
|
1591
1599
|
return [];
|
|
1592
1600
|
}
|
|
1593
1601
|
const { DEFAULT_MAX_SEMANTIC_DISTANCE } = await import('../../core/search/types.js');
|
|
1594
|
-
const {
|
|
1595
|
-
const queryVec = await
|
|
1602
|
+
const { nlEmbed } = await import('../../core/embeddings/nl-embedder.js');
|
|
1603
|
+
const queryVec = await nlEmbed(query);
|
|
1596
1604
|
// In-memory cosine search — no disk I/O
|
|
1597
1605
|
const vecResults = this.searchEmbeddingsInMemory(repo.id, queryVec, limit, DEFAULT_MAX_SEMANTIC_DISTANCE);
|
|
1598
1606
|
if (vecResults.length === 0)
|
|
@@ -1626,9 +1634,14 @@ export class LocalBackend {
|
|
|
1626
1634
|
*/
|
|
1627
1635
|
async nlSemanticSearch(repo, query, limit) {
|
|
1628
1636
|
try {
|
|
1629
|
-
|
|
1630
|
-
if (!cache || cache.nodeIds.length === 0)
|
|
1631
|
-
|
|
1637
|
+
let cache = this.nlEmbeddingCaches.get(repo.id);
|
|
1638
|
+
if (!cache || cache.nodeIds.length === 0) {
|
|
1639
|
+
// Try loading on demand
|
|
1640
|
+
this.loadNlEmbeddingCache(repo.id);
|
|
1641
|
+
cache = this.nlEmbeddingCaches.get(repo.id);
|
|
1642
|
+
if (!cache || cache.nodeIds.length === 0)
|
|
1643
|
+
return [];
|
|
1644
|
+
}
|
|
1632
1645
|
const { nlEmbed } = await import('../../core/embeddings/nl-embedder.js');
|
|
1633
1646
|
const queryVec = await nlEmbed(query);
|
|
1634
1647
|
const vecResults = this.searchNlEmbeddingsInMemory(repo.id, queryVec, limit, 0.5);
|
|
@@ -1748,8 +1761,8 @@ export class LocalBackend {
|
|
|
1748
1761
|
const cache = this.embeddingCaches.get(repo.id);
|
|
1749
1762
|
if (!cache || cache.nodeIds.length === 0)
|
|
1750
1763
|
return [];
|
|
1751
|
-
const {
|
|
1752
|
-
const queryVec = await
|
|
1764
|
+
const { nlEmbed } = await import('../../core/embeddings/nl-embedder.js');
|
|
1765
|
+
const queryVec = await nlEmbed(query);
|
|
1753
1766
|
const neighbors = this.searchEmbeddingsInMemory(repo.id, queryVec, 5, 0.7);
|
|
1754
1767
|
// Extract symbol names from nodeIds (format: "Label:filePath:name")
|
|
1755
1768
|
return neighbors.map(n => {
|
package/dist/mcp/server.js
CHANGED
|
@@ -13,8 +13,8 @@ import { getResourceDefinitions, getResourceTemplates, readResource } from './re
|
|
|
13
13
|
// the MCP tool descriptions. Hints wasted ~40 tokens per response.
|
|
14
14
|
/** Create a configured MCP Server with all handlers registered (transport-agnostic) */
|
|
15
15
|
export function createMCPServer(backend) {
|
|
16
|
-
// Preload embedding model in background so first query doesn't pay cold-start cost
|
|
17
|
-
import('../core/embeddings/embedder.js').then(m => m.
|
|
16
|
+
// Preload bge-small embedding model in background so first query doesn't pay cold-start cost
|
|
17
|
+
import('../core/embeddings/nl-embedder.js').then(m => m.initNlEmbedder()).catch(() => { });
|
|
18
18
|
const require = createRequire(import.meta.url);
|
|
19
19
|
const pkgVersion = require('../../package.json').version;
|
|
20
20
|
const server = new Server({
|
package/dist/mcp/tools.js
CHANGED
|
@@ -44,6 +44,7 @@ Hybrid ranking: BM25 keyword + semantic vector search, ranked by Reciprocal Rank
|
|
|
44
44
|
limit: { type: 'number', description: 'Max processes to return (default: 5)', default: 5 },
|
|
45
45
|
max_symbols: { type: 'number', description: 'Max symbols per process (default: 10)', default: 10 },
|
|
46
46
|
include_content: { type: 'boolean', description: 'Include full symbol source code (default: false)', default: false },
|
|
47
|
+
exclude_tests: { type: 'boolean', description: 'Exclude test/spec/fixture files from results (default: false)', default: false },
|
|
47
48
|
repo: { type: 'string', description: 'Repository name or path. Omit if only one repo is indexed.' },
|
|
48
49
|
},
|
|
49
50
|
required: ['query'],
|
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@zuvia-software-solutions/code-mapper",
|
|
3
|
-
"version": "2.
|
|
3
|
+
"version": "2.5.1",
|
|
4
4
|
"description": "Graph-powered code intelligence for AI agents. Index any codebase, query via MCP or CLI.",
|
|
5
5
|
"author": "Abhigyan Patwari",
|
|
6
6
|
"license": "PolyForm-Noncommercial-1.0.0",
|
|
@@ -34,10 +34,7 @@
|
|
|
34
34
|
"hooks",
|
|
35
35
|
"scripts",
|
|
36
36
|
"skills",
|
|
37
|
-
"vendor"
|
|
38
|
-
"models/mlx-embedder.py",
|
|
39
|
-
"models/jina-code-0.5b-mlx/model.py",
|
|
40
|
-
"models/jina-code-0.5b-mlx/config.json"
|
|
37
|
+
"vendor"
|
|
41
38
|
],
|
|
42
39
|
"scripts": {
|
|
43
40
|
"build": "tsc",
|
|
@@ -1,73 +0,0 @@
|
|
|
1
|
-
{
|
|
2
|
-
"architectures": [
|
|
3
|
-
"Qwen2ForCausalLM"
|
|
4
|
-
],
|
|
5
|
-
"attention_dropout": 0.0,
|
|
6
|
-
"bos_token_id": 151643,
|
|
7
|
-
"eos_token_id": 151643,
|
|
8
|
-
"hidden_act": "silu",
|
|
9
|
-
"hidden_size": 896,
|
|
10
|
-
"initializer_range": 0.02,
|
|
11
|
-
"intermediate_size": 4864,
|
|
12
|
-
"layer_types": [
|
|
13
|
-
"full_attention",
|
|
14
|
-
"full_attention",
|
|
15
|
-
"full_attention",
|
|
16
|
-
"full_attention",
|
|
17
|
-
"full_attention",
|
|
18
|
-
"full_attention",
|
|
19
|
-
"full_attention",
|
|
20
|
-
"full_attention",
|
|
21
|
-
"full_attention",
|
|
22
|
-
"full_attention",
|
|
23
|
-
"full_attention",
|
|
24
|
-
"full_attention",
|
|
25
|
-
"full_attention",
|
|
26
|
-
"full_attention",
|
|
27
|
-
"full_attention",
|
|
28
|
-
"full_attention",
|
|
29
|
-
"full_attention",
|
|
30
|
-
"full_attention",
|
|
31
|
-
"full_attention",
|
|
32
|
-
"full_attention",
|
|
33
|
-
"full_attention",
|
|
34
|
-
"full_attention",
|
|
35
|
-
"full_attention",
|
|
36
|
-
"full_attention"
|
|
37
|
-
],
|
|
38
|
-
"matryoshka_dims": [
|
|
39
|
-
64,
|
|
40
|
-
128,
|
|
41
|
-
256,
|
|
42
|
-
512,
|
|
43
|
-
896
|
|
44
|
-
],
|
|
45
|
-
"max_position_embeddings": 32768,
|
|
46
|
-
"max_window_layers": 24,
|
|
47
|
-
"model_type": "qwen2",
|
|
48
|
-
"num_attention_heads": 14,
|
|
49
|
-
"num_hidden_layers": 24,
|
|
50
|
-
"num_key_value_heads": 2,
|
|
51
|
-
"prompt_names": [
|
|
52
|
-
"query",
|
|
53
|
-
"passage"
|
|
54
|
-
],
|
|
55
|
-
"rms_norm_eps": 1e-06,
|
|
56
|
-
"rope_scaling": null,
|
|
57
|
-
"rope_theta": 1000000.0,
|
|
58
|
-
"sliding_window": null,
|
|
59
|
-
"task_names": [
|
|
60
|
-
"nl2code",
|
|
61
|
-
"qa",
|
|
62
|
-
"code2code",
|
|
63
|
-
"code2nl",
|
|
64
|
-
"code2completion"
|
|
65
|
-
],
|
|
66
|
-
"tie_word_embeddings": true,
|
|
67
|
-
"tokenizer_class": "Qwen2TokenizerFast",
|
|
68
|
-
"torch_dtype": "bfloat16",
|
|
69
|
-
"transformers_version": "4.53.0",
|
|
70
|
-
"use_cache": true,
|
|
71
|
-
"use_sliding_window": false,
|
|
72
|
-
"vocab_size": 151936
|
|
73
|
-
}
|
|
@@ -1,127 +0,0 @@
|
|
|
1
|
-
|
|
2
|
-
from dataclasses import dataclass
|
|
3
|
-
from typing import Optional, List
|
|
4
|
-
import mlx.core as mx
|
|
5
|
-
import mlx.nn as nn
|
|
6
|
-
|
|
7
|
-
@dataclass
|
|
8
|
-
class ModelArgs:
|
|
9
|
-
hidden_size: int
|
|
10
|
-
num_hidden_layers: int
|
|
11
|
-
intermediate_size: int
|
|
12
|
-
num_attention_heads: int
|
|
13
|
-
rms_norm_eps: float
|
|
14
|
-
vocab_size: int
|
|
15
|
-
num_key_value_heads: int
|
|
16
|
-
max_position_embeddings: int
|
|
17
|
-
rope_theta: float = 1000000.0
|
|
18
|
-
tie_word_embeddings: bool = True
|
|
19
|
-
|
|
20
|
-
class Attention(nn.Module):
|
|
21
|
-
def __init__(self, args):
|
|
22
|
-
super().__init__()
|
|
23
|
-
dim = args.hidden_size
|
|
24
|
-
self.n_heads = args.num_attention_heads
|
|
25
|
-
self.n_kv_heads = args.num_key_value_heads
|
|
26
|
-
self.head_dim = dim // self.n_heads
|
|
27
|
-
self.scale = self.head_dim ** -0.5
|
|
28
|
-
self.rope_theta = args.rope_theta
|
|
29
|
-
self.q_proj = nn.Linear(dim, self.n_heads * self.head_dim, bias=True)
|
|
30
|
-
self.k_proj = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=True)
|
|
31
|
-
self.v_proj = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=True)
|
|
32
|
-
self.o_proj = nn.Linear(self.n_heads * self.head_dim, dim, bias=False)
|
|
33
|
-
|
|
34
|
-
def __call__(self, x, mask=None):
|
|
35
|
-
B, L, D = x.shape
|
|
36
|
-
q = self.q_proj(x).reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
|
|
37
|
-
k = self.k_proj(x).reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
|
|
38
|
-
v = self.v_proj(x).reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
|
|
39
|
-
q = mx.fast.rope(q, self.head_dim, traditional=False, base=self.rope_theta, scale=1.0, offset=0)
|
|
40
|
-
k = mx.fast.rope(k, self.head_dim, traditional=False, base=self.rope_theta, scale=1.0, offset=0)
|
|
41
|
-
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=mask.astype(q.dtype) if mask is not None else None, scale=self.scale)
|
|
42
|
-
return self.o_proj(out.transpose(0, 2, 1, 3).reshape(B, L, -1))
|
|
43
|
-
|
|
44
|
-
class MLP(nn.Module):
|
|
45
|
-
def __init__(self, dim, hidden):
|
|
46
|
-
super().__init__()
|
|
47
|
-
self.gate_proj = nn.Linear(dim, hidden, bias=False)
|
|
48
|
-
self.down_proj = nn.Linear(hidden, dim, bias=False)
|
|
49
|
-
self.up_proj = nn.Linear(dim, hidden, bias=False)
|
|
50
|
-
def __call__(self, x):
|
|
51
|
-
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
|
52
|
-
|
|
53
|
-
class TransformerBlock(nn.Module):
|
|
54
|
-
def __init__(self, args):
|
|
55
|
-
super().__init__()
|
|
56
|
-
self.self_attn = Attention(args)
|
|
57
|
-
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
|
58
|
-
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
|
59
|
-
self.post_attention_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
|
60
|
-
def __call__(self, x, mask=None):
|
|
61
|
-
h = x + self.self_attn(self.input_layernorm(x), mask)
|
|
62
|
-
return h + self.mlp(self.post_attention_layernorm(h))
|
|
63
|
-
|
|
64
|
-
class Qwen2Model(nn.Module):
|
|
65
|
-
def __init__(self, args):
|
|
66
|
-
super().__init__()
|
|
67
|
-
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
|
68
|
-
self.layers = [TransformerBlock(args) for _ in range(args.num_hidden_layers)]
|
|
69
|
-
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
|
70
|
-
def __call__(self, inputs, mask=None):
|
|
71
|
-
h = self.embed_tokens(inputs)
|
|
72
|
-
for layer in self.layers:
|
|
73
|
-
h = layer(h, mask)
|
|
74
|
-
return self.norm(h)
|
|
75
|
-
|
|
76
|
-
class JinaCodeEmbeddingModel(nn.Module):
|
|
77
|
-
def __init__(self, config):
|
|
78
|
-
super().__init__()
|
|
79
|
-
args = ModelArgs(
|
|
80
|
-
hidden_size=config["hidden_size"],
|
|
81
|
-
num_hidden_layers=config["num_hidden_layers"],
|
|
82
|
-
intermediate_size=config["intermediate_size"],
|
|
83
|
-
num_attention_heads=config["num_attention_heads"],
|
|
84
|
-
rms_norm_eps=config["rms_norm_eps"],
|
|
85
|
-
vocab_size=config["vocab_size"],
|
|
86
|
-
num_key_value_heads=config["num_key_value_heads"],
|
|
87
|
-
max_position_embeddings=config["max_position_embeddings"],
|
|
88
|
-
rope_theta=config.get("rope_theta", 1000000.0),
|
|
89
|
-
)
|
|
90
|
-
self.model = Qwen2Model(args)
|
|
91
|
-
self.config = config
|
|
92
|
-
|
|
93
|
-
def __call__(self, input_ids, attention_mask=None):
|
|
94
|
-
B, L = input_ids.shape
|
|
95
|
-
causal = mx.tril(mx.ones((L, L)))
|
|
96
|
-
causal = mx.where(causal == 0, -1e4, 0.0)[None, None, :, :]
|
|
97
|
-
if attention_mask is not None:
|
|
98
|
-
pad = mx.where(attention_mask == 0, -1e4, 0.0)[:, None, None, :]
|
|
99
|
-
mask = causal + pad
|
|
100
|
-
else:
|
|
101
|
-
mask = causal
|
|
102
|
-
h = self.model(input_ids, mask)
|
|
103
|
-
if attention_mask is not None:
|
|
104
|
-
seq_lens = mx.sum(attention_mask.astype(mx.int32), axis=1) - 1
|
|
105
|
-
embs = h[mx.arange(B), seq_lens]
|
|
106
|
-
else:
|
|
107
|
-
embs = h[:, -1, :]
|
|
108
|
-
norms = mx.linalg.norm(embs, axis=1, keepdims=True)
|
|
109
|
-
return embs / norms
|
|
110
|
-
|
|
111
|
-
def encode(self, texts, tokenizer, max_length=8192, truncate_dim=None, task="nl2code", prompt_type="query"):
|
|
112
|
-
PREFIXES = {"nl2code": {"query": "Find the most relevant code snippet given the following query:\n", "passage": "Candidate code snippet:\n"}}
|
|
113
|
-
prefix = PREFIXES.get(task, {}).get(prompt_type, "")
|
|
114
|
-
if prefix:
|
|
115
|
-
texts = [prefix + t for t in texts]
|
|
116
|
-
encodings = tokenizer.encode_batch(texts)
|
|
117
|
-
ml = min(max_length, max(len(e.ids) for e in encodings))
|
|
118
|
-
iids, amask = [], []
|
|
119
|
-
for e in encodings:
|
|
120
|
-
ids = e.ids[:ml]; m = e.attention_mask[:ml]; p = ml - len(ids)
|
|
121
|
-
if p > 0: ids = ids + [0]*p; m = m + [0]*p
|
|
122
|
-
iids.append(ids); amask.append(m)
|
|
123
|
-
embs = self(mx.array(iids), mx.array(amask))
|
|
124
|
-
if truncate_dim:
|
|
125
|
-
embs = embs[:, :truncate_dim]
|
|
126
|
-
embs = embs / mx.linalg.norm(embs, axis=1, keepdims=True)
|
|
127
|
-
return embs
|