tricoder 1.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tricoder/train.py ADDED
@@ -0,0 +1,857 @@
1
+ """Training pipeline for TriVector Code Intelligence."""
2
+ import json
3
+ import os
4
+ import time
5
+ from typing import Optional
6
+
7
+ import click
8
+ import numpy as np
9
+ from annoy import AnnoyIndex
10
+ from rich import box
11
+ from rich.console import Console
12
+ from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TimeElapsedColumn
13
+ from rich.table import Table
14
+
15
+ from .calibration import split_edges, learn_temperature
16
+ from .context_view import compute_context_view
17
+ from .data_loader import load_nodes, load_edges, load_types
18
+ from .fusion import fuse_embeddings, iterative_embedding_smoothing
19
+ from .graph_view import compute_graph_view
20
+ from .typed_view import compute_typed_view
21
+
22
+ console = Console()
23
+
24
+
25
+ def estimate_training_time(num_nodes: int, num_edges: int, num_types: Optional[int],
26
+ graph_dim: int, context_dim: int, typed_dim: int,
27
+ num_walks: int, walk_length: int, final_dim: int,
28
+ n_jobs: int = 1, use_gpu: bool = False) -> str:
29
+ """
30
+ Estimate training time based on data size and parameters.
31
+
32
+ Args:
33
+ use_gpu: Whether GPU acceleration will be used (affects SVD/PCA/matrix operation times)
34
+
35
+ Returns:
36
+ Estimated time as formatted string
37
+ """
38
+ # Check if GPU is actually available
39
+ gpu_available = False
40
+ gpu_speedup_factor = 1.0
41
+ if use_gpu:
42
+ try:
43
+ from .gpu_utils import GPUAccelerator
44
+ gpu_accelerator = GPUAccelerator(use_gpu=True)
45
+ if gpu_accelerator.use_gpu:
46
+ gpu_available = True
47
+ # GPU speedup factors based on typical performance:
48
+ # - SVD/PCA: 5-20x faster (use conservative 8x)
49
+ # - Matrix operations: 10-50x faster (use conservative 15x)
50
+ # - Sparse operations: 3-10x faster (use conservative 5x)
51
+ # Use different factors for different operations
52
+ gpu_speedup_factor = 1.0 # Will be applied per-operation
53
+ except Exception:
54
+ pass # GPU not available, use CPU estimates
55
+
56
+ # More realistic time estimates based on actual performance
57
+ # Accounts for parallelization efficiency and overhead
58
+
59
+ # Parallelization efficiency factor (diminishing returns with more cores)
60
+ # More cores help but not linearly due to overhead
61
+ efficiency = min(0.85, 0.3 + 0.55 * (n_jobs / max(n_jobs, 8))) # Cap efficiency at ~85%
62
+ effective_cores = n_jobs * efficiency
63
+
64
+ # Graph view: PPMI + SVD
65
+ # PPMI computation: sparse matrix operations, scales with edges
66
+ # SVD: matrix decomposition, scales with nodes and dimensions
67
+ # Account for subtoken expansion (roughly doubles nodes)
68
+ expanded_nodes = num_nodes * 2.5 # Account for subtokens and expansion
69
+ ppmi_time = (expanded_nodes * num_edges ** 0.5) / (5000 * effective_cores)
70
+ # SVD benefits significantly from GPU (8x speedup)
71
+ svd_base_time = (expanded_nodes * graph_dim ** 2) / (20000 * effective_cores)
72
+ svd_time = svd_base_time / (8.0 if gpu_available else 1.0)
73
+ graph_time = ppmi_time + svd_time
74
+
75
+ # Context view: Random walks + Word2Vec
76
+ # Random walks: parallelized per node, but has overhead
77
+ total_walks = expanded_nodes * num_walks
78
+ walk_time = (total_walks * walk_length) / (8000 * effective_cores) # More realistic
79
+
80
+ # Word2Vec: training is CPU-intensive, benefits from parallelization
81
+ # window=10, epochs=5, negative=5
82
+ # Word2Vec doesn't benefit much from GPU (mostly CPU-bound)
83
+ w2v_time = (total_walks * walk_length * 10 * 5) / (150000 * effective_cores)
84
+ context_time = walk_time + w2v_time
85
+
86
+ # Typed view: PPMI + SVD (if available)
87
+ typed_time = 0
88
+ if num_types:
89
+ # Account for type expansion
90
+ expanded_types = num_types * 1.3 # Type expansion adds ~30% more types
91
+ typed_ppmi_time = (expanded_nodes * expanded_types ** 0.5) / (5000 * effective_cores)
92
+ # SVD benefits significantly from GPU (8x speedup)
93
+ typed_svd_base_time = (expanded_nodes * typed_dim ** 2) / (20000 * effective_cores)
94
+ typed_svd_time = typed_svd_base_time / (8.0 if gpu_available else 1.0)
95
+ typed_time = typed_ppmi_time + typed_svd_time
96
+
97
+ # Fusion: PCA
98
+ # PCA is memory-bound and benefits significantly from GPU (8x speedup)
99
+ total_input_dim = graph_dim + context_dim + (typed_dim if num_types else 0)
100
+ fusion_base_time = (expanded_nodes * total_input_dim * final_dim) / (30000 * effective_cores)
101
+ fusion_time = fusion_base_time / (8.0 if gpu_available else 1.0)
102
+
103
+ # Embedding smoothing: iterative neighbor averaging
104
+ # Sparse matrix operations benefit from GPU (5x speedup)
105
+ smoothing_base_time = (expanded_nodes * num_edges * 2) / (10000 * effective_cores) # 2 iterations
106
+ smoothing_time = smoothing_base_time / (5.0 if gpu_available else 1.0)
107
+
108
+ # Temperature calibration: grid search with parallel evaluation
109
+ tau_candidates = 50 # default
110
+ val_edges_est = int(num_edges * 0.2) # ~20% for validation
111
+ calibration_time = (tau_candidates * val_edges_est) / (5000 * effective_cores)
112
+
113
+ # ANN index building (single-threaded mostly, no GPU benefit)
114
+ ann_time = (num_nodes * final_dim * 10) / 200000 # 10 trees, less parallelizable
115
+
116
+ # I/O overhead (saving files)
117
+ io_time = 1.5
118
+
119
+ # Data loading overhead
120
+ load_time = 0.5
121
+
122
+ total_seconds = (graph_time + context_time + typed_time + fusion_time +
123
+ smoothing_time + calibration_time + ann_time + io_time + load_time)
124
+
125
+ # Format time estimate
126
+ if total_seconds < 60:
127
+ return f"{total_seconds:.1f}s"
128
+ elif total_seconds < 3600:
129
+ minutes = total_seconds / 60
130
+ return f"{minutes:.1f}m"
131
+ else:
132
+ hours = total_seconds / 3600
133
+ minutes = (total_seconds % 3600) / 60
134
+ return f"{int(hours)}h {int(minutes)}m"
135
+
136
+
137
+ def compute_default_dimensions(num_nodes: int, num_edges: int, num_types: Optional[int] = None) -> dict:
138
+ """
139
+ Compute default dimensions based on data characteristics.
140
+
141
+ Returns:
142
+ Dictionary with graph_dim, context_dim, typed_dim, final_dim
143
+ """
144
+ # Graph dimension: based on number of nodes
145
+ # Use log-based heuristic for better scaling: log2(nodes) * 8, clamped to reasonable range
146
+ # Small codebases (<100 nodes): ~32-48 dims
147
+ # Medium codebases (100-1000 nodes): ~48-64 dims
148
+ # Large codebases (>1000 nodes): ~64-128 dims
149
+ if num_nodes < 50:
150
+ graph_dim = 32
151
+ elif num_nodes < 200:
152
+ graph_dim = max(32, min(64, int(np.log2(num_nodes) * 8)))
153
+ elif num_nodes < 1000:
154
+ graph_dim = max(48, min(96, int(np.log2(num_nodes) * 7)))
155
+ else:
156
+ graph_dim = max(64, min(128, int(np.log2(num_nodes) * 6)))
157
+
158
+ # Context dimension: match graph dimension for balanced fusion
159
+ context_dim = graph_dim
160
+
161
+ # Typed dimension: based on number of type tokens if available
162
+ if num_types:
163
+ if num_types < 20:
164
+ typed_dim = 16
165
+ elif num_types < 50:
166
+ typed_dim = max(16, min(32, int(np.sqrt(num_types) * 2)))
167
+ else:
168
+ typed_dim = max(32, min(64, int(np.sqrt(num_types) * 2)))
169
+ else:
170
+ typed_dim = 32 # Default when no types
171
+
172
+ # Final dimension: based on total input dimensions and number of nodes
173
+ # Should be less than number of nodes (for PCA to work)
174
+ # Use 50-70% of total input dimensions, but respect node count limit
175
+ total_input_dim = graph_dim + context_dim + (typed_dim if num_types else 0)
176
+ final_dim_candidate = int(total_input_dim * 0.6) # 60% of input dims
177
+ final_dim = max(32, min(num_nodes - 1, min(256, final_dim_candidate)))
178
+
179
+ return {
180
+ 'graph_dim': graph_dim,
181
+ 'context_dim': context_dim,
182
+ 'typed_dim': typed_dim,
183
+ 'final_dim': final_dim
184
+ }
185
+
186
+
187
+ def train_model(nodes_path: str,
188
+ edges_path: str,
189
+ types_path: Optional[str],
190
+ output_dir: str,
191
+ graph_dim: Optional[int] = None,
192
+ context_dim: Optional[int] = None,
193
+ typed_dim: Optional[int] = None,
194
+ final_dim: Optional[int] = None,
195
+ num_walks: int = 10,
196
+ walk_length: int = 80,
197
+ train_ratio: float = 0.8,
198
+ random_state: int = 42,
199
+ fast_mode: bool = False,
200
+ use_gpu: bool = False):
201
+ """
202
+ Train TriVector Code Intelligence model.
203
+
204
+ Optimizations applied for faster training (without sacrificing much quality):
205
+ - Vectorized embedding smoothing (uses sparse matrix operations, 2-5x faster)
206
+ - Reduced SVD iterations (5 instead of 10, ~2x faster)
207
+ - Reduced Word2Vec epochs (3 instead of 5, ~1.7x faster)
208
+ - Reduced Word2Vec window (7 instead of 10, ~1.4x faster)
209
+ - Reduced Word2Vec negative samples (3 instead of 5, ~1.7x faster)
210
+ - Reduced temperature calibration candidates (30 instead of 50, ~1.7x faster)
211
+ - Fast mode: further optimizations:
212
+ * Halves random walk parameters
213
+ * Reduces smoothing iterations (1 instead of 2)
214
+ * Reduces context window (3 instead of 5)
215
+ * Reduces call graph depth (2 instead of 3)
216
+ * Reduces ANN trees (7 instead of 10)
217
+ * Uses less validation data for calibration
218
+ * Fewer tau candidates (20 instead of 30)
219
+ * Fewer negative samples (3 instead of 5)
220
+
221
+ Args:
222
+ nodes_path: path to nodes.jsonl
223
+ edges_path: path to edges.jsonl
224
+ types_path: path to types.jsonl (optional)
225
+ output_dir: output directory for model
226
+ graph_dim: dimensionality for graph view
227
+ context_dim: dimensionality for context view
228
+ typed_dim: dimensionality for typed view
229
+ final_dim: final fused embedding dimensionality
230
+ num_walks: number of random walks per node
231
+ walk_length: length of each random walk
232
+ train_ratio: ratio of edges for training (rest for calibration)
233
+ random_state: random seed
234
+ fast_mode: if True, reduces walk parameters for faster training (lower quality)
235
+ use_gpu: if True, attempt GPU acceleration (requires CuPy and CUDA-capable GPU)
236
+ """
237
+ # Initialize GPU accelerator if requested
238
+ gpu_accelerator = None
239
+ if use_gpu:
240
+ from .gpu_utils import GPUAccelerator, TORCH_AVAILABLE, TORCH_VERSION, diagnose_gpu_support
241
+ import platform
242
+
243
+ is_mac = platform.system() == 'Darwin'
244
+ if is_mac and not TORCH_AVAILABLE:
245
+ console.print(f"[yellow]⚠ PyTorch not installed. Install with: pip install torch[/yellow]")
246
+ console.print(f"[yellow] GPU acceleration requires PyTorch for Mac MPS support. Using CPU.[/yellow]\n")
247
+ else:
248
+ gpu_accelerator = GPUAccelerator(use_gpu=True)
249
+ if gpu_accelerator.use_gpu:
250
+ backend_name = "CUDA (NVIDIA)" if gpu_accelerator.device_type == 'cuda' else "MPS (Mac)"
251
+ console.print(f"[bold green]✓ GPU acceleration enabled ({backend_name})[/bold green]\n")
252
+ else:
253
+ if is_mac:
254
+ # Provide detailed diagnostics
255
+ diagnostics = diagnose_gpu_support()
256
+ console.print(f"[yellow]⚠ GPU acceleration requested but MPS not available.[/yellow]")
257
+ if not TORCH_AVAILABLE:
258
+ console.print(f"[yellow] PyTorch is not installed. Install with: pip install torch[/yellow]")
259
+ elif TORCH_VERSION:
260
+ # Simple version check (compare first two parts)
261
+ try:
262
+ version_parts = TORCH_VERSION.split('.')
263
+ major = int(version_parts[0])
264
+ minor = int(version_parts[1]) if len(version_parts) > 1 else 0
265
+ if major < 1 or (major == 1 and minor < 12):
266
+ console.print(f"[yellow] PyTorch version {TORCH_VERSION} is too old. Upgrade to 1.12+ for MPS support.[/yellow]")
267
+ console.print(f"[yellow] Upgrade with: pip install --upgrade torch[/yellow]")
268
+ except (ValueError, IndexError):
269
+ pass # Skip version check if parsing fails
270
+ elif 'mps_unavailable_reason' in diagnostics:
271
+ console.print(f"[yellow] {diagnostics['mps_unavailable_reason']}[/yellow]")
272
+ else:
273
+ console.print(f"[yellow] Requirements: macOS 12.3+, Apple Silicon (M1/M2/M3), PyTorch 1.12+[/yellow]")
274
+ console.print(f"[yellow] Using CPU.[/yellow]\n")
275
+ else:
276
+ console.print(f"[yellow]⚠ GPU acceleration requested but not available, using CPU[/yellow]\n")
277
+
278
+ # Set random seeds
279
+ np.random.seed(random_state)
280
+
281
+ # Record start time
282
+ start_time = time.time()
283
+
284
+ console.print("\n[bold cyan]TriVector Code Intelligence - Training Pipeline[/bold cyan]\n")
285
+
286
+ with Progress(
287
+ SpinnerColumn(),
288
+ TextColumn("[progress.description]{task.description}"),
289
+ BarColumn(),
290
+ TimeElapsedColumn(),
291
+ console=console
292
+ ) as progress:
293
+ # Load data first to compute dimensions
294
+ task1 = progress.add_task("[cyan]Loading nodes...", total=None)
295
+ node_to_idx, node_metadata, node_subtokens, node_file_info = load_nodes(nodes_path)
296
+ num_nodes = len(node_to_idx)
297
+ progress.update(task1, completed=True)
298
+ console.print(f" [dim]✓ Loaded {num_nodes:,} nodes[/dim]")
299
+
300
+ if num_nodes == 0:
301
+ raise ValueError("No nodes found in input file")
302
+
303
+ task1b = progress.add_task("[cyan]Loading edges...", total=None)
304
+ edges, _ = load_edges(edges_path, node_to_idx)
305
+ num_edges = len(edges)
306
+ progress.update(task1b, completed=True)
307
+ console.print(f" [dim]✓ Loaded {num_edges:,} edges[/dim]")
308
+
309
+ if num_edges == 0:
310
+ raise ValueError("No edges found in input file")
311
+
312
+ # Load types if available
313
+ node_types = None
314
+ type_to_idx = None
315
+ num_types = None
316
+ if types_path and os.path.exists(types_path):
317
+ task2 = progress.add_task("[cyan]Loading types...", total=None)
318
+ node_types, type_to_idx = load_types(types_path, node_to_idx)
319
+ num_types = len(type_to_idx)
320
+ progress.update(task2, completed=True)
321
+ console.print(f" [dim]✓ Loaded {num_types:,} type tokens[/dim]")
322
+ else:
323
+ console.print(f" [dim]⊘ No types file provided[/dim]")
324
+
325
+ # Compute default dimensions if not provided
326
+ defaults = compute_default_dimensions(num_nodes, len(edges), num_types)
327
+
328
+ if graph_dim is None:
329
+ graph_dim = defaults['graph_dim']
330
+ graph_dim_source = "[dim](computed)[/dim]"
331
+ else:
332
+ graph_dim_source = ""
333
+
334
+ if context_dim is None:
335
+ context_dim = defaults['context_dim']
336
+ context_dim_source = "[dim](computed)[/dim]"
337
+ else:
338
+ context_dim_source = ""
339
+
340
+ if typed_dim is None:
341
+ typed_dim = defaults['typed_dim']
342
+ typed_dim_source = "[dim](computed)[/dim]"
343
+ else:
344
+ typed_dim_source = ""
345
+
346
+ if final_dim is None:
347
+ final_dim = defaults['final_dim']
348
+ final_dim_source = "[dim](computed)[/dim]"
349
+ else:
350
+ final_dim_source = ""
351
+
352
+ # Get number of workers for time estimation
353
+ from multiprocessing import cpu_count
354
+ n_jobs_est = max(1, cpu_count() - 1)
355
+
356
+ # Estimate training time (account for GPU if available)
357
+ estimated_time = estimate_training_time(
358
+ num_nodes, len(edges), num_types,
359
+ graph_dim, context_dim, typed_dim,
360
+ num_walks, walk_length, final_dim,
361
+ n_jobs_est, use_gpu=use_gpu
362
+ )
363
+
364
+ # Display configuration
365
+ config_table = Table(box=box.ROUNDED, show_header=False, title="Configuration")
366
+ config_table.add_column("Parameter", style="cyan", width=25)
367
+ config_table.add_column("Value", style="white", width=15)
368
+ config_table.add_column("Source", style="dim", width=12)
369
+ config_table.add_row("Graph Dimension", str(graph_dim), graph_dim_source)
370
+ config_table.add_row("Context Dimension", str(context_dim), context_dim_source)
371
+ config_table.add_row("Typed Dimension", str(typed_dim), typed_dim_source)
372
+ config_table.add_row("Final Dimension", str(final_dim), final_dim_source)
373
+ config_table.add_row("Random Walks", str(num_walks), "")
374
+ config_table.add_row("Walk Length", str(walk_length), "")
375
+ config_table.add_row("Train Ratio", f"{train_ratio:.2f}", "")
376
+ config_table.add_row("Random State", str(random_state), "")
377
+ config_table.add_row("", "", "") # Separator
378
+ config_table.add_row("Data: Nodes", str(num_nodes), "")
379
+ config_table.add_row("Data: Edges", str(len(edges)), "")
380
+ if num_types:
381
+ config_table.add_row("Data: Type Tokens", str(num_types), "")
382
+ config_table.add_row("", "", "") # Separator
383
+ gpu_status = "GPU" if (use_gpu and gpu_accelerator and gpu_accelerator.use_gpu) else "CPU"
384
+ config_table.add_row("[bold]Estimated Time[/bold]", f"[bold green]{estimated_time}[/bold green]", f"[dim]({gpu_status})[/dim]")
385
+ config_table.add_row("Workers", str(n_jobs_est), "")
386
+ console.print(config_table)
387
+ console.print()
388
+
389
+ with Progress(
390
+ SpinnerColumn(),
391
+ TextColumn("[progress.description]{task.description}"),
392
+ BarColumn(),
393
+ TimeElapsedColumn(),
394
+ console=console
395
+ ) as progress:
396
+
397
+ # Split edges for calibration
398
+ # In fast mode, use less validation data for faster calibration
399
+ calibration_train_ratio = train_ratio if not fast_mode else min(0.9, train_ratio + 0.05)
400
+ task3 = progress.add_task("[cyan]Splitting edges for training/validation...", total=None)
401
+ train_edges, val_edges = split_edges(edges, calibration_train_ratio, random_state)
402
+ progress.update(task3, completed=True)
403
+ console.print(f" [dim]✓ Split edges: {len(train_edges):,} training, {len(val_edges):,} validation "
404
+ f"({calibration_train_ratio:.1%}/{1-calibration_train_ratio:.1%})[/dim]")
405
+
406
+ # Get number of workers (all cores - 1, but use all cores on Windows for better performance)
407
+ from multiprocessing import cpu_count
408
+ import platform
409
+ if platform.system() == 'Windows':
410
+ # On Windows, use all cores since spawn method handles it well
411
+ n_jobs = cpu_count()
412
+ else:
413
+ n_jobs = max(1, cpu_count() - 1)
414
+
415
+ # Apply fast mode optimizations
416
+ smoothing_iterations = 2
417
+ context_window_size = 5
418
+ call_graph_depth = 3
419
+ ann_trees = 10
420
+
421
+ if fast_mode:
422
+ # Reduce random walk parameters for faster training
423
+ num_walks = max(5, num_walks // 2) # Reduce walks by half (min 5)
424
+ walk_length = max(40, walk_length // 2) # Reduce walk length by half (min 40)
425
+ smoothing_iterations = 1 # Reduce smoothing iterations
426
+ context_window_size = 3 # Reduce context window
427
+ call_graph_depth = 2 # Reduce call graph expansion depth
428
+ ann_trees = 7 # Reduce ANN trees (slightly faster indexing)
429
+ console.print(f"[yellow]Fast mode: Using {num_walks} walks of length {walk_length}, "
430
+ f"{smoothing_iterations} smoothing iteration(s), "
431
+ f"context window {context_window_size}, call depth {call_graph_depth}[/yellow]\n")
432
+
433
+ # Create idx_to_node mapping
434
+ idx_to_node = {idx: node_id for node_id, idx in node_to_idx.items()}
435
+
436
+ # Step 1: Compute graph view with all enhancements
437
+ # This task runs completely and sequentially before moving to the next one
438
+ # This ensures full parallelization can be used for graph view operations
439
+ console.print(f"\n[bold cyan]Step 1/5: Graph View[/bold cyan]")
440
+ console.print(f" [dim]Computing graph embeddings (dim={graph_dim}) with {n_jobs} workers...[/dim]")
441
+
442
+ task4a = progress.add_task("[cyan] → Expanding call graph...", total=None)
443
+ task4 = progress.add_task(
444
+ f"[cyan] → Building adjacency matrix & computing PPMI...", total=None)
445
+ X_graph, svd_components_graph, final_num_nodes, subtoken_to_idx, expanded_edges = compute_graph_view(
446
+ train_edges, num_nodes, graph_dim, random_state, n_jobs=n_jobs,
447
+ node_to_idx=node_to_idx,
448
+ node_subtokens=node_subtokens,
449
+ node_file_info=node_file_info,
450
+ node_metadata=node_metadata,
451
+ idx_to_node=idx_to_node,
452
+ expand_calls=True,
453
+ add_subtokens=True,
454
+ add_hierarchy=True,
455
+ add_context=True,
456
+ context_window=context_window_size,
457
+ max_depth=call_graph_depth,
458
+ gpu_accelerator=gpu_accelerator
459
+ )
460
+ progress.update(task4a, completed=True)
461
+ progress.update(task4, completed=True)
462
+
463
+ num_subtokens = len(subtoken_to_idx) if subtoken_to_idx else 0
464
+ expanded_edges_count = len(expanded_edges)
465
+ console.print(f" [dim]✓ Expanded to {final_num_nodes:,} nodes ({num_nodes:,} original + {num_subtokens:,} subtokens)[/dim]")
466
+ console.print(f" [dim]✓ Expanded to {expanded_edges_count:,} edges ({len(train_edges):,} original)[/dim]")
467
+ console.print(f" [dim]✓ Computed graph embeddings: {X_graph.shape}[/dim]")
468
+ # Graph view is now complete - all resources released before next task
469
+
470
+ # Step 2: Compute context view using expanded edges (includes subtokens)
471
+ # This task runs completely after graph view finishes
472
+ # This ensures full parallelization can be used for context view operations
473
+ console.print(f"\n[bold cyan]Step 2/5: Context View[/bold cyan]")
474
+ console.print(f" [dim]Computing context embeddings (dim={context_dim}) with {n_jobs} workers...[/dim]")
475
+ console.print(f" [dim]Parameters: {num_walks} walks/node, length={walk_length}[/dim]")
476
+
477
+ task5a = progress.add_task("[cyan] → Generating random walks...", total=None)
478
+ task5 = progress.add_task(
479
+ f"[cyan] → Training Word2Vec model...", total=None)
480
+ # Use expanded_edges which includes subtokens and all enhancements
481
+ X_w2v, word2vec_kv = compute_context_view(
482
+ expanded_edges, final_num_nodes, context_dim, num_walks, walk_length, random_state, n_jobs=n_jobs
483
+ )
484
+ progress.update(task5a, completed=True)
485
+ progress.update(task5, completed=True)
486
+
487
+ total_walks = final_num_nodes * num_walks
488
+ console.print(f" [dim]✓ Generated {total_walks:,} random walks[/dim]")
489
+ console.print(f" [dim]✓ Trained Word2Vec model (vocab size: {len(word2vec_kv.key_to_index):,})[/dim]")
490
+ console.print(f" [dim]✓ Computed context embeddings: {X_w2v.shape}[/dim]")
491
+ # Context view is now complete - all resources released before next task
492
+
493
+ # Step 3: Compute typed view if available (with type expansion)
494
+ # This task runs completely after context view finishes
495
+ X_types = None
496
+ svd_components_types = None
497
+ final_type_to_idx = None
498
+ if node_types is not None and type_to_idx is not None:
499
+ console.print(f"\n[bold cyan]Step 3/5: Typed View[/bold cyan]")
500
+ console.print(f" [dim]Computing typed embeddings (dim={typed_dim}) with {n_jobs} workers...[/dim]")
501
+
502
+ task6a = progress.add_task("[cyan] → Building type-token matrix...", total=None)
503
+ task6 = progress.add_task(
504
+ f"[cyan] → Computing PPMI & SVD...", total=None)
505
+ X_types, svd_components_types, final_type_to_idx = compute_typed_view(
506
+ node_types, type_to_idx, final_num_nodes, typed_dim, random_state, n_jobs=n_jobs,
507
+ expand_types=True, gpu_accelerator=gpu_accelerator
508
+ )
509
+ progress.update(task6a, completed=True)
510
+ progress.update(task6, completed=True)
511
+
512
+ final_type_count = len(final_type_to_idx) if final_type_to_idx else num_types
513
+ console.print(f" [dim]✓ Expanded to {final_type_count:,} type tokens ({num_types:,} original)[/dim]")
514
+ console.print(f" [dim]✓ Computed typed embeddings: {X_types.shape}[/dim]")
515
+ # Typed view is now complete - all resources released before next task
516
+ else:
517
+ console.print(f"\n[bold cyan]Step 3/5: Typed View[/bold cyan]")
518
+ console.print(f" [dim]⊘ Skipped (no types provided)[/dim]")
519
+
520
+ # Step 4: Fuse embeddings
521
+ # This task runs completely after typed view finishes (if available)
522
+ console.print(f"\n[bold cyan]Step 4/5: Fusion[/bold cyan]")
523
+
524
+ # Build embeddings list
525
+ embeddings_list = [X_graph, X_w2v]
526
+ if X_types is not None:
527
+ embeddings_list.append(X_types)
528
+
529
+ input_dims = [graph_dim, context_dim]
530
+ if X_types is not None:
531
+ input_dims.append(typed_dim)
532
+ total_input_dim = sum(input_dims)
533
+ console.print(f" [dim]Fusing {len(embeddings_list)} views (total input dim={total_input_dim}) → {final_dim}...[/dim]")
534
+
535
+ task7a = progress.add_task("[cyan] → Concatenating views & applying PCA...", total=None)
536
+ E, pca_components, pca_mean = fuse_embeddings(
537
+ embeddings_list, final_num_nodes, final_dim, random_state, n_jobs=n_jobs,
538
+ gpu_accelerator=gpu_accelerator
539
+ )
540
+ progress.update(task7a, completed=True)
541
+
542
+ # Store embeddings before normalization for mean_norm computation
543
+ E_before_norm = E.copy()
544
+
545
+ # Compute mean_norm for length penalty
546
+ task7b = progress.add_task("[cyan] → Normalizing embeddings...", total=None)
547
+ mean_norm = float(np.mean(np.linalg.norm(E_before_norm, axis=1)))
548
+ progress.update(task7b, completed=True)
549
+
550
+ console.print(f" [dim]✓ Fused embeddings: {E.shape} (reduced from {total_input_dim}D)[/dim]")
551
+ console.print(f" [dim]✓ Mean norm: {mean_norm:.4f}[/dim]")
552
+
553
+ # Apply iterative embedding smoothing (diffusion)
554
+ # This runs after fusion completes
555
+ if smoothing_iterations > 0:
556
+ console.print(f" [dim]Applying smoothing (iterations={smoothing_iterations}, beta=0.35)...[/dim]")
557
+ task7c = progress.add_task(f"[cyan] → Smoothing embeddings via diffusion...", total=None)
558
+ E = iterative_embedding_smoothing(
559
+ E, expanded_edges, final_num_nodes,
560
+ num_iterations=smoothing_iterations, beta=0.35, random_state=random_state,
561
+ gpu_accelerator=gpu_accelerator
562
+ )
563
+ progress.update(task7c, completed=True)
564
+ console.print(f" [dim]✓ Applied {smoothing_iterations} smoothing iteration(s)[/dim]")
565
+ # Smoothing is now complete - all resources released before next task
566
+
567
+ # Step 5: Learn temperature with improved negative sampling
568
+ # This task runs completely after smoothing finishes
569
+ # In fast mode, use fewer tau candidates and negatives for faster calibration
570
+ console.print(f"\n[bold cyan]Step 5/5: Temperature Calibration[/bold cyan]")
571
+ num_negatives = 3 if fast_mode else 5
572
+ tau_candidates_count = 20 if fast_mode else 30
573
+ console.print(f" [dim]Calibrating temperature with {tau_candidates_count} candidates, "
574
+ f"{num_negatives} negatives/edge, {len(val_edges):,} validation edges...[/dim]")
575
+
576
+ task8a = progress.add_task("[cyan] → Evaluating tau candidates...", total=None)
577
+ tau_candidates = np.logspace(-2, 2, num=tau_candidates_count)
578
+ tau = learn_temperature(
579
+ E, val_edges, final_num_nodes,
580
+ num_negatives=num_negatives,
581
+ tau_candidates=tau_candidates,
582
+ random_state=random_state, n_jobs=n_jobs,
583
+ node_metadata=node_metadata,
584
+ node_file_info=node_file_info,
585
+ idx_to_node=idx_to_node
586
+ )
587
+ progress.update(task8a, completed=True)
588
+ console.print(f" [dim]✓ Learned optimal temperature: τ = {tau:.6f}[/dim]")
589
+
590
+ # Build ANN index (only for original nodes, not subtokens)
591
+ console.print(f"\n[bold cyan]Building ANN Index[/bold cyan]")
592
+ console.print(f" [dim]Building approximate nearest neighbor index ({ann_trees} trees)...[/dim]")
593
+ task9a = progress.add_task("[cyan] → Adding nodes to index...", total=None)
594
+ ann_index = AnnoyIndex(final_dim, 'angular')
595
+ # Only index original nodes (not subtokens)
596
+ for i in range(num_nodes):
597
+ ann_index.add_item(i, E[i])
598
+ progress.update(task9a, completed=True)
599
+
600
+ task9b = progress.add_task("[cyan] → Building index trees...", total=None)
601
+ ann_index.build(ann_trees) # Reduced trees in fast mode
602
+ progress.update(task9b, completed=True)
603
+ console.print(f" [dim]✓ Built ANN index with {ann_trees} trees for {num_nodes:,} nodes[/dim]")
604
+
605
+ # Save model
606
+ console.print(f"\n[bold cyan]Saving Model[/bold cyan]")
607
+ console.print(f" [dim]Writing model files to {output_dir}...[/dim]")
608
+
609
+ # Ensure E exists and is valid before saving
610
+ if E is None:
611
+ raise ValueError("Embeddings (E) are None - cannot save model")
612
+ if not isinstance(E, np.ndarray):
613
+ raise ValueError(f"Embeddings (E) must be numpy array, got {type(E)}")
614
+ if E.size == 0:
615
+ raise ValueError("Embeddings (E) is empty - cannot save model")
616
+
617
+ task10a = progress.add_task("[cyan] → Saving embeddings & components...", total=None)
618
+ os.makedirs(output_dir, exist_ok=True)
619
+
620
+ # Save embeddings (only original nodes for query, but keep full for future use)
621
+ # Save full embeddings including subtokens
622
+ embeddings_path = os.path.join(output_dir, 'embeddings.npy')
623
+ try:
624
+ np.save(embeddings_path, E)
625
+ # Verify file was written
626
+ if not os.path.exists(embeddings_path):
627
+ raise IOError(f"Failed to save embeddings.npy to {embeddings_path}")
628
+ # Verify file is readable
629
+ test_load = np.load(embeddings_path)
630
+ if test_load.shape != E.shape:
631
+ raise IOError(f"Saved embeddings shape mismatch: expected {E.shape}, got {test_load.shape}")
632
+ console.print(f" [dim]✓ Saved embeddings: {E.shape} → {embeddings_path}[/dim]")
633
+ except Exception as e:
634
+ console.print(f"[bold red]Error saving embeddings: {e}[/bold red]")
635
+ raise
636
+ progress.update(task10a, completed=True)
637
+
638
+ task10b = progress.add_task("[cyan] → Saving metadata & indices...", total=None)
639
+ # Save temperature
640
+ np.save(os.path.join(output_dir, 'tau.npy'), np.array(tau))
641
+
642
+ # Save mean_norm for length penalty
643
+ np.save(os.path.join(output_dir, 'mean_norm.npy'), np.array(mean_norm))
644
+
645
+ # Save metadata
646
+ metadata = {
647
+ 'node_map': node_to_idx,
648
+ 'node_metadata': node_metadata,
649
+ 'embedding_dim': final_dim,
650
+ 'num_nodes': num_nodes,
651
+ 'final_num_nodes': final_num_nodes # Includes subtokens
652
+ }
653
+ with open(os.path.join(output_dir, 'metadata.json'), 'w') as f:
654
+ json.dump(metadata, f, indent=2)
655
+
656
+ # Save subtoken mapping
657
+ if subtoken_to_idx:
658
+ with open(os.path.join(output_dir, 'subtoken_map.json'), 'w') as f:
659
+ json.dump(subtoken_to_idx, f, indent=2)
660
+
661
+ # Save node subtokens
662
+ if node_subtokens:
663
+ with open(os.path.join(output_dir, 'node_subtokens.json'), 'w') as f:
664
+ json.dump(node_subtokens, f, indent=2)
665
+
666
+ # Save node types (for query expansion)
667
+ if node_types is not None:
668
+ # Convert node_types from idx-based to node_id-based
669
+ node_types_by_id = {}
670
+ for node_idx, types_dict in node_types.items():
671
+ node_id = idx_to_node.get(node_idx)
672
+ if node_id:
673
+ # Convert counts to int for JSON serialization
674
+ node_types_by_id[node_id] = {k: int(v) for k, v in types_dict.items()}
675
+ if node_types_by_id:
676
+ with open(os.path.join(output_dir, 'node_types.json'), 'w') as f:
677
+ json.dump(node_types_by_id, f, indent=2)
678
+
679
+ # Save PCA components
680
+ np.save(os.path.join(output_dir, 'fusion_pca_components.npy'), pca_components)
681
+ np.save(os.path.join(output_dir, 'fusion_pca_mean.npy'), pca_mean)
682
+
683
+ # Save SVD components
684
+ np.save(os.path.join(output_dir, 'svd_components.npy'), svd_components_graph)
685
+
686
+ if svd_components_types is not None:
687
+ np.save(os.path.join(output_dir, 'svd_components_types.npy'), svd_components_types)
688
+
689
+ # Save Word2Vec
690
+ word2vec_kv.save(os.path.join(output_dir, 'word2vec.kv'))
691
+
692
+ # Save ANN index
693
+ ann_index.save(os.path.join(output_dir, 'ann_index.ann'))
694
+
695
+ # Save type token map (use final expanded version)
696
+ if final_type_to_idx is not None:
697
+ with open(os.path.join(output_dir, 'type_token_map.json'), 'w') as f:
698
+ json.dump(final_type_to_idx, f, indent=2)
699
+ elif type_to_idx is not None:
700
+ with open(os.path.join(output_dir, 'type_token_map.json'), 'w') as f:
701
+ json.dump(type_to_idx, f, indent=2)
702
+
703
+ progress.update(task10b, completed=True)
704
+
705
+ # Verify critical files were saved
706
+ critical_files = ['embeddings.npy', 'tau.npy', 'metadata.json', 'ann_index.ann']
707
+ missing_files = []
708
+ for filename in critical_files:
709
+ filepath = os.path.join(output_dir, filename)
710
+ if not os.path.exists(filepath):
711
+ missing_files.append(filename)
712
+
713
+ if missing_files:
714
+ raise IOError(
715
+ f"Critical model files missing after save: {', '.join(missing_files)}\n"
716
+ f"Model directory: {output_dir}\n"
717
+ f"This indicates a save failure. Please check disk space and permissions."
718
+ )
719
+
720
+ console.print(f" [dim]✓ Saved {len(os.listdir(output_dir))} model files[/dim]")
721
+ console.print(f" [dim]✓ Verified critical files: {', '.join(critical_files)}[/dim]")
722
+
723
+ # Calculate elapsed time
724
+ end_time = time.time()
725
+ elapsed_time = end_time - start_time
726
+
727
+ # Format elapsed time
728
+ if elapsed_time < 60:
729
+ time_str = f"{elapsed_time:.2f} seconds"
730
+ elif elapsed_time < 3600:
731
+ minutes = int(elapsed_time // 60)
732
+ seconds = elapsed_time % 60
733
+ time_str = f"{minutes}m {seconds:.2f}s"
734
+ else:
735
+ hours = int(elapsed_time // 3600)
736
+ minutes = int((elapsed_time % 3600) // 60)
737
+ seconds = elapsed_time % 60
738
+ time_str = f"{hours}h {minutes}m {seconds:.2f}s"
739
+
740
+ # Display statistics
741
+ console.print("\n[bold green]✓ Training Complete![/bold green]\n")
742
+
743
+ stats_table = Table(box=box.ROUNDED, title="Training Statistics")
744
+ stats_table.add_column("Metric", style="cyan", width=25)
745
+ stats_table.add_column("Value", style="white")
746
+ stats_table.add_row("Total Nodes", str(num_nodes))
747
+ stats_table.add_row("Total Nodes (with subtokens)", str(final_num_nodes))
748
+ stats_table.add_row("Total Edges", str(len(edges)))
749
+ stats_table.add_row("Training Edges", str(len(train_edges)))
750
+ stats_table.add_row("Validation Edges", str(len(val_edges)))
751
+ if final_type_to_idx:
752
+ stats_table.add_row("Type Tokens (expanded)", str(len(final_type_to_idx)))
753
+ elif type_to_idx:
754
+ stats_table.add_row("Type Tokens", str(len(type_to_idx)))
755
+ if subtoken_to_idx:
756
+ stats_table.add_row("Subtoken Nodes", str(len(subtoken_to_idx)))
757
+ stats_table.add_row("Graph View Dim", f"{X_graph.shape[1]}")
758
+ stats_table.add_row("Context View Dim", f"{X_w2v.shape[1]}")
759
+ if X_types is not None:
760
+ stats_table.add_row("Typed View Dim", f"{X_types.shape[1]}")
761
+ stats_table.add_row("Final Embedding Dim", f"{E.shape[1]}")
762
+ stats_table.add_row("Temperature (τ)", f"{tau:.6f}")
763
+ stats_table.add_row("Mean Norm", f"{mean_norm:.6f}")
764
+ stats_table.add_row("Model Directory", output_dir)
765
+ stats_table.add_row("", "") # Separator
766
+ stats_table.add_row("[bold]Total Training Time[/bold]", f"[bold green]{time_str}[/bold green]")
767
+
768
+ console.print(stats_table)
769
+ console.print()
770
+
771
+
772
+ @click.command()
773
+ @click.option('--nodes', '-n', required=True, type=click.Path(exists=True),
774
+ help='Path to nodes.jsonl file containing symbol definitions.\n'
775
+ 'Each line should be a JSON object with: id, kind, name, meta.')
776
+ @click.option('--edges', '-e', required=True, type=click.Path(exists=True),
777
+ help='Path to edges.jsonl file containing symbol relationships.\n'
778
+ 'Each line should be a JSON object with: src, dst, rel, weight.')
779
+ @click.option('--types', '-t', default=None, type=click.Path(exists=True),
780
+ help='[Optional] Path to types.jsonl file containing type token information.\n'
781
+ 'Each line should be a JSON object with: symbol, type_token, count.')
782
+ @click.option('--out', '-o', required=True, type=click.Path(),
783
+ help='Output directory where the trained model will be saved.\n'
784
+ 'Directory will be created if it does not exist.')
785
+ @click.option('--graph-dim', default=None, type=int, show_default=False,
786
+ help='Dimensionality for the graph view embeddings.\n'
787
+ 'Range: 8-512. Higher values capture more graph structure but increase computation.\n'
788
+ 'If not specified, computed automatically based on number of nodes.\n'
789
+ 'Recommended: 32-128 for small codebases, 64-256 for large ones.')
790
+ @click.option('--context-dim', default=None, type=int, show_default=False,
791
+ help='Dimensionality for the context view embeddings (Node2Vec + Word2Vec).\n'
792
+ 'Range: 8-512. Higher values capture more semantic context.\n'
793
+ 'If not specified, computed automatically to match graph-dim.\n'
794
+ 'Recommended: 32-128. Should match graph-dim for balanced fusion.')
795
+ @click.option('--typed-dim', default=None, type=int, show_default=False,
796
+ help='Dimensionality for the typed view embeddings.\n'
797
+ 'Range: 8-512. Only used if --types file is provided.\n'
798
+ 'If not specified, computed automatically based on number of type tokens.\n'
799
+ 'Recommended: 32-128. Should match other view dimensions.')
800
+ @click.option('--final-dim', default=None, type=int, show_default=False,
801
+ help='Final dimensionality of fused embeddings after PCA reduction.\n'
802
+ 'Range: 16-512. This is the output embedding size used for queries.\n'
803
+ 'If not specified, computed automatically based on input dimensions and node count.\n'
804
+ 'Recommended: 64-256. Higher values preserve more information but increase memory.')
805
+ @click.option('--num-walks', default=10, type=int, show_default=True,
806
+ help='Number of random walks to generate per node for context view.\n'
807
+ 'Range: 5-100. More walks improve context quality but increase training time.\n'
808
+ 'Recommended: 10-20 for most codebases.')
809
+ @click.option('--walk-length', default=80, type=int, show_default=True,
810
+ help='Length of each random walk in the context view.\n'
811
+ 'Range: 20-200. Longer walks capture more distant relationships.\n'
812
+ 'Recommended: 40-100. Shorter for small graphs, longer for large ones.')
813
+ @click.option('--train-ratio', default=0.8, type=float, show_default=True,
814
+ help='Fraction of edges used for training (rest used for temperature calibration).\n'
815
+ 'Range: 0.5-0.95. Higher values use more data for training but less for calibration.\n'
816
+ 'Recommended: 0.8. Lower values (0.7-0.75) improve temperature estimation.')
817
+ @click.option('--random-state', default=42, type=int, show_default=True,
818
+ help='Random seed for reproducibility.\n'
819
+ 'Range: Any integer. Use the same seed to get identical results.\n'
820
+ 'Recommended: 42 (classic choice) or any fixed integer for reproducibility.')
821
+ def main(nodes, edges, types, out, graph_dim, context_dim, typed_dim, final_dim,
822
+ num_walks, walk_length, train_ratio, random_state):
823
+ """
824
+ Train TriVector Code Intelligence (TVI) model.
825
+
826
+ TVI learns symbol-level semantics from codebases using three complementary views:
827
+
828
+ \b
829
+ 1. Graph View: Structural relationships via PPMI and SVD
830
+ 2. Context View: Semantic context via Node2Vec random walks and Word2Vec
831
+ 3. Typed View: Type information via type-token co-occurrence (optional)
832
+
833
+ The views are fused using PCA and normalized to produce final embeddings.
834
+ A temperature parameter is learned for calibrated similarity scores.
835
+
836
+ \b
837
+ Example:
838
+ python train_tvci.py --nodes nodes.jsonl --edges edges.jsonl --out model_dir
839
+ """
840
+ train_model(
841
+ nodes_path=nodes,
842
+ edges_path=edges,
843
+ types_path=types,
844
+ output_dir=out,
845
+ graph_dim=graph_dim,
846
+ context_dim=context_dim,
847
+ typed_dim=typed_dim,
848
+ final_dim=final_dim,
849
+ num_walks=num_walks,
850
+ walk_length=walk_length,
851
+ train_ratio=train_ratio,
852
+ random_state=random_state
853
+ )
854
+
855
+
856
+ if __name__ == '__main__':
857
+ main()