tricoder 1.2.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tricoder/__about__.py +6 -0
- tricoder/__init__.py +19 -0
- tricoder/calibration.py +276 -0
- tricoder/cli.py +890 -0
- tricoder/context_view.py +228 -0
- tricoder/data_loader.py +144 -0
- tricoder/extract.py +622 -0
- tricoder/fusion.py +203 -0
- tricoder/git_tracker.py +203 -0
- tricoder/gpu_utils.py +414 -0
- tricoder/graph_view.py +583 -0
- tricoder/model.py +476 -0
- tricoder/optimize.py +263 -0
- tricoder/subtoken_utils.py +196 -0
- tricoder/train.py +857 -0
- tricoder/typed_view.py +313 -0
- tricoder-1.2.8.dist-info/METADATA +306 -0
- tricoder-1.2.8.dist-info/RECORD +22 -0
- tricoder-1.2.8.dist-info/WHEEL +4 -0
- tricoder-1.2.8.dist-info/entry_points.txt +3 -0
- tricoder-1.2.8.dist-info/licenses/LICENSE +56 -0
- tricoder-1.2.8.dist-info/licenses/LICENSE_COMMERCIAL.md +68 -0
tricoder/train.py
ADDED
|
@@ -0,0 +1,857 @@
|
|
|
1
|
+
"""Training pipeline for TriVector Code Intelligence."""
|
|
2
|
+
import json
|
|
3
|
+
import os
|
|
4
|
+
import time
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
import click
|
|
8
|
+
import numpy as np
|
|
9
|
+
from annoy import AnnoyIndex
|
|
10
|
+
from rich import box
|
|
11
|
+
from rich.console import Console
|
|
12
|
+
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TimeElapsedColumn
|
|
13
|
+
from rich.table import Table
|
|
14
|
+
|
|
15
|
+
from .calibration import split_edges, learn_temperature
|
|
16
|
+
from .context_view import compute_context_view
|
|
17
|
+
from .data_loader import load_nodes, load_edges, load_types
|
|
18
|
+
from .fusion import fuse_embeddings, iterative_embedding_smoothing
|
|
19
|
+
from .graph_view import compute_graph_view
|
|
20
|
+
from .typed_view import compute_typed_view
|
|
21
|
+
|
|
22
|
+
console = Console()
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def estimate_training_time(num_nodes: int, num_edges: int, num_types: Optional[int],
|
|
26
|
+
graph_dim: int, context_dim: int, typed_dim: int,
|
|
27
|
+
num_walks: int, walk_length: int, final_dim: int,
|
|
28
|
+
n_jobs: int = 1, use_gpu: bool = False) -> str:
|
|
29
|
+
"""
|
|
30
|
+
Estimate training time based on data size and parameters.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
use_gpu: Whether GPU acceleration will be used (affects SVD/PCA/matrix operation times)
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
Estimated time as formatted string
|
|
37
|
+
"""
|
|
38
|
+
# Check if GPU is actually available
|
|
39
|
+
gpu_available = False
|
|
40
|
+
gpu_speedup_factor = 1.0
|
|
41
|
+
if use_gpu:
|
|
42
|
+
try:
|
|
43
|
+
from .gpu_utils import GPUAccelerator
|
|
44
|
+
gpu_accelerator = GPUAccelerator(use_gpu=True)
|
|
45
|
+
if gpu_accelerator.use_gpu:
|
|
46
|
+
gpu_available = True
|
|
47
|
+
# GPU speedup factors based on typical performance:
|
|
48
|
+
# - SVD/PCA: 5-20x faster (use conservative 8x)
|
|
49
|
+
# - Matrix operations: 10-50x faster (use conservative 15x)
|
|
50
|
+
# - Sparse operations: 3-10x faster (use conservative 5x)
|
|
51
|
+
# Use different factors for different operations
|
|
52
|
+
gpu_speedup_factor = 1.0 # Will be applied per-operation
|
|
53
|
+
except Exception:
|
|
54
|
+
pass # GPU not available, use CPU estimates
|
|
55
|
+
|
|
56
|
+
# More realistic time estimates based on actual performance
|
|
57
|
+
# Accounts for parallelization efficiency and overhead
|
|
58
|
+
|
|
59
|
+
# Parallelization efficiency factor (diminishing returns with more cores)
|
|
60
|
+
# More cores help but not linearly due to overhead
|
|
61
|
+
efficiency = min(0.85, 0.3 + 0.55 * (n_jobs / max(n_jobs, 8))) # Cap efficiency at ~85%
|
|
62
|
+
effective_cores = n_jobs * efficiency
|
|
63
|
+
|
|
64
|
+
# Graph view: PPMI + SVD
|
|
65
|
+
# PPMI computation: sparse matrix operations, scales with edges
|
|
66
|
+
# SVD: matrix decomposition, scales with nodes and dimensions
|
|
67
|
+
# Account for subtoken expansion (roughly doubles nodes)
|
|
68
|
+
expanded_nodes = num_nodes * 2.5 # Account for subtokens and expansion
|
|
69
|
+
ppmi_time = (expanded_nodes * num_edges ** 0.5) / (5000 * effective_cores)
|
|
70
|
+
# SVD benefits significantly from GPU (8x speedup)
|
|
71
|
+
svd_base_time = (expanded_nodes * graph_dim ** 2) / (20000 * effective_cores)
|
|
72
|
+
svd_time = svd_base_time / (8.0 if gpu_available else 1.0)
|
|
73
|
+
graph_time = ppmi_time + svd_time
|
|
74
|
+
|
|
75
|
+
# Context view: Random walks + Word2Vec
|
|
76
|
+
# Random walks: parallelized per node, but has overhead
|
|
77
|
+
total_walks = expanded_nodes * num_walks
|
|
78
|
+
walk_time = (total_walks * walk_length) / (8000 * effective_cores) # More realistic
|
|
79
|
+
|
|
80
|
+
# Word2Vec: training is CPU-intensive, benefits from parallelization
|
|
81
|
+
# window=10, epochs=5, negative=5
|
|
82
|
+
# Word2Vec doesn't benefit much from GPU (mostly CPU-bound)
|
|
83
|
+
w2v_time = (total_walks * walk_length * 10 * 5) / (150000 * effective_cores)
|
|
84
|
+
context_time = walk_time + w2v_time
|
|
85
|
+
|
|
86
|
+
# Typed view: PPMI + SVD (if available)
|
|
87
|
+
typed_time = 0
|
|
88
|
+
if num_types:
|
|
89
|
+
# Account for type expansion
|
|
90
|
+
expanded_types = num_types * 1.3 # Type expansion adds ~30% more types
|
|
91
|
+
typed_ppmi_time = (expanded_nodes * expanded_types ** 0.5) / (5000 * effective_cores)
|
|
92
|
+
# SVD benefits significantly from GPU (8x speedup)
|
|
93
|
+
typed_svd_base_time = (expanded_nodes * typed_dim ** 2) / (20000 * effective_cores)
|
|
94
|
+
typed_svd_time = typed_svd_base_time / (8.0 if gpu_available else 1.0)
|
|
95
|
+
typed_time = typed_ppmi_time + typed_svd_time
|
|
96
|
+
|
|
97
|
+
# Fusion: PCA
|
|
98
|
+
# PCA is memory-bound and benefits significantly from GPU (8x speedup)
|
|
99
|
+
total_input_dim = graph_dim + context_dim + (typed_dim if num_types else 0)
|
|
100
|
+
fusion_base_time = (expanded_nodes * total_input_dim * final_dim) / (30000 * effective_cores)
|
|
101
|
+
fusion_time = fusion_base_time / (8.0 if gpu_available else 1.0)
|
|
102
|
+
|
|
103
|
+
# Embedding smoothing: iterative neighbor averaging
|
|
104
|
+
# Sparse matrix operations benefit from GPU (5x speedup)
|
|
105
|
+
smoothing_base_time = (expanded_nodes * num_edges * 2) / (10000 * effective_cores) # 2 iterations
|
|
106
|
+
smoothing_time = smoothing_base_time / (5.0 if gpu_available else 1.0)
|
|
107
|
+
|
|
108
|
+
# Temperature calibration: grid search with parallel evaluation
|
|
109
|
+
tau_candidates = 50 # default
|
|
110
|
+
val_edges_est = int(num_edges * 0.2) # ~20% for validation
|
|
111
|
+
calibration_time = (tau_candidates * val_edges_est) / (5000 * effective_cores)
|
|
112
|
+
|
|
113
|
+
# ANN index building (single-threaded mostly, no GPU benefit)
|
|
114
|
+
ann_time = (num_nodes * final_dim * 10) / 200000 # 10 trees, less parallelizable
|
|
115
|
+
|
|
116
|
+
# I/O overhead (saving files)
|
|
117
|
+
io_time = 1.5
|
|
118
|
+
|
|
119
|
+
# Data loading overhead
|
|
120
|
+
load_time = 0.5
|
|
121
|
+
|
|
122
|
+
total_seconds = (graph_time + context_time + typed_time + fusion_time +
|
|
123
|
+
smoothing_time + calibration_time + ann_time + io_time + load_time)
|
|
124
|
+
|
|
125
|
+
# Format time estimate
|
|
126
|
+
if total_seconds < 60:
|
|
127
|
+
return f"{total_seconds:.1f}s"
|
|
128
|
+
elif total_seconds < 3600:
|
|
129
|
+
minutes = total_seconds / 60
|
|
130
|
+
return f"{minutes:.1f}m"
|
|
131
|
+
else:
|
|
132
|
+
hours = total_seconds / 3600
|
|
133
|
+
minutes = (total_seconds % 3600) / 60
|
|
134
|
+
return f"{int(hours)}h {int(minutes)}m"
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def compute_default_dimensions(num_nodes: int, num_edges: int, num_types: Optional[int] = None) -> dict:
|
|
138
|
+
"""
|
|
139
|
+
Compute default dimensions based on data characteristics.
|
|
140
|
+
|
|
141
|
+
Returns:
|
|
142
|
+
Dictionary with graph_dim, context_dim, typed_dim, final_dim
|
|
143
|
+
"""
|
|
144
|
+
# Graph dimension: based on number of nodes
|
|
145
|
+
# Use log-based heuristic for better scaling: log2(nodes) * 8, clamped to reasonable range
|
|
146
|
+
# Small codebases (<100 nodes): ~32-48 dims
|
|
147
|
+
# Medium codebases (100-1000 nodes): ~48-64 dims
|
|
148
|
+
# Large codebases (>1000 nodes): ~64-128 dims
|
|
149
|
+
if num_nodes < 50:
|
|
150
|
+
graph_dim = 32
|
|
151
|
+
elif num_nodes < 200:
|
|
152
|
+
graph_dim = max(32, min(64, int(np.log2(num_nodes) * 8)))
|
|
153
|
+
elif num_nodes < 1000:
|
|
154
|
+
graph_dim = max(48, min(96, int(np.log2(num_nodes) * 7)))
|
|
155
|
+
else:
|
|
156
|
+
graph_dim = max(64, min(128, int(np.log2(num_nodes) * 6)))
|
|
157
|
+
|
|
158
|
+
# Context dimension: match graph dimension for balanced fusion
|
|
159
|
+
context_dim = graph_dim
|
|
160
|
+
|
|
161
|
+
# Typed dimension: based on number of type tokens if available
|
|
162
|
+
if num_types:
|
|
163
|
+
if num_types < 20:
|
|
164
|
+
typed_dim = 16
|
|
165
|
+
elif num_types < 50:
|
|
166
|
+
typed_dim = max(16, min(32, int(np.sqrt(num_types) * 2)))
|
|
167
|
+
else:
|
|
168
|
+
typed_dim = max(32, min(64, int(np.sqrt(num_types) * 2)))
|
|
169
|
+
else:
|
|
170
|
+
typed_dim = 32 # Default when no types
|
|
171
|
+
|
|
172
|
+
# Final dimension: based on total input dimensions and number of nodes
|
|
173
|
+
# Should be less than number of nodes (for PCA to work)
|
|
174
|
+
# Use 50-70% of total input dimensions, but respect node count limit
|
|
175
|
+
total_input_dim = graph_dim + context_dim + (typed_dim if num_types else 0)
|
|
176
|
+
final_dim_candidate = int(total_input_dim * 0.6) # 60% of input dims
|
|
177
|
+
final_dim = max(32, min(num_nodes - 1, min(256, final_dim_candidate)))
|
|
178
|
+
|
|
179
|
+
return {
|
|
180
|
+
'graph_dim': graph_dim,
|
|
181
|
+
'context_dim': context_dim,
|
|
182
|
+
'typed_dim': typed_dim,
|
|
183
|
+
'final_dim': final_dim
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def train_model(nodes_path: str,
|
|
188
|
+
edges_path: str,
|
|
189
|
+
types_path: Optional[str],
|
|
190
|
+
output_dir: str,
|
|
191
|
+
graph_dim: Optional[int] = None,
|
|
192
|
+
context_dim: Optional[int] = None,
|
|
193
|
+
typed_dim: Optional[int] = None,
|
|
194
|
+
final_dim: Optional[int] = None,
|
|
195
|
+
num_walks: int = 10,
|
|
196
|
+
walk_length: int = 80,
|
|
197
|
+
train_ratio: float = 0.8,
|
|
198
|
+
random_state: int = 42,
|
|
199
|
+
fast_mode: bool = False,
|
|
200
|
+
use_gpu: bool = False):
|
|
201
|
+
"""
|
|
202
|
+
Train TriVector Code Intelligence model.
|
|
203
|
+
|
|
204
|
+
Optimizations applied for faster training (without sacrificing much quality):
|
|
205
|
+
- Vectorized embedding smoothing (uses sparse matrix operations, 2-5x faster)
|
|
206
|
+
- Reduced SVD iterations (5 instead of 10, ~2x faster)
|
|
207
|
+
- Reduced Word2Vec epochs (3 instead of 5, ~1.7x faster)
|
|
208
|
+
- Reduced Word2Vec window (7 instead of 10, ~1.4x faster)
|
|
209
|
+
- Reduced Word2Vec negative samples (3 instead of 5, ~1.7x faster)
|
|
210
|
+
- Reduced temperature calibration candidates (30 instead of 50, ~1.7x faster)
|
|
211
|
+
- Fast mode: further optimizations:
|
|
212
|
+
* Halves random walk parameters
|
|
213
|
+
* Reduces smoothing iterations (1 instead of 2)
|
|
214
|
+
* Reduces context window (3 instead of 5)
|
|
215
|
+
* Reduces call graph depth (2 instead of 3)
|
|
216
|
+
* Reduces ANN trees (7 instead of 10)
|
|
217
|
+
* Uses less validation data for calibration
|
|
218
|
+
* Fewer tau candidates (20 instead of 30)
|
|
219
|
+
* Fewer negative samples (3 instead of 5)
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
nodes_path: path to nodes.jsonl
|
|
223
|
+
edges_path: path to edges.jsonl
|
|
224
|
+
types_path: path to types.jsonl (optional)
|
|
225
|
+
output_dir: output directory for model
|
|
226
|
+
graph_dim: dimensionality for graph view
|
|
227
|
+
context_dim: dimensionality for context view
|
|
228
|
+
typed_dim: dimensionality for typed view
|
|
229
|
+
final_dim: final fused embedding dimensionality
|
|
230
|
+
num_walks: number of random walks per node
|
|
231
|
+
walk_length: length of each random walk
|
|
232
|
+
train_ratio: ratio of edges for training (rest for calibration)
|
|
233
|
+
random_state: random seed
|
|
234
|
+
fast_mode: if True, reduces walk parameters for faster training (lower quality)
|
|
235
|
+
use_gpu: if True, attempt GPU acceleration (requires CuPy and CUDA-capable GPU)
|
|
236
|
+
"""
|
|
237
|
+
# Initialize GPU accelerator if requested
|
|
238
|
+
gpu_accelerator = None
|
|
239
|
+
if use_gpu:
|
|
240
|
+
from .gpu_utils import GPUAccelerator, TORCH_AVAILABLE, TORCH_VERSION, diagnose_gpu_support
|
|
241
|
+
import platform
|
|
242
|
+
|
|
243
|
+
is_mac = platform.system() == 'Darwin'
|
|
244
|
+
if is_mac and not TORCH_AVAILABLE:
|
|
245
|
+
console.print(f"[yellow]⚠ PyTorch not installed. Install with: pip install torch[/yellow]")
|
|
246
|
+
console.print(f"[yellow] GPU acceleration requires PyTorch for Mac MPS support. Using CPU.[/yellow]\n")
|
|
247
|
+
else:
|
|
248
|
+
gpu_accelerator = GPUAccelerator(use_gpu=True)
|
|
249
|
+
if gpu_accelerator.use_gpu:
|
|
250
|
+
backend_name = "CUDA (NVIDIA)" if gpu_accelerator.device_type == 'cuda' else "MPS (Mac)"
|
|
251
|
+
console.print(f"[bold green]✓ GPU acceleration enabled ({backend_name})[/bold green]\n")
|
|
252
|
+
else:
|
|
253
|
+
if is_mac:
|
|
254
|
+
# Provide detailed diagnostics
|
|
255
|
+
diagnostics = diagnose_gpu_support()
|
|
256
|
+
console.print(f"[yellow]⚠ GPU acceleration requested but MPS not available.[/yellow]")
|
|
257
|
+
if not TORCH_AVAILABLE:
|
|
258
|
+
console.print(f"[yellow] PyTorch is not installed. Install with: pip install torch[/yellow]")
|
|
259
|
+
elif TORCH_VERSION:
|
|
260
|
+
# Simple version check (compare first two parts)
|
|
261
|
+
try:
|
|
262
|
+
version_parts = TORCH_VERSION.split('.')
|
|
263
|
+
major = int(version_parts[0])
|
|
264
|
+
minor = int(version_parts[1]) if len(version_parts) > 1 else 0
|
|
265
|
+
if major < 1 or (major == 1 and minor < 12):
|
|
266
|
+
console.print(f"[yellow] PyTorch version {TORCH_VERSION} is too old. Upgrade to 1.12+ for MPS support.[/yellow]")
|
|
267
|
+
console.print(f"[yellow] Upgrade with: pip install --upgrade torch[/yellow]")
|
|
268
|
+
except (ValueError, IndexError):
|
|
269
|
+
pass # Skip version check if parsing fails
|
|
270
|
+
elif 'mps_unavailable_reason' in diagnostics:
|
|
271
|
+
console.print(f"[yellow] {diagnostics['mps_unavailable_reason']}[/yellow]")
|
|
272
|
+
else:
|
|
273
|
+
console.print(f"[yellow] Requirements: macOS 12.3+, Apple Silicon (M1/M2/M3), PyTorch 1.12+[/yellow]")
|
|
274
|
+
console.print(f"[yellow] Using CPU.[/yellow]\n")
|
|
275
|
+
else:
|
|
276
|
+
console.print(f"[yellow]⚠ GPU acceleration requested but not available, using CPU[/yellow]\n")
|
|
277
|
+
|
|
278
|
+
# Set random seeds
|
|
279
|
+
np.random.seed(random_state)
|
|
280
|
+
|
|
281
|
+
# Record start time
|
|
282
|
+
start_time = time.time()
|
|
283
|
+
|
|
284
|
+
console.print("\n[bold cyan]TriVector Code Intelligence - Training Pipeline[/bold cyan]\n")
|
|
285
|
+
|
|
286
|
+
with Progress(
|
|
287
|
+
SpinnerColumn(),
|
|
288
|
+
TextColumn("[progress.description]{task.description}"),
|
|
289
|
+
BarColumn(),
|
|
290
|
+
TimeElapsedColumn(),
|
|
291
|
+
console=console
|
|
292
|
+
) as progress:
|
|
293
|
+
# Load data first to compute dimensions
|
|
294
|
+
task1 = progress.add_task("[cyan]Loading nodes...", total=None)
|
|
295
|
+
node_to_idx, node_metadata, node_subtokens, node_file_info = load_nodes(nodes_path)
|
|
296
|
+
num_nodes = len(node_to_idx)
|
|
297
|
+
progress.update(task1, completed=True)
|
|
298
|
+
console.print(f" [dim]✓ Loaded {num_nodes:,} nodes[/dim]")
|
|
299
|
+
|
|
300
|
+
if num_nodes == 0:
|
|
301
|
+
raise ValueError("No nodes found in input file")
|
|
302
|
+
|
|
303
|
+
task1b = progress.add_task("[cyan]Loading edges...", total=None)
|
|
304
|
+
edges, _ = load_edges(edges_path, node_to_idx)
|
|
305
|
+
num_edges = len(edges)
|
|
306
|
+
progress.update(task1b, completed=True)
|
|
307
|
+
console.print(f" [dim]✓ Loaded {num_edges:,} edges[/dim]")
|
|
308
|
+
|
|
309
|
+
if num_edges == 0:
|
|
310
|
+
raise ValueError("No edges found in input file")
|
|
311
|
+
|
|
312
|
+
# Load types if available
|
|
313
|
+
node_types = None
|
|
314
|
+
type_to_idx = None
|
|
315
|
+
num_types = None
|
|
316
|
+
if types_path and os.path.exists(types_path):
|
|
317
|
+
task2 = progress.add_task("[cyan]Loading types...", total=None)
|
|
318
|
+
node_types, type_to_idx = load_types(types_path, node_to_idx)
|
|
319
|
+
num_types = len(type_to_idx)
|
|
320
|
+
progress.update(task2, completed=True)
|
|
321
|
+
console.print(f" [dim]✓ Loaded {num_types:,} type tokens[/dim]")
|
|
322
|
+
else:
|
|
323
|
+
console.print(f" [dim]⊘ No types file provided[/dim]")
|
|
324
|
+
|
|
325
|
+
# Compute default dimensions if not provided
|
|
326
|
+
defaults = compute_default_dimensions(num_nodes, len(edges), num_types)
|
|
327
|
+
|
|
328
|
+
if graph_dim is None:
|
|
329
|
+
graph_dim = defaults['graph_dim']
|
|
330
|
+
graph_dim_source = "[dim](computed)[/dim]"
|
|
331
|
+
else:
|
|
332
|
+
graph_dim_source = ""
|
|
333
|
+
|
|
334
|
+
if context_dim is None:
|
|
335
|
+
context_dim = defaults['context_dim']
|
|
336
|
+
context_dim_source = "[dim](computed)[/dim]"
|
|
337
|
+
else:
|
|
338
|
+
context_dim_source = ""
|
|
339
|
+
|
|
340
|
+
if typed_dim is None:
|
|
341
|
+
typed_dim = defaults['typed_dim']
|
|
342
|
+
typed_dim_source = "[dim](computed)[/dim]"
|
|
343
|
+
else:
|
|
344
|
+
typed_dim_source = ""
|
|
345
|
+
|
|
346
|
+
if final_dim is None:
|
|
347
|
+
final_dim = defaults['final_dim']
|
|
348
|
+
final_dim_source = "[dim](computed)[/dim]"
|
|
349
|
+
else:
|
|
350
|
+
final_dim_source = ""
|
|
351
|
+
|
|
352
|
+
# Get number of workers for time estimation
|
|
353
|
+
from multiprocessing import cpu_count
|
|
354
|
+
n_jobs_est = max(1, cpu_count() - 1)
|
|
355
|
+
|
|
356
|
+
# Estimate training time (account for GPU if available)
|
|
357
|
+
estimated_time = estimate_training_time(
|
|
358
|
+
num_nodes, len(edges), num_types,
|
|
359
|
+
graph_dim, context_dim, typed_dim,
|
|
360
|
+
num_walks, walk_length, final_dim,
|
|
361
|
+
n_jobs_est, use_gpu=use_gpu
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
# Display configuration
|
|
365
|
+
config_table = Table(box=box.ROUNDED, show_header=False, title="Configuration")
|
|
366
|
+
config_table.add_column("Parameter", style="cyan", width=25)
|
|
367
|
+
config_table.add_column("Value", style="white", width=15)
|
|
368
|
+
config_table.add_column("Source", style="dim", width=12)
|
|
369
|
+
config_table.add_row("Graph Dimension", str(graph_dim), graph_dim_source)
|
|
370
|
+
config_table.add_row("Context Dimension", str(context_dim), context_dim_source)
|
|
371
|
+
config_table.add_row("Typed Dimension", str(typed_dim), typed_dim_source)
|
|
372
|
+
config_table.add_row("Final Dimension", str(final_dim), final_dim_source)
|
|
373
|
+
config_table.add_row("Random Walks", str(num_walks), "")
|
|
374
|
+
config_table.add_row("Walk Length", str(walk_length), "")
|
|
375
|
+
config_table.add_row("Train Ratio", f"{train_ratio:.2f}", "")
|
|
376
|
+
config_table.add_row("Random State", str(random_state), "")
|
|
377
|
+
config_table.add_row("", "", "") # Separator
|
|
378
|
+
config_table.add_row("Data: Nodes", str(num_nodes), "")
|
|
379
|
+
config_table.add_row("Data: Edges", str(len(edges)), "")
|
|
380
|
+
if num_types:
|
|
381
|
+
config_table.add_row("Data: Type Tokens", str(num_types), "")
|
|
382
|
+
config_table.add_row("", "", "") # Separator
|
|
383
|
+
gpu_status = "GPU" if (use_gpu and gpu_accelerator and gpu_accelerator.use_gpu) else "CPU"
|
|
384
|
+
config_table.add_row("[bold]Estimated Time[/bold]", f"[bold green]{estimated_time}[/bold green]", f"[dim]({gpu_status})[/dim]")
|
|
385
|
+
config_table.add_row("Workers", str(n_jobs_est), "")
|
|
386
|
+
console.print(config_table)
|
|
387
|
+
console.print()
|
|
388
|
+
|
|
389
|
+
with Progress(
|
|
390
|
+
SpinnerColumn(),
|
|
391
|
+
TextColumn("[progress.description]{task.description}"),
|
|
392
|
+
BarColumn(),
|
|
393
|
+
TimeElapsedColumn(),
|
|
394
|
+
console=console
|
|
395
|
+
) as progress:
|
|
396
|
+
|
|
397
|
+
# Split edges for calibration
|
|
398
|
+
# In fast mode, use less validation data for faster calibration
|
|
399
|
+
calibration_train_ratio = train_ratio if not fast_mode else min(0.9, train_ratio + 0.05)
|
|
400
|
+
task3 = progress.add_task("[cyan]Splitting edges for training/validation...", total=None)
|
|
401
|
+
train_edges, val_edges = split_edges(edges, calibration_train_ratio, random_state)
|
|
402
|
+
progress.update(task3, completed=True)
|
|
403
|
+
console.print(f" [dim]✓ Split edges: {len(train_edges):,} training, {len(val_edges):,} validation "
|
|
404
|
+
f"({calibration_train_ratio:.1%}/{1-calibration_train_ratio:.1%})[/dim]")
|
|
405
|
+
|
|
406
|
+
# Get number of workers (all cores - 1, but use all cores on Windows for better performance)
|
|
407
|
+
from multiprocessing import cpu_count
|
|
408
|
+
import platform
|
|
409
|
+
if platform.system() == 'Windows':
|
|
410
|
+
# On Windows, use all cores since spawn method handles it well
|
|
411
|
+
n_jobs = cpu_count()
|
|
412
|
+
else:
|
|
413
|
+
n_jobs = max(1, cpu_count() - 1)
|
|
414
|
+
|
|
415
|
+
# Apply fast mode optimizations
|
|
416
|
+
smoothing_iterations = 2
|
|
417
|
+
context_window_size = 5
|
|
418
|
+
call_graph_depth = 3
|
|
419
|
+
ann_trees = 10
|
|
420
|
+
|
|
421
|
+
if fast_mode:
|
|
422
|
+
# Reduce random walk parameters for faster training
|
|
423
|
+
num_walks = max(5, num_walks // 2) # Reduce walks by half (min 5)
|
|
424
|
+
walk_length = max(40, walk_length // 2) # Reduce walk length by half (min 40)
|
|
425
|
+
smoothing_iterations = 1 # Reduce smoothing iterations
|
|
426
|
+
context_window_size = 3 # Reduce context window
|
|
427
|
+
call_graph_depth = 2 # Reduce call graph expansion depth
|
|
428
|
+
ann_trees = 7 # Reduce ANN trees (slightly faster indexing)
|
|
429
|
+
console.print(f"[yellow]Fast mode: Using {num_walks} walks of length {walk_length}, "
|
|
430
|
+
f"{smoothing_iterations} smoothing iteration(s), "
|
|
431
|
+
f"context window {context_window_size}, call depth {call_graph_depth}[/yellow]\n")
|
|
432
|
+
|
|
433
|
+
# Create idx_to_node mapping
|
|
434
|
+
idx_to_node = {idx: node_id for node_id, idx in node_to_idx.items()}
|
|
435
|
+
|
|
436
|
+
# Step 1: Compute graph view with all enhancements
|
|
437
|
+
# This task runs completely and sequentially before moving to the next one
|
|
438
|
+
# This ensures full parallelization can be used for graph view operations
|
|
439
|
+
console.print(f"\n[bold cyan]Step 1/5: Graph View[/bold cyan]")
|
|
440
|
+
console.print(f" [dim]Computing graph embeddings (dim={graph_dim}) with {n_jobs} workers...[/dim]")
|
|
441
|
+
|
|
442
|
+
task4a = progress.add_task("[cyan] → Expanding call graph...", total=None)
|
|
443
|
+
task4 = progress.add_task(
|
|
444
|
+
f"[cyan] → Building adjacency matrix & computing PPMI...", total=None)
|
|
445
|
+
X_graph, svd_components_graph, final_num_nodes, subtoken_to_idx, expanded_edges = compute_graph_view(
|
|
446
|
+
train_edges, num_nodes, graph_dim, random_state, n_jobs=n_jobs,
|
|
447
|
+
node_to_idx=node_to_idx,
|
|
448
|
+
node_subtokens=node_subtokens,
|
|
449
|
+
node_file_info=node_file_info,
|
|
450
|
+
node_metadata=node_metadata,
|
|
451
|
+
idx_to_node=idx_to_node,
|
|
452
|
+
expand_calls=True,
|
|
453
|
+
add_subtokens=True,
|
|
454
|
+
add_hierarchy=True,
|
|
455
|
+
add_context=True,
|
|
456
|
+
context_window=context_window_size,
|
|
457
|
+
max_depth=call_graph_depth,
|
|
458
|
+
gpu_accelerator=gpu_accelerator
|
|
459
|
+
)
|
|
460
|
+
progress.update(task4a, completed=True)
|
|
461
|
+
progress.update(task4, completed=True)
|
|
462
|
+
|
|
463
|
+
num_subtokens = len(subtoken_to_idx) if subtoken_to_idx else 0
|
|
464
|
+
expanded_edges_count = len(expanded_edges)
|
|
465
|
+
console.print(f" [dim]✓ Expanded to {final_num_nodes:,} nodes ({num_nodes:,} original + {num_subtokens:,} subtokens)[/dim]")
|
|
466
|
+
console.print(f" [dim]✓ Expanded to {expanded_edges_count:,} edges ({len(train_edges):,} original)[/dim]")
|
|
467
|
+
console.print(f" [dim]✓ Computed graph embeddings: {X_graph.shape}[/dim]")
|
|
468
|
+
# Graph view is now complete - all resources released before next task
|
|
469
|
+
|
|
470
|
+
# Step 2: Compute context view using expanded edges (includes subtokens)
|
|
471
|
+
# This task runs completely after graph view finishes
|
|
472
|
+
# This ensures full parallelization can be used for context view operations
|
|
473
|
+
console.print(f"\n[bold cyan]Step 2/5: Context View[/bold cyan]")
|
|
474
|
+
console.print(f" [dim]Computing context embeddings (dim={context_dim}) with {n_jobs} workers...[/dim]")
|
|
475
|
+
console.print(f" [dim]Parameters: {num_walks} walks/node, length={walk_length}[/dim]")
|
|
476
|
+
|
|
477
|
+
task5a = progress.add_task("[cyan] → Generating random walks...", total=None)
|
|
478
|
+
task5 = progress.add_task(
|
|
479
|
+
f"[cyan] → Training Word2Vec model...", total=None)
|
|
480
|
+
# Use expanded_edges which includes subtokens and all enhancements
|
|
481
|
+
X_w2v, word2vec_kv = compute_context_view(
|
|
482
|
+
expanded_edges, final_num_nodes, context_dim, num_walks, walk_length, random_state, n_jobs=n_jobs
|
|
483
|
+
)
|
|
484
|
+
progress.update(task5a, completed=True)
|
|
485
|
+
progress.update(task5, completed=True)
|
|
486
|
+
|
|
487
|
+
total_walks = final_num_nodes * num_walks
|
|
488
|
+
console.print(f" [dim]✓ Generated {total_walks:,} random walks[/dim]")
|
|
489
|
+
console.print(f" [dim]✓ Trained Word2Vec model (vocab size: {len(word2vec_kv.key_to_index):,})[/dim]")
|
|
490
|
+
console.print(f" [dim]✓ Computed context embeddings: {X_w2v.shape}[/dim]")
|
|
491
|
+
# Context view is now complete - all resources released before next task
|
|
492
|
+
|
|
493
|
+
# Step 3: Compute typed view if available (with type expansion)
|
|
494
|
+
# This task runs completely after context view finishes
|
|
495
|
+
X_types = None
|
|
496
|
+
svd_components_types = None
|
|
497
|
+
final_type_to_idx = None
|
|
498
|
+
if node_types is not None and type_to_idx is not None:
|
|
499
|
+
console.print(f"\n[bold cyan]Step 3/5: Typed View[/bold cyan]")
|
|
500
|
+
console.print(f" [dim]Computing typed embeddings (dim={typed_dim}) with {n_jobs} workers...[/dim]")
|
|
501
|
+
|
|
502
|
+
task6a = progress.add_task("[cyan] → Building type-token matrix...", total=None)
|
|
503
|
+
task6 = progress.add_task(
|
|
504
|
+
f"[cyan] → Computing PPMI & SVD...", total=None)
|
|
505
|
+
X_types, svd_components_types, final_type_to_idx = compute_typed_view(
|
|
506
|
+
node_types, type_to_idx, final_num_nodes, typed_dim, random_state, n_jobs=n_jobs,
|
|
507
|
+
expand_types=True, gpu_accelerator=gpu_accelerator
|
|
508
|
+
)
|
|
509
|
+
progress.update(task6a, completed=True)
|
|
510
|
+
progress.update(task6, completed=True)
|
|
511
|
+
|
|
512
|
+
final_type_count = len(final_type_to_idx) if final_type_to_idx else num_types
|
|
513
|
+
console.print(f" [dim]✓ Expanded to {final_type_count:,} type tokens ({num_types:,} original)[/dim]")
|
|
514
|
+
console.print(f" [dim]✓ Computed typed embeddings: {X_types.shape}[/dim]")
|
|
515
|
+
# Typed view is now complete - all resources released before next task
|
|
516
|
+
else:
|
|
517
|
+
console.print(f"\n[bold cyan]Step 3/5: Typed View[/bold cyan]")
|
|
518
|
+
console.print(f" [dim]⊘ Skipped (no types provided)[/dim]")
|
|
519
|
+
|
|
520
|
+
# Step 4: Fuse embeddings
|
|
521
|
+
# This task runs completely after typed view finishes (if available)
|
|
522
|
+
console.print(f"\n[bold cyan]Step 4/5: Fusion[/bold cyan]")
|
|
523
|
+
|
|
524
|
+
# Build embeddings list
|
|
525
|
+
embeddings_list = [X_graph, X_w2v]
|
|
526
|
+
if X_types is not None:
|
|
527
|
+
embeddings_list.append(X_types)
|
|
528
|
+
|
|
529
|
+
input_dims = [graph_dim, context_dim]
|
|
530
|
+
if X_types is not None:
|
|
531
|
+
input_dims.append(typed_dim)
|
|
532
|
+
total_input_dim = sum(input_dims)
|
|
533
|
+
console.print(f" [dim]Fusing {len(embeddings_list)} views (total input dim={total_input_dim}) → {final_dim}...[/dim]")
|
|
534
|
+
|
|
535
|
+
task7a = progress.add_task("[cyan] → Concatenating views & applying PCA...", total=None)
|
|
536
|
+
E, pca_components, pca_mean = fuse_embeddings(
|
|
537
|
+
embeddings_list, final_num_nodes, final_dim, random_state, n_jobs=n_jobs,
|
|
538
|
+
gpu_accelerator=gpu_accelerator
|
|
539
|
+
)
|
|
540
|
+
progress.update(task7a, completed=True)
|
|
541
|
+
|
|
542
|
+
# Store embeddings before normalization for mean_norm computation
|
|
543
|
+
E_before_norm = E.copy()
|
|
544
|
+
|
|
545
|
+
# Compute mean_norm for length penalty
|
|
546
|
+
task7b = progress.add_task("[cyan] → Normalizing embeddings...", total=None)
|
|
547
|
+
mean_norm = float(np.mean(np.linalg.norm(E_before_norm, axis=1)))
|
|
548
|
+
progress.update(task7b, completed=True)
|
|
549
|
+
|
|
550
|
+
console.print(f" [dim]✓ Fused embeddings: {E.shape} (reduced from {total_input_dim}D)[/dim]")
|
|
551
|
+
console.print(f" [dim]✓ Mean norm: {mean_norm:.4f}[/dim]")
|
|
552
|
+
|
|
553
|
+
# Apply iterative embedding smoothing (diffusion)
|
|
554
|
+
# This runs after fusion completes
|
|
555
|
+
if smoothing_iterations > 0:
|
|
556
|
+
console.print(f" [dim]Applying smoothing (iterations={smoothing_iterations}, beta=0.35)...[/dim]")
|
|
557
|
+
task7c = progress.add_task(f"[cyan] → Smoothing embeddings via diffusion...", total=None)
|
|
558
|
+
E = iterative_embedding_smoothing(
|
|
559
|
+
E, expanded_edges, final_num_nodes,
|
|
560
|
+
num_iterations=smoothing_iterations, beta=0.35, random_state=random_state,
|
|
561
|
+
gpu_accelerator=gpu_accelerator
|
|
562
|
+
)
|
|
563
|
+
progress.update(task7c, completed=True)
|
|
564
|
+
console.print(f" [dim]✓ Applied {smoothing_iterations} smoothing iteration(s)[/dim]")
|
|
565
|
+
# Smoothing is now complete - all resources released before next task
|
|
566
|
+
|
|
567
|
+
# Step 5: Learn temperature with improved negative sampling
|
|
568
|
+
# This task runs completely after smoothing finishes
|
|
569
|
+
# In fast mode, use fewer tau candidates and negatives for faster calibration
|
|
570
|
+
console.print(f"\n[bold cyan]Step 5/5: Temperature Calibration[/bold cyan]")
|
|
571
|
+
num_negatives = 3 if fast_mode else 5
|
|
572
|
+
tau_candidates_count = 20 if fast_mode else 30
|
|
573
|
+
console.print(f" [dim]Calibrating temperature with {tau_candidates_count} candidates, "
|
|
574
|
+
f"{num_negatives} negatives/edge, {len(val_edges):,} validation edges...[/dim]")
|
|
575
|
+
|
|
576
|
+
task8a = progress.add_task("[cyan] → Evaluating tau candidates...", total=None)
|
|
577
|
+
tau_candidates = np.logspace(-2, 2, num=tau_candidates_count)
|
|
578
|
+
tau = learn_temperature(
|
|
579
|
+
E, val_edges, final_num_nodes,
|
|
580
|
+
num_negatives=num_negatives,
|
|
581
|
+
tau_candidates=tau_candidates,
|
|
582
|
+
random_state=random_state, n_jobs=n_jobs,
|
|
583
|
+
node_metadata=node_metadata,
|
|
584
|
+
node_file_info=node_file_info,
|
|
585
|
+
idx_to_node=idx_to_node
|
|
586
|
+
)
|
|
587
|
+
progress.update(task8a, completed=True)
|
|
588
|
+
console.print(f" [dim]✓ Learned optimal temperature: τ = {tau:.6f}[/dim]")
|
|
589
|
+
|
|
590
|
+
# Build ANN index (only for original nodes, not subtokens)
|
|
591
|
+
console.print(f"\n[bold cyan]Building ANN Index[/bold cyan]")
|
|
592
|
+
console.print(f" [dim]Building approximate nearest neighbor index ({ann_trees} trees)...[/dim]")
|
|
593
|
+
task9a = progress.add_task("[cyan] → Adding nodes to index...", total=None)
|
|
594
|
+
ann_index = AnnoyIndex(final_dim, 'angular')
|
|
595
|
+
# Only index original nodes (not subtokens)
|
|
596
|
+
for i in range(num_nodes):
|
|
597
|
+
ann_index.add_item(i, E[i])
|
|
598
|
+
progress.update(task9a, completed=True)
|
|
599
|
+
|
|
600
|
+
task9b = progress.add_task("[cyan] → Building index trees...", total=None)
|
|
601
|
+
ann_index.build(ann_trees) # Reduced trees in fast mode
|
|
602
|
+
progress.update(task9b, completed=True)
|
|
603
|
+
console.print(f" [dim]✓ Built ANN index with {ann_trees} trees for {num_nodes:,} nodes[/dim]")
|
|
604
|
+
|
|
605
|
+
# Save model
|
|
606
|
+
console.print(f"\n[bold cyan]Saving Model[/bold cyan]")
|
|
607
|
+
console.print(f" [dim]Writing model files to {output_dir}...[/dim]")
|
|
608
|
+
|
|
609
|
+
# Ensure E exists and is valid before saving
|
|
610
|
+
if E is None:
|
|
611
|
+
raise ValueError("Embeddings (E) are None - cannot save model")
|
|
612
|
+
if not isinstance(E, np.ndarray):
|
|
613
|
+
raise ValueError(f"Embeddings (E) must be numpy array, got {type(E)}")
|
|
614
|
+
if E.size == 0:
|
|
615
|
+
raise ValueError("Embeddings (E) is empty - cannot save model")
|
|
616
|
+
|
|
617
|
+
task10a = progress.add_task("[cyan] → Saving embeddings & components...", total=None)
|
|
618
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
619
|
+
|
|
620
|
+
# Save embeddings (only original nodes for query, but keep full for future use)
|
|
621
|
+
# Save full embeddings including subtokens
|
|
622
|
+
embeddings_path = os.path.join(output_dir, 'embeddings.npy')
|
|
623
|
+
try:
|
|
624
|
+
np.save(embeddings_path, E)
|
|
625
|
+
# Verify file was written
|
|
626
|
+
if not os.path.exists(embeddings_path):
|
|
627
|
+
raise IOError(f"Failed to save embeddings.npy to {embeddings_path}")
|
|
628
|
+
# Verify file is readable
|
|
629
|
+
test_load = np.load(embeddings_path)
|
|
630
|
+
if test_load.shape != E.shape:
|
|
631
|
+
raise IOError(f"Saved embeddings shape mismatch: expected {E.shape}, got {test_load.shape}")
|
|
632
|
+
console.print(f" [dim]✓ Saved embeddings: {E.shape} → {embeddings_path}[/dim]")
|
|
633
|
+
except Exception as e:
|
|
634
|
+
console.print(f"[bold red]Error saving embeddings: {e}[/bold red]")
|
|
635
|
+
raise
|
|
636
|
+
progress.update(task10a, completed=True)
|
|
637
|
+
|
|
638
|
+
task10b = progress.add_task("[cyan] → Saving metadata & indices...", total=None)
|
|
639
|
+
# Save temperature
|
|
640
|
+
np.save(os.path.join(output_dir, 'tau.npy'), np.array(tau))
|
|
641
|
+
|
|
642
|
+
# Save mean_norm for length penalty
|
|
643
|
+
np.save(os.path.join(output_dir, 'mean_norm.npy'), np.array(mean_norm))
|
|
644
|
+
|
|
645
|
+
# Save metadata
|
|
646
|
+
metadata = {
|
|
647
|
+
'node_map': node_to_idx,
|
|
648
|
+
'node_metadata': node_metadata,
|
|
649
|
+
'embedding_dim': final_dim,
|
|
650
|
+
'num_nodes': num_nodes,
|
|
651
|
+
'final_num_nodes': final_num_nodes # Includes subtokens
|
|
652
|
+
}
|
|
653
|
+
with open(os.path.join(output_dir, 'metadata.json'), 'w') as f:
|
|
654
|
+
json.dump(metadata, f, indent=2)
|
|
655
|
+
|
|
656
|
+
# Save subtoken mapping
|
|
657
|
+
if subtoken_to_idx:
|
|
658
|
+
with open(os.path.join(output_dir, 'subtoken_map.json'), 'w') as f:
|
|
659
|
+
json.dump(subtoken_to_idx, f, indent=2)
|
|
660
|
+
|
|
661
|
+
# Save node subtokens
|
|
662
|
+
if node_subtokens:
|
|
663
|
+
with open(os.path.join(output_dir, 'node_subtokens.json'), 'w') as f:
|
|
664
|
+
json.dump(node_subtokens, f, indent=2)
|
|
665
|
+
|
|
666
|
+
# Save node types (for query expansion)
|
|
667
|
+
if node_types is not None:
|
|
668
|
+
# Convert node_types from idx-based to node_id-based
|
|
669
|
+
node_types_by_id = {}
|
|
670
|
+
for node_idx, types_dict in node_types.items():
|
|
671
|
+
node_id = idx_to_node.get(node_idx)
|
|
672
|
+
if node_id:
|
|
673
|
+
# Convert counts to int for JSON serialization
|
|
674
|
+
node_types_by_id[node_id] = {k: int(v) for k, v in types_dict.items()}
|
|
675
|
+
if node_types_by_id:
|
|
676
|
+
with open(os.path.join(output_dir, 'node_types.json'), 'w') as f:
|
|
677
|
+
json.dump(node_types_by_id, f, indent=2)
|
|
678
|
+
|
|
679
|
+
# Save PCA components
|
|
680
|
+
np.save(os.path.join(output_dir, 'fusion_pca_components.npy'), pca_components)
|
|
681
|
+
np.save(os.path.join(output_dir, 'fusion_pca_mean.npy'), pca_mean)
|
|
682
|
+
|
|
683
|
+
# Save SVD components
|
|
684
|
+
np.save(os.path.join(output_dir, 'svd_components.npy'), svd_components_graph)
|
|
685
|
+
|
|
686
|
+
if svd_components_types is not None:
|
|
687
|
+
np.save(os.path.join(output_dir, 'svd_components_types.npy'), svd_components_types)
|
|
688
|
+
|
|
689
|
+
# Save Word2Vec
|
|
690
|
+
word2vec_kv.save(os.path.join(output_dir, 'word2vec.kv'))
|
|
691
|
+
|
|
692
|
+
# Save ANN index
|
|
693
|
+
ann_index.save(os.path.join(output_dir, 'ann_index.ann'))
|
|
694
|
+
|
|
695
|
+
# Save type token map (use final expanded version)
|
|
696
|
+
if final_type_to_idx is not None:
|
|
697
|
+
with open(os.path.join(output_dir, 'type_token_map.json'), 'w') as f:
|
|
698
|
+
json.dump(final_type_to_idx, f, indent=2)
|
|
699
|
+
elif type_to_idx is not None:
|
|
700
|
+
with open(os.path.join(output_dir, 'type_token_map.json'), 'w') as f:
|
|
701
|
+
json.dump(type_to_idx, f, indent=2)
|
|
702
|
+
|
|
703
|
+
progress.update(task10b, completed=True)
|
|
704
|
+
|
|
705
|
+
# Verify critical files were saved
|
|
706
|
+
critical_files = ['embeddings.npy', 'tau.npy', 'metadata.json', 'ann_index.ann']
|
|
707
|
+
missing_files = []
|
|
708
|
+
for filename in critical_files:
|
|
709
|
+
filepath = os.path.join(output_dir, filename)
|
|
710
|
+
if not os.path.exists(filepath):
|
|
711
|
+
missing_files.append(filename)
|
|
712
|
+
|
|
713
|
+
if missing_files:
|
|
714
|
+
raise IOError(
|
|
715
|
+
f"Critical model files missing after save: {', '.join(missing_files)}\n"
|
|
716
|
+
f"Model directory: {output_dir}\n"
|
|
717
|
+
f"This indicates a save failure. Please check disk space and permissions."
|
|
718
|
+
)
|
|
719
|
+
|
|
720
|
+
console.print(f" [dim]✓ Saved {len(os.listdir(output_dir))} model files[/dim]")
|
|
721
|
+
console.print(f" [dim]✓ Verified critical files: {', '.join(critical_files)}[/dim]")
|
|
722
|
+
|
|
723
|
+
# Calculate elapsed time
|
|
724
|
+
end_time = time.time()
|
|
725
|
+
elapsed_time = end_time - start_time
|
|
726
|
+
|
|
727
|
+
# Format elapsed time
|
|
728
|
+
if elapsed_time < 60:
|
|
729
|
+
time_str = f"{elapsed_time:.2f} seconds"
|
|
730
|
+
elif elapsed_time < 3600:
|
|
731
|
+
minutes = int(elapsed_time // 60)
|
|
732
|
+
seconds = elapsed_time % 60
|
|
733
|
+
time_str = f"{minutes}m {seconds:.2f}s"
|
|
734
|
+
else:
|
|
735
|
+
hours = int(elapsed_time // 3600)
|
|
736
|
+
minutes = int((elapsed_time % 3600) // 60)
|
|
737
|
+
seconds = elapsed_time % 60
|
|
738
|
+
time_str = f"{hours}h {minutes}m {seconds:.2f}s"
|
|
739
|
+
|
|
740
|
+
# Display statistics
|
|
741
|
+
console.print("\n[bold green]✓ Training Complete![/bold green]\n")
|
|
742
|
+
|
|
743
|
+
stats_table = Table(box=box.ROUNDED, title="Training Statistics")
|
|
744
|
+
stats_table.add_column("Metric", style="cyan", width=25)
|
|
745
|
+
stats_table.add_column("Value", style="white")
|
|
746
|
+
stats_table.add_row("Total Nodes", str(num_nodes))
|
|
747
|
+
stats_table.add_row("Total Nodes (with subtokens)", str(final_num_nodes))
|
|
748
|
+
stats_table.add_row("Total Edges", str(len(edges)))
|
|
749
|
+
stats_table.add_row("Training Edges", str(len(train_edges)))
|
|
750
|
+
stats_table.add_row("Validation Edges", str(len(val_edges)))
|
|
751
|
+
if final_type_to_idx:
|
|
752
|
+
stats_table.add_row("Type Tokens (expanded)", str(len(final_type_to_idx)))
|
|
753
|
+
elif type_to_idx:
|
|
754
|
+
stats_table.add_row("Type Tokens", str(len(type_to_idx)))
|
|
755
|
+
if subtoken_to_idx:
|
|
756
|
+
stats_table.add_row("Subtoken Nodes", str(len(subtoken_to_idx)))
|
|
757
|
+
stats_table.add_row("Graph View Dim", f"{X_graph.shape[1]}")
|
|
758
|
+
stats_table.add_row("Context View Dim", f"{X_w2v.shape[1]}")
|
|
759
|
+
if X_types is not None:
|
|
760
|
+
stats_table.add_row("Typed View Dim", f"{X_types.shape[1]}")
|
|
761
|
+
stats_table.add_row("Final Embedding Dim", f"{E.shape[1]}")
|
|
762
|
+
stats_table.add_row("Temperature (τ)", f"{tau:.6f}")
|
|
763
|
+
stats_table.add_row("Mean Norm", f"{mean_norm:.6f}")
|
|
764
|
+
stats_table.add_row("Model Directory", output_dir)
|
|
765
|
+
stats_table.add_row("", "") # Separator
|
|
766
|
+
stats_table.add_row("[bold]Total Training Time[/bold]", f"[bold green]{time_str}[/bold green]")
|
|
767
|
+
|
|
768
|
+
console.print(stats_table)
|
|
769
|
+
console.print()
|
|
770
|
+
|
|
771
|
+
|
|
772
|
+
@click.command()
|
|
773
|
+
@click.option('--nodes', '-n', required=True, type=click.Path(exists=True),
|
|
774
|
+
help='Path to nodes.jsonl file containing symbol definitions.\n'
|
|
775
|
+
'Each line should be a JSON object with: id, kind, name, meta.')
|
|
776
|
+
@click.option('--edges', '-e', required=True, type=click.Path(exists=True),
|
|
777
|
+
help='Path to edges.jsonl file containing symbol relationships.\n'
|
|
778
|
+
'Each line should be a JSON object with: src, dst, rel, weight.')
|
|
779
|
+
@click.option('--types', '-t', default=None, type=click.Path(exists=True),
|
|
780
|
+
help='[Optional] Path to types.jsonl file containing type token information.\n'
|
|
781
|
+
'Each line should be a JSON object with: symbol, type_token, count.')
|
|
782
|
+
@click.option('--out', '-o', required=True, type=click.Path(),
|
|
783
|
+
help='Output directory where the trained model will be saved.\n'
|
|
784
|
+
'Directory will be created if it does not exist.')
|
|
785
|
+
@click.option('--graph-dim', default=None, type=int, show_default=False,
|
|
786
|
+
help='Dimensionality for the graph view embeddings.\n'
|
|
787
|
+
'Range: 8-512. Higher values capture more graph structure but increase computation.\n'
|
|
788
|
+
'If not specified, computed automatically based on number of nodes.\n'
|
|
789
|
+
'Recommended: 32-128 for small codebases, 64-256 for large ones.')
|
|
790
|
+
@click.option('--context-dim', default=None, type=int, show_default=False,
|
|
791
|
+
help='Dimensionality for the context view embeddings (Node2Vec + Word2Vec).\n'
|
|
792
|
+
'Range: 8-512. Higher values capture more semantic context.\n'
|
|
793
|
+
'If not specified, computed automatically to match graph-dim.\n'
|
|
794
|
+
'Recommended: 32-128. Should match graph-dim for balanced fusion.')
|
|
795
|
+
@click.option('--typed-dim', default=None, type=int, show_default=False,
|
|
796
|
+
help='Dimensionality for the typed view embeddings.\n'
|
|
797
|
+
'Range: 8-512. Only used if --types file is provided.\n'
|
|
798
|
+
'If not specified, computed automatically based on number of type tokens.\n'
|
|
799
|
+
'Recommended: 32-128. Should match other view dimensions.')
|
|
800
|
+
@click.option('--final-dim', default=None, type=int, show_default=False,
|
|
801
|
+
help='Final dimensionality of fused embeddings after PCA reduction.\n'
|
|
802
|
+
'Range: 16-512. This is the output embedding size used for queries.\n'
|
|
803
|
+
'If not specified, computed automatically based on input dimensions and node count.\n'
|
|
804
|
+
'Recommended: 64-256. Higher values preserve more information but increase memory.')
|
|
805
|
+
@click.option('--num-walks', default=10, type=int, show_default=True,
|
|
806
|
+
help='Number of random walks to generate per node for context view.\n'
|
|
807
|
+
'Range: 5-100. More walks improve context quality but increase training time.\n'
|
|
808
|
+
'Recommended: 10-20 for most codebases.')
|
|
809
|
+
@click.option('--walk-length', default=80, type=int, show_default=True,
|
|
810
|
+
help='Length of each random walk in the context view.\n'
|
|
811
|
+
'Range: 20-200. Longer walks capture more distant relationships.\n'
|
|
812
|
+
'Recommended: 40-100. Shorter for small graphs, longer for large ones.')
|
|
813
|
+
@click.option('--train-ratio', default=0.8, type=float, show_default=True,
|
|
814
|
+
help='Fraction of edges used for training (rest used for temperature calibration).\n'
|
|
815
|
+
'Range: 0.5-0.95. Higher values use more data for training but less for calibration.\n'
|
|
816
|
+
'Recommended: 0.8. Lower values (0.7-0.75) improve temperature estimation.')
|
|
817
|
+
@click.option('--random-state', default=42, type=int, show_default=True,
|
|
818
|
+
help='Random seed for reproducibility.\n'
|
|
819
|
+
'Range: Any integer. Use the same seed to get identical results.\n'
|
|
820
|
+
'Recommended: 42 (classic choice) or any fixed integer for reproducibility.')
|
|
821
|
+
def main(nodes, edges, types, out, graph_dim, context_dim, typed_dim, final_dim,
|
|
822
|
+
num_walks, walk_length, train_ratio, random_state):
|
|
823
|
+
"""
|
|
824
|
+
Train TriVector Code Intelligence (TVI) model.
|
|
825
|
+
|
|
826
|
+
TVI learns symbol-level semantics from codebases using three complementary views:
|
|
827
|
+
|
|
828
|
+
\b
|
|
829
|
+
1. Graph View: Structural relationships via PPMI and SVD
|
|
830
|
+
2. Context View: Semantic context via Node2Vec random walks and Word2Vec
|
|
831
|
+
3. Typed View: Type information via type-token co-occurrence (optional)
|
|
832
|
+
|
|
833
|
+
The views are fused using PCA and normalized to produce final embeddings.
|
|
834
|
+
A temperature parameter is learned for calibrated similarity scores.
|
|
835
|
+
|
|
836
|
+
\b
|
|
837
|
+
Example:
|
|
838
|
+
python train_tvci.py --nodes nodes.jsonl --edges edges.jsonl --out model_dir
|
|
839
|
+
"""
|
|
840
|
+
train_model(
|
|
841
|
+
nodes_path=nodes,
|
|
842
|
+
edges_path=edges,
|
|
843
|
+
types_path=types,
|
|
844
|
+
output_dir=out,
|
|
845
|
+
graph_dim=graph_dim,
|
|
846
|
+
context_dim=context_dim,
|
|
847
|
+
typed_dim=typed_dim,
|
|
848
|
+
final_dim=final_dim,
|
|
849
|
+
num_walks=num_walks,
|
|
850
|
+
walk_length=walk_length,
|
|
851
|
+
train_ratio=train_ratio,
|
|
852
|
+
random_state=random_state
|
|
853
|
+
)
|
|
854
|
+
|
|
855
|
+
|
|
856
|
+
if __name__ == '__main__':
|
|
857
|
+
main()
|