tricoder 1.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tricoder/cli.py ADDED
@@ -0,0 +1,890 @@
1
+ #!/usr/bin/env python3
2
+ """Command-line interface for TriCoder."""
3
+ import os
4
+ import shlex
5
+ from typing import List
6
+
7
+ import click
8
+ from rich import box
9
+ from rich.console import Console
10
+ from rich.table import Table
11
+ from rich.prompt import Prompt
12
+
13
+ from .git_tracker import (
14
+ get_git_commit_hash, get_git_commit_timestamp, get_changed_files_for_retraining,
15
+ save_training_metadata, extract_files_from_jsonl, get_all_python_files
16
+ )
17
+ from .model import SymbolModel
18
+ from .train import train_model
19
+
20
+ console = Console()
21
+
22
+
23
+ def get_tricoder_dir() -> str:
24
+ """Get the default .tricoder directory path and ensure it exists."""
25
+ cwd = os.getcwd()
26
+ tricoder_dir = os.path.join(cwd, '.tricoder')
27
+ os.makedirs(tricoder_dir, exist_ok=True)
28
+ return tricoder_dir
29
+
30
+
31
+ def get_model_dir() -> str:
32
+ """Get the default model directory path and ensure it exists."""
33
+ tricoder_dir = get_tricoder_dir()
34
+ model_dir = os.path.join(tricoder_dir, 'model')
35
+ os.makedirs(model_dir, exist_ok=True)
36
+ return model_dir
37
+
38
+
39
+ def is_valid_model_dir(path: str) -> bool:
40
+ """Check if a directory contains a valid model."""
41
+ embeddings_path = os.path.join(path, 'embeddings.npy')
42
+ metadata_path = os.path.join(path, 'metadata.json')
43
+ return os.path.exists(embeddings_path) and os.path.exists(metadata_path)
44
+
45
+
46
+ def discover_models(root_dir: str) -> List[str]:
47
+ """
48
+ Recursively discover all model directories starting from root_dir.
49
+
50
+ Args:
51
+ root_dir: Root directory to search for models
52
+
53
+ Returns:
54
+ List of model directory paths (sorted, with root_dir first if it's a model)
55
+ """
56
+ models = []
57
+
58
+ if not os.path.exists(root_dir):
59
+ return models
60
+
61
+ # Check if root_dir itself is a model
62
+ if is_valid_model_dir(root_dir):
63
+ models.append(root_dir)
64
+
65
+ # Recursively search subdirectories
66
+ for root, dirs, files in os.walk(root_dir):
67
+ # Skip hidden directories
68
+ dirs[:] = [d for d in dirs if not d.startswith('.')]
69
+
70
+ for d in dirs:
71
+ subdir = os.path.join(root, d)
72
+ if is_valid_model_dir(subdir):
73
+ models.append(subdir)
74
+
75
+ # Sort: root_dir first if it's a model, then alphabetically
76
+ def sort_key(path):
77
+ if path == root_dir:
78
+ return (0, path)
79
+ return (1, path)
80
+
81
+ return sorted(models, key=sort_key)
82
+
83
+
84
+ @click.group()
85
+ def cli():
86
+ """TriCoder - TriVector Code Intelligence for semantic code analysis."""
87
+ pass
88
+
89
+
90
+ @cli.command(name='train')
91
+ @click.option('--nodes', '-n', default=None, type=click.Path(),
92
+ help='Path to nodes.jsonl file (default: .tricoder/nodes.jsonl).')
93
+ @click.option('--edges', '-e', default=None, type=click.Path(),
94
+ help='Path to edges.jsonl file (default: .tricoder/edges.jsonl).')
95
+ @click.option('--types', '-t', default=None, type=click.Path(),
96
+ help='[Optional] Path to types.jsonl file (default: .tricoder/types.jsonl).')
97
+ @click.option('--out', '-o', default=None, type=click.Path(),
98
+ help='Output directory for trained model (default: .tricoder/model).')
99
+ @click.option('--graph-dim', default=None, type=int, show_default=False,
100
+ help='Dimensionality for the graph view embeddings.')
101
+ @click.option('--context-dim', default=None, type=int, show_default=False,
102
+ help='Dimensionality for the context view embeddings.')
103
+ @click.option('--typed-dim', default=None, type=int, show_default=False,
104
+ help='Dimensionality for the typed view embeddings.')
105
+ @click.option('--final-dim', default=None, type=int, show_default=False,
106
+ help='Final dimensionality of fused embeddings after PCA reduction.')
107
+ @click.option('--num-walks', default=10, type=int, show_default=True,
108
+ help='Number of random walks to generate per node for context view.')
109
+ @click.option('--walk-length', default=80, type=int, show_default=True,
110
+ help='Length of each random walk in the context view.')
111
+ @click.option('--train-ratio', default=0.8, type=float, show_default=True,
112
+ help='Fraction of edges used for training (rest used for temperature calibration).')
113
+ @click.option('--random-state', default=42, type=int, show_default=True,
114
+ help='Random seed for reproducibility.')
115
+ @click.option('--fast', is_flag=True, default=False,
116
+ help='Enable fast mode: reduces walk parameters by half for faster training (slightly lower quality).')
117
+ @click.option('--use-gpu', is_flag=True, default=False,
118
+ help='Enable GPU acceleration: CUDA (NVIDIA) via CuPy, or MPS (Mac) via PyTorch. Falls back to CPU if GPU unavailable.')
119
+ def train(nodes, edges, types, out, graph_dim, context_dim, typed_dim, final_dim,
120
+ num_walks, walk_length, train_ratio, random_state, fast, use_gpu):
121
+ """Train TriCoder model on codebase symbols and relationships."""
122
+ # Use default .tricoder directory if paths not specified
123
+ tricoder_dir = get_tricoder_dir()
124
+ nodes_path = nodes if nodes else os.path.join(tricoder_dir, 'nodes.jsonl')
125
+ edges_path = edges if edges else os.path.join(tricoder_dir, 'edges.jsonl')
126
+ types_path = types if types else os.path.join(tricoder_dir, 'types.jsonl')
127
+ output_dir = out if out else get_model_dir()
128
+
129
+ # Handle optional types file - only use if it exists
130
+ if not os.path.exists(types_path):
131
+ types_path = None
132
+
133
+ train_model(
134
+ nodes_path=nodes_path,
135
+ edges_path=edges_path,
136
+ types_path=types_path,
137
+ output_dir=output_dir,
138
+ graph_dim=graph_dim,
139
+ context_dim=context_dim,
140
+ typed_dim=typed_dim,
141
+ final_dim=final_dim,
142
+ num_walks=num_walks,
143
+ walk_length=walk_length,
144
+ train_ratio=train_ratio,
145
+ random_state=random_state,
146
+ fast_mode=fast,
147
+ use_gpu=use_gpu
148
+ )
149
+
150
+ # Save git metadata after training
151
+ commit_hash = get_git_commit_hash()
152
+ commit_timestamp = get_git_commit_timestamp()
153
+ files_trained = extract_files_from_jsonl(nodes_path)
154
+ files_trained.update(extract_files_from_jsonl(edges_path))
155
+ if types_path and os.path.exists(types_path):
156
+ files_trained.update(extract_files_from_jsonl(types_path))
157
+
158
+ save_training_metadata(output_dir, commit_hash, commit_timestamp, files_trained)
159
+ console.print(f"[dim]Saved training metadata (commit: {commit_hash[:8] if commit_hash else 'N/A'})[/dim]")
160
+
161
+
162
+ @cli.command(name='query')
163
+ @click.option('--model-dir', '-m', default=None, help='Path to model directory (default: discovers models in .tricoder/model)')
164
+ @click.option('--symbol', '-s', help='Symbol ID to query')
165
+ @click.option('--keywords', '-w', help='Keywords to search for (use quotes for multi-word: "my function")')
166
+ @click.option('--top-k', '-k', default=10, help='Number of results to return')
167
+ @click.option('--exclude-keywords', '--exclude', multiple=True,
168
+ help='Additional keywords to exclude from search (can be specified multiple times). '
169
+ 'These are appended to the default excluded keywords (Python builtins, etc.)')
170
+ @click.option('--interactive', '-i', is_flag=True, help='Interactive mode')
171
+ @click.option('--no-recursive', '--no-discover', is_flag=True, default=False,
172
+ help='Disable recursive model discovery. Only use the base model directory directly.')
173
+ def query(model_dir, symbol, keywords, top_k, exclude_keywords, interactive, no_recursive):
174
+ """Query the TriCoder model for similar symbols."""
175
+ # Discover models if not specified
176
+ if model_dir is None:
177
+ base_model_dir = get_model_dir()
178
+ if no_recursive:
179
+ # Only check the base directory itself, don't search recursively
180
+ if is_valid_model_dir(base_model_dir):
181
+ discovered_models = [base_model_dir]
182
+ else:
183
+ discovered_models = []
184
+ else:
185
+ discovered_models = discover_models(base_model_dir)
186
+
187
+ if not discovered_models:
188
+ console.print(f"[bold red]No models found in {base_model_dir}[/bold red]")
189
+ console.print(f"[yellow]Please train a model first using: tricoder train[/yellow]")
190
+ return
191
+
192
+ # Print one-liner summary at the beginning
193
+ if len(discovered_models) > 1:
194
+ console.print(f"[dim]Found {len(discovered_models)} models[/dim]")
195
+
196
+ if len(discovered_models) == 1:
197
+ # Only one model found, use it automatically
198
+ model_dir = discovered_models[0]
199
+ console.print(f"[dim]Found 1 model: {os.path.relpath(model_dir)}[/dim]")
200
+ else:
201
+ # Multiple models found, ask user to select
202
+ console.print(f"[bold cyan]Found {len(discovered_models)} models:[/bold cyan]\n")
203
+
204
+ # Display models in a tree-like structure
205
+ try:
206
+ base_path = os.path.commonpath([base_model_dir] + discovered_models)
207
+ except ValueError:
208
+ # Fallback if paths are on different drives (Windows) or can't find common path
209
+ base_path = base_model_dir
210
+
211
+ for idx, model_path in enumerate(discovered_models, 1):
212
+ try:
213
+ rel_path = os.path.relpath(model_path, base_path)
214
+ # Normalize path separators for display
215
+ rel_path = rel_path.replace('\\', '/')
216
+ if rel_path == '.':
217
+ rel_path = 'model'
218
+ except ValueError:
219
+ rel_path = os.path.basename(model_path) or model_path
220
+
221
+ # Count symbols in model
222
+ num_nodes = 0
223
+ try:
224
+ import json
225
+ metadata_path = os.path.join(model_path, 'metadata.json')
226
+ if os.path.exists(metadata_path):
227
+ with open(metadata_path, 'r') as f:
228
+ metadata = json.load(f)
229
+ num_nodes = metadata.get('num_nodes', 0)
230
+ except Exception:
231
+ pass
232
+
233
+ if num_nodes > 0:
234
+ console.print(f" [cyan]{idx}.[/cyan] [white]{rel_path}[/white] [dim]({num_nodes:,} symbols)[/dim]")
235
+ else:
236
+ console.print(f" [cyan]{idx}.[/cyan] [white]{rel_path}[/white]")
237
+
238
+ console.print()
239
+ try:
240
+ choice = Prompt.ask(
241
+ "[bold cyan]Select model to query[/bold cyan]",
242
+ default="1",
243
+ choices=[str(i) for i in range(1, len(discovered_models) + 1)]
244
+ )
245
+ model_dir = discovered_models[int(choice) - 1]
246
+ except (ValueError, IndexError, KeyboardInterrupt):
247
+ console.print("[bold yellow]Cancelled[/bold yellow]")
248
+ return
249
+
250
+ console.print(f"[bold green]Loading model from {os.path.relpath(model_dir)}...[/bold green]")
251
+
252
+ try:
253
+ model = SymbolModel()
254
+ model.load(model_dir)
255
+ console.print("[bold green]✓ Model loaded successfully[/bold green]\n")
256
+ except Exception as e:
257
+ console.print(f"[bold red]Error loading model: {e}[/bold red]")
258
+ return
259
+
260
+ # Build excluded keywords set (default + user-provided)
261
+ excluded_keywords_set = None
262
+ if exclude_keywords:
263
+ from .model import DEFAULT_EXCLUDED_KEYWORDS
264
+ excluded_keywords_set = DEFAULT_EXCLUDED_KEYWORDS | {kw.lower() for kw in exclude_keywords}
265
+ console.print(f"[dim]Excluding {len(excluded_keywords_set)} keywords "
266
+ f"({len(exclude_keywords)} user-added)[/dim]\n")
267
+
268
+ if interactive:
269
+ interactive_mode(model, excluded_keywords_set)
270
+ elif symbol:
271
+ display_results(model, symbol, top_k)
272
+ elif keywords:
273
+ # Parse keywords (handle quoted strings)
274
+ keywords_parsed = parse_keywords(keywords)
275
+ search_and_display_results(model, keywords_parsed, top_k, excluded_keywords_set)
276
+ else:
277
+ console.print("[bold yellow]Please provide --symbol, --keywords, or use --interactive mode[/bold yellow]")
278
+
279
+
280
+ def display_results(model, symbol_id, top_k):
281
+ """Display query results in a formatted table."""
282
+ results = model.query(symbol_id, top_k)
283
+
284
+ if not results:
285
+ console.print(f"[bold yellow]No results found for symbol: {symbol_id}[/bold yellow]")
286
+ return
287
+
288
+ # Get query symbol info
289
+ query_meta = None
290
+ if model.metadata_lookup:
291
+ query_meta = model.metadata_lookup.get(symbol_id)
292
+
293
+ console.print(f"\n[bold cyan]Query:[/bold cyan] {symbol_id}")
294
+ if query_meta:
295
+ console.print(f" [dim]Kind:[/dim] {query_meta.get('kind', 'unknown')}")
296
+ console.print(f" [dim]Name:[/dim] {query_meta.get('name', 'unknown')}")
297
+ query_meta_dict = query_meta.get('meta', {})
298
+ if isinstance(query_meta_dict, dict):
299
+ query_file = query_meta_dict.get('file', '')
300
+ query_lineno = query_meta_dict.get('lineno', None)
301
+ if query_file:
302
+ if query_lineno is not None and query_lineno >= 0:
303
+ console.print(f" [dim]File:[/dim] {query_file}:{query_lineno}")
304
+ else:
305
+ console.print(f" [dim]File:[/dim] {query_file}")
306
+
307
+ console.print(f"\n[bold cyan]Top {len(results)} Similar Symbols:[/bold cyan]\n")
308
+
309
+ for idx, result in enumerate(results, 1):
310
+ meta = result.get('meta', {})
311
+ meta_dict = meta.get('meta', {}) if isinstance(meta.get('meta'), dict) else {}
312
+ file_path = meta_dict.get('file', '') if meta_dict.get('file') else ''
313
+ lineno = meta_dict.get('lineno', None)
314
+
315
+ console.print(f"[dim]{idx}.[/dim] [cyan]{result['symbol']:15}[/cyan] "
316
+ f"[green]Score: {result['score']:8.4f}[/green] "
317
+ f"[yellow]Dist: {result['distance']:6.4f}[/yellow] "
318
+ f"[blue]{meta.get('kind', 'unknown'):10}[/blue] "
319
+ f"[white]{meta.get('name', ''):30}[/white]")
320
+ if file_path:
321
+ if lineno is not None and lineno >= 0:
322
+ console.print(f" [dim]→ {file_path}:{lineno}[/dim]")
323
+ else:
324
+ console.print(f" [dim]→ {file_path}[/dim]")
325
+
326
+ console.print()
327
+
328
+
329
+ def parse_keywords(keywords_str: str) -> str:
330
+ """
331
+ Parse keywords string, handling quoted strings.
332
+
333
+ Args:
334
+ keywords_str: Input string that may contain quoted keywords
335
+
336
+ Returns:
337
+ Parsed keywords string
338
+ """
339
+ try:
340
+ # Use shlex to properly parse quoted strings
341
+ parts = shlex.split(keywords_str)
342
+ return ' '.join(parts)
343
+ except ValueError:
344
+ # If parsing fails, return as-is (handles unclosed quotes)
345
+ return keywords_str.strip()
346
+
347
+
348
+ def search_and_display_results(model, keywords: str, top_k: int, excluded_keywords: set = None):
349
+ """Search for symbols by keywords and display results."""
350
+ from .model import DEFAULT_EXCLUDED_KEYWORDS
351
+
352
+ # Use provided excluded keywords or default
353
+ if excluded_keywords is None:
354
+ excluded_keywords = DEFAULT_EXCLUDED_KEYWORDS
355
+
356
+ # Check if any keywords are excluded
357
+ keyword_words = keywords.lower().split()
358
+ excluded_found = [w for w in keyword_words if w in excluded_keywords]
359
+
360
+ matches = model.search_by_keywords(keywords, top_k, excluded_keywords=excluded_keywords)
361
+
362
+ # Show warning if excluded keywords were filtered
363
+ if excluded_found:
364
+ console.print(f"[yellow]Note: Filtered out excluded keywords: {', '.join(excluded_found)}[/yellow]")
365
+ console.print("[dim]These are Python builtins/keywords that don't provide useful search results.[/dim]\n")
366
+
367
+ if not matches:
368
+ if excluded_found and len(excluded_found) == len(keyword_words):
369
+ console.print(f"[bold yellow]All keywords were filtered out (Python builtins/keywords).[/bold yellow]")
370
+ console.print(f"[dim]Try searching for user-defined code patterns instead of language constructs.[/dim]")
371
+ else:
372
+ console.print(f"[bold yellow]No symbols found matching keywords: {keywords}[/bold yellow]")
373
+ return
374
+
375
+ console.print(f"\n[bold cyan]Search Results for:[/bold cyan] \"{keywords}\"")
376
+ console.print(f"[bold cyan]Found {len(matches)} matching symbol(s):[/bold cyan]\n")
377
+
378
+ for idx, match in enumerate(matches, 1):
379
+ meta = match.get('meta', {})
380
+ meta_dict = meta.get('meta', {}) if isinstance(meta.get('meta'), dict) else {}
381
+ file_path = meta_dict.get('file', '') if meta_dict.get('file') else ''
382
+ lineno = meta_dict.get('lineno', None)
383
+
384
+ console.print(f"[dim]{idx}.[/dim] [cyan]{match['symbol']:15}[/cyan] "
385
+ f"[green]Relevance: {match['score']:6.4f}[/green] "
386
+ f"[blue]{meta.get('kind', 'unknown'):10}[/blue] "
387
+ f"[white]{meta.get('name', ''):30}[/white]")
388
+ if file_path:
389
+ if lineno is not None and lineno >= 0:
390
+ console.print(f" [dim]→ {file_path}:{lineno}[/dim]")
391
+ else:
392
+ console.print(f" [dim]→ {file_path}[/dim]")
393
+
394
+ console.print()
395
+
396
+ # If there are matches, ask if user wants to query the first one
397
+ if matches:
398
+ first_match = matches[0]
399
+ console.print(f"[dim]Tip: Query similar symbols with: --symbol {first_match['symbol']}[/dim]\n")
400
+
401
+
402
+ def interactive_mode(model, excluded_keywords: set = None):
403
+ """Interactive query mode."""
404
+ from .model import DEFAULT_EXCLUDED_KEYWORDS
405
+
406
+ # Use provided excluded keywords or default
407
+ if excluded_keywords is None:
408
+ excluded_keywords = DEFAULT_EXCLUDED_KEYWORDS
409
+
410
+ console.print("[bold green]Entering interactive mode. Type 'quit' or 'exit' to quit.[/bold green]")
411
+ console.print("[dim]You can search by symbol ID or keywords (use quotes for multi-word)[/dim]")
412
+ console.print(f"[dim]Excluding {len(excluded_keywords)} keywords (Python builtins, etc.)[/dim]\n")
413
+
414
+ while True:
415
+ try:
416
+ query_input = click.prompt("\n[bold cyan]Enter symbol ID or keywords[/bold cyan]", type=str)
417
+
418
+ if query_input.lower() in ['quit', 'exit', 'q']:
419
+ console.print("[bold yellow]Goodbye![/bold yellow]")
420
+ break
421
+
422
+ top_k = click.prompt("Number of results", default=10, type=int)
423
+
424
+ # Check if it looks like a symbol ID (starts with 'sym_') or try as keywords
425
+ if query_input.startswith('sym_') and query_input in model.node_map:
426
+ display_results(model, query_input, top_k)
427
+ else:
428
+ # Try as keywords
429
+ keywords_parsed = parse_keywords(query_input)
430
+ search_and_display_results(model, keywords_parsed, top_k, excluded_keywords)
431
+
432
+ except KeyboardInterrupt:
433
+ console.print("\n[bold yellow]Goodbye![/bold yellow]")
434
+ break
435
+ except Exception as e:
436
+ console.print(f"[bold red]Error: {e}[/bold red]")
437
+
438
+
439
+ @cli.command(name='extract')
440
+ @click.option('--input-dir', '--root', '-r', default='.',
441
+ type=click.Path(exists=True, file_okay=False, dir_okay=True),
442
+ help='Input directory to scan for files.')
443
+ @click.option('--include-dirs', '-i', multiple=True,
444
+ help='Include only these subdirectories (can be specified multiple times).')
445
+ @click.option('--exclude-dirs', '-e', multiple=True,
446
+ default=['.venv', '__pycache__', '.git', 'node_modules', '.pytest_cache'],
447
+ help='Exclude these directories (can be specified multiple times).')
448
+ @click.option('--extensions', '--ext', default='py',
449
+ help='Comma-separated list of file extensions to process (e.g., "py,js,ts"). Default: py')
450
+ @click.option('--exclude-keywords', '--exclude', multiple=True,
451
+ help='Symbol names to exclude from extraction (can be specified multiple times). '
452
+ 'These are appended to the default excluded keywords (Python builtins, etc.).')
453
+ @click.option('--output-nodes', '-n', default=None,
454
+ help='Output file for nodes (default: .tricoder/nodes.jsonl)')
455
+ @click.option('--output-edges', '-d', default=None,
456
+ help='Output file for edges (default: .tricoder/edges.jsonl)')
457
+ @click.option('--output-types', '-t', default=None,
458
+ help='Output file for types (default: .tricoder/types.jsonl)')
459
+ @click.option('--no-gitignore', is_flag=True, default=False,
460
+ help='Disable .gitignore filtering (enabled by default)')
461
+ def extract(input_dir, include_dirs, exclude_dirs, extensions, exclude_keywords, output_nodes, output_edges, output_types, no_gitignore):
462
+ """Extract symbols and relationships from codebase."""
463
+ from .extract import extract_from_directory
464
+ from .model import DEFAULT_EXCLUDED_KEYWORDS
465
+
466
+ # Use default .tricoder directory if paths not specified
467
+ tricoder_dir = get_tricoder_dir()
468
+ output_nodes_path = output_nodes if output_nodes else os.path.join(tricoder_dir, 'nodes.jsonl')
469
+ output_edges_path = output_edges if output_edges else os.path.join(tricoder_dir, 'edges.jsonl')
470
+ output_types_path = output_types if output_types else os.path.join(tricoder_dir, 'types.jsonl')
471
+
472
+ # Parse extensions: split by comma, strip whitespace, remove dots if present
473
+ ext_list = [ext.strip().lstrip('.') for ext in extensions.split(',') if ext.strip()]
474
+ if not ext_list:
475
+ ext_list = ['py'] # Default to Python if empty
476
+
477
+ # Build excluded keywords set (default + user-provided)
478
+ excluded_keywords_set = None
479
+ if exclude_keywords:
480
+ excluded_keywords_set = DEFAULT_EXCLUDED_KEYWORDS | {kw.lower() for kw in exclude_keywords}
481
+ console.print(f"[dim]Excluding {len(excluded_keywords_set)} symbol names "
482
+ f"({len(exclude_keywords)} user-added)[/dim]\n")
483
+ else:
484
+ excluded_keywords_set = DEFAULT_EXCLUDED_KEYWORDS
485
+
486
+ extract_from_directory(
487
+ root_dir=input_dir,
488
+ include_dirs=list(include_dirs) if include_dirs else [],
489
+ exclude_dirs=list(exclude_dirs) if exclude_dirs else [],
490
+ extensions=ext_list,
491
+ excluded_keywords=excluded_keywords_set,
492
+ output_nodes=output_nodes_path,
493
+ output_edges=output_edges_path,
494
+ output_types=output_types_path,
495
+ use_gitignore=not no_gitignore
496
+ )
497
+
498
+
499
+ @cli.command(name='optimize')
500
+ @click.option('--nodes', '-n', default=None, type=click.Path(),
501
+ help='Path to nodes.jsonl file (default: .tricoder/nodes.jsonl)')
502
+ @click.option('--edges', '-e', default=None, type=click.Path(),
503
+ help='Path to edges.jsonl file (default: .tricoder/edges.jsonl)')
504
+ @click.option('--types', '-t', default=None, type=click.Path(),
505
+ help='Path to types.jsonl file (default: .tricoder/types.jsonl, optional)')
506
+ @click.option('--output-nodes', '-N', default=None, type=click.Path(),
507
+ help='Output path for optimized nodes (default: overwrites input)')
508
+ @click.option('--output-edges', '-E', default=None, type=click.Path(),
509
+ help='Output path for optimized edges (default: overwrites input)')
510
+ @click.option('--output-types', '-T', default=None, type=click.Path(),
511
+ help='Output path for optimized types (default: overwrites input)')
512
+ @click.option('--min-edge-weight', default=0.3, type=float,
513
+ help='Minimum edge weight to keep (default: 0.3)')
514
+ @click.option('--remove-isolated', is_flag=True, default=True,
515
+ help='Remove nodes with no edges (default: True)')
516
+ @click.option('--keep-isolated', is_flag=True, default=False,
517
+ help='Keep isolated nodes (overrides --remove-isolated)')
518
+ @click.option('--remove-generic', is_flag=True, default=True,
519
+ help='Remove nodes with generic names (default: True)')
520
+ @click.option('--keep-generic', is_flag=True, default=False,
521
+ help='Keep generic names (overrides --remove-generic)')
522
+ @click.option('--exclude-keywords', '--exclude', multiple=True,
523
+ help='Additional keywords to exclude (can be specified multiple times)')
524
+ def optimize(nodes, edges, types, output_nodes, output_edges, output_types,
525
+ min_edge_weight, remove_isolated, keep_isolated, remove_generic, keep_generic, exclude_keywords):
526
+ """Optimize nodes and edges by filtering out low-value entries.
527
+
528
+ This command removes:
529
+ - Nodes with generic names (single letters, common names like 'temp', 'var', etc.)
530
+ - Isolated nodes (nodes with no edges)
531
+ - Low-weight edges (below minimum threshold)
532
+ - Nodes matching excluded keywords
533
+
534
+ This reduces the graph size while preserving meaningful relationships.
535
+ """
536
+ from .optimize import optimize_nodes_and_edges
537
+ from .model import DEFAULT_EXCLUDED_KEYWORDS
538
+
539
+ # Use default .tricoder directory if paths not specified
540
+ tricoder_dir = get_tricoder_dir()
541
+ nodes_path = nodes if nodes else os.path.join(tricoder_dir, 'nodes.jsonl')
542
+ edges_path = edges if edges else os.path.join(tricoder_dir, 'edges.jsonl')
543
+ types_path = types if types else os.path.join(tricoder_dir, 'types.jsonl')
544
+
545
+ # Default output paths: overwrite input if not specified
546
+ output_nodes_path = output_nodes if output_nodes else nodes_path
547
+ output_edges_path = output_edges if output_edges else edges_path
548
+ output_types_path = output_types if output_types else types_path
549
+
550
+ # Build excluded keywords set
551
+ excluded_keywords_set = DEFAULT_EXCLUDED_KEYWORDS
552
+ if exclude_keywords:
553
+ excluded_keywords_set = excluded_keywords_set | {kw.lower() for kw in exclude_keywords}
554
+
555
+ # Check if input files exist
556
+ if not os.path.exists(nodes_path):
557
+ console.print(f"[bold red]Error: Nodes file not found: {nodes_path}[/bold red]")
558
+ return
559
+ if not os.path.exists(edges_path):
560
+ console.print(f"[bold red]Error: Edges file not found: {edges_path}[/bold red]")
561
+ return
562
+
563
+ # Handle flags
564
+ remove_isolated_nodes = remove_isolated and not keep_isolated
565
+ remove_generic_names = remove_generic and not keep_generic
566
+
567
+ console.print("[bold cyan]Optimizing nodes and edges...[/bold cyan]\n")
568
+ console.print(f"[dim]Min edge weight: {min_edge_weight}[/dim]")
569
+ console.print(f"[dim]Remove isolated nodes: {remove_isolated_nodes}[/dim]")
570
+ console.print(f"[dim]Remove generic names: {remove_generic_names}[/dim]")
571
+ console.print(f"[dim]Excluded keywords: {len(excluded_keywords_set)}[/dim]\n")
572
+
573
+ try:
574
+ nodes_removed, edges_removed, types_removed, stats = optimize_nodes_and_edges(
575
+ nodes_path=nodes_path,
576
+ edges_path=edges_path,
577
+ types_path=types_path if types_path and os.path.exists(types_path) else None,
578
+ output_nodes=output_nodes_path,
579
+ output_edges=output_edges_path,
580
+ output_types=output_types_path,
581
+ min_edge_weight=min_edge_weight,
582
+ remove_isolated=remove_isolated_nodes,
583
+ remove_generic_names=remove_generic_names,
584
+ excluded_keywords=excluded_keywords_set
585
+ )
586
+
587
+ console.print(f"\n[bold green]✓ Optimization complete![/bold green]\n")
588
+
589
+ # Overall statistics
590
+ from rich.table import Table
591
+ stats_table = Table(title="Optimization Statistics", box=box.ROUNDED, show_header=True)
592
+ stats_table.add_column("Metric", style="cyan", width=25)
593
+ stats_table.add_column("Original", style="white", justify="right", width=12)
594
+ stats_table.add_column("Final", style="green", justify="right", width=12)
595
+ stats_table.add_column("Removed", style="yellow", justify="right", width=12)
596
+ stats_table.add_column("Reduction", style="dim", justify="right", width=12)
597
+
598
+ # Calculate percentages
599
+ node_reduction = (nodes_removed / stats['original']['nodes'] * 100) if stats['original']['nodes'] > 0 else 0
600
+ edge_reduction = (edges_removed / stats['original']['edges'] * 100) if stats['original']['edges'] > 0 else 0
601
+ type_reduction = (types_removed / stats['original']['types'] * 100) if stats['original']['types'] > 0 else 0
602
+
603
+ stats_table.add_row(
604
+ "Nodes",
605
+ f"{stats['original']['nodes']:,}",
606
+ f"{stats['final']['nodes']:,}",
607
+ f"{stats['removed']['nodes']:,}",
608
+ f"{node_reduction:.1f}%"
609
+ )
610
+ stats_table.add_row(
611
+ "Edges",
612
+ f"{stats['original']['edges']:,}",
613
+ f"{stats['final']['edges']:,}",
614
+ f"{stats['removed']['edges']:,}",
615
+ f"{edge_reduction:.1f}%"
616
+ )
617
+ if stats['original']['types'] > 0:
618
+ stats_table.add_row(
619
+ "Types",
620
+ f"{stats['original']['types']:,}",
621
+ f"{stats['final']['types']:,}",
622
+ f"{stats['removed']['types']:,}",
623
+ f"{type_reduction:.1f}%"
624
+ )
625
+
626
+ console.print(stats_table)
627
+
628
+ # Removal reasons
629
+ console.print(f"\n[bold cyan]Removal Breakdown:[/bold cyan]")
630
+ reasons_table = Table(show_header=False, box=None)
631
+ reasons_table.add_column("Reason", style="dim", width=30)
632
+ reasons_table.add_column("Count", style="yellow", justify="right", width=15)
633
+
634
+ if stats['removal_reasons']['excluded_keywords'] > 0:
635
+ reasons_table.add_row("Excluded keywords", f"{stats['removal_reasons']['excluded_keywords']:,}")
636
+ if stats['removal_reasons']['generic_names'] > 0:
637
+ reasons_table.add_row("Generic names", f"{stats['removal_reasons']['generic_names']:,}")
638
+ if stats['removal_reasons']['isolated'] > 0:
639
+ reasons_table.add_row("Isolated nodes", f"{stats['removal_reasons']['isolated']:,}")
640
+ if stats['removal_reasons']['orphaned_edges'] > 0:
641
+ reasons_table.add_row("Orphaned edges (node removed)", f"{stats['removal_reasons']['orphaned_edges']:,}")
642
+ if stats['removal_reasons']['low_weight_edges'] > 0:
643
+ reasons_table.add_row(f"Low-weight edges (<{min_edge_weight})", f"{stats['removal_reasons']['low_weight_edges']:,}")
644
+
645
+ console.print(reasons_table)
646
+
647
+ # Statistics by kind
648
+ console.print(f"\n[bold cyan]Statistics by Kind:[/bold cyan]")
649
+ kind_table = Table(show_header=True, box=box.ROUNDED)
650
+ kind_table.add_column("Kind", style="cyan", width=15)
651
+ kind_table.add_column("Original", style="white", justify="right", width=12)
652
+ kind_table.add_column("Removed", style="yellow", justify="right", width=12)
653
+ kind_table.add_column("Final", style="green", justify="right", width=12)
654
+
655
+ for kind in sorted(stats['by_kind'].keys()):
656
+ kind_stats = stats['by_kind'][kind]
657
+ if kind_stats['original'] > 0:
658
+ kind_table.add_row(
659
+ kind,
660
+ f"{kind_stats['original']:,}",
661
+ f"{kind_stats['removed']:,}",
662
+ f"{kind_stats['final']:,}"
663
+ )
664
+
665
+ console.print(kind_table)
666
+
667
+ # Show output paths
668
+ console.print(f"\n[dim]Optimized files written to:[/dim]")
669
+ console.print(f" [dim]Nodes: {output_nodes_path}[/dim]")
670
+ console.print(f" [dim]Edges: {output_edges_path}[/dim]")
671
+ if output_types_path and os.path.exists(output_types_path):
672
+ console.print(f" [dim]Types: {output_types_path}[/dim]")
673
+
674
+ except Exception as e:
675
+ console.print(f"[bold red]Error during optimization: {e}[/bold red]")
676
+ raise
677
+
678
+
679
+ @cli.command(name='retrain')
680
+ @click.option('--model-dir', '-m', required=True, type=click.Path(exists=True),
681
+ help='Path to existing model directory')
682
+ @click.option('--codebase-dir', '-c', default='.', type=click.Path(exists=True, file_okay=False),
683
+ help='Path to codebase root directory (default: current directory)')
684
+ @click.option('--output-nodes', '-n', default=None,
685
+ help='Temporary output file for nodes (default: .tricoder/nodes_retrain.jsonl)')
686
+ @click.option('--output-edges', '-d', default=None,
687
+ help='Temporary output file for edges (default: .tricoder/edges_retrain.jsonl)')
688
+ @click.option('--output-types', '-t', default=None,
689
+ help='Temporary output file for types (default: .tricoder/types_retrain.jsonl)')
690
+ @click.option('--graph-dim', default=None, type=int, show_default=False,
691
+ help='Dimensionality for the graph view embeddings (uses model default if not specified).')
692
+ @click.option('--context-dim', default=None, type=int, show_default=False,
693
+ help='Dimensionality for the context view embeddings (uses model default if not specified).')
694
+ @click.option('--typed-dim', default=None, type=int, show_default=False,
695
+ help='Dimensionality for the typed view embeddings (uses model default if not specified).')
696
+ @click.option('--final-dim', default=None, type=int, show_default=False,
697
+ help='Final dimensionality of fused embeddings (uses model default if not specified).')
698
+ @click.option('--num-walks', default=10, type=int, show_default=True,
699
+ help='Number of random walks per node for context view.')
700
+ @click.option('--walk-length', default=80, type=int, show_default=True,
701
+ help='Length of each random walk in the context view.')
702
+ @click.option('--train-ratio', default=0.8, type=float, show_default=True,
703
+ help='Fraction of edges used for training.')
704
+ @click.option('--random-state', default=42, type=int, show_default=True,
705
+ help='Random seed for reproducibility.')
706
+ @click.option('--force', is_flag=True, default=False,
707
+ help='Force full retraining even if no files changed.')
708
+ def retrain(model_dir, codebase_dir, output_nodes, output_edges, output_types,
709
+ graph_dim, context_dim, typed_dim, final_dim, num_walks, walk_length,
710
+ train_ratio, random_state, force):
711
+ """Retrain TriCoder model incrementally on changed files only."""
712
+ from .git_tracker import load_training_metadata
713
+ from .extract import extract_from_directory
714
+ import json
715
+
716
+ # Use default .tricoder directory if paths not specified
717
+ if model_dir is None:
718
+ model_dir = get_model_dir()
719
+ tricoder_dir = get_tricoder_dir()
720
+ output_nodes_path = output_nodes if output_nodes else os.path.join(tricoder_dir, 'nodes_retrain.jsonl')
721
+ output_edges_path = output_edges if output_edges else os.path.join(tricoder_dir, 'edges_retrain.jsonl')
722
+ output_types_path = output_types if output_types else os.path.join(tricoder_dir, 'types_retrain.jsonl')
723
+
724
+ console.print("[bold cyan]TriCoder Incremental Retraining[/bold cyan]\n")
725
+
726
+ # Load previous training metadata
727
+ metadata = load_training_metadata(model_dir)
728
+ if not metadata and not force:
729
+ console.print(
730
+ "[bold yellow]No previous training metadata found. Use 'train' command for initial training.[/bold yellow]")
731
+ return
732
+
733
+ if metadata:
734
+ console.print(
735
+ f"[dim]Previous training: commit {metadata.get('commit_hash', 'N/A')[:8] if metadata.get('commit_hash') else 'N/A'}[/dim]")
736
+ console.print(f"[dim]Training timestamp: {metadata.get('training_timestamp', 'N/A')}[/dim]\n")
737
+
738
+ # Get changed files
739
+ if force:
740
+ console.print("[yellow]Force flag set - retraining on all files[/yellow]\n")
741
+ changed_files = get_all_python_files(codebase_dir)
742
+ else:
743
+ changed_files = get_changed_files_for_retraining(model_dir, codebase_dir)
744
+
745
+ if not changed_files:
746
+ console.print("[bold green]✓ No files changed since last training. Model is up to date![/bold green]")
747
+ return
748
+
749
+ console.print(f"[cyan]Found {len(changed_files)} changed file(s):[/cyan]")
750
+ for f in sorted(list(changed_files))[:10]: # Show first 10
751
+ console.print(f" [dim]- {f}[/dim]")
752
+ if len(changed_files) > 10:
753
+ console.print(f" [dim]... and {len(changed_files) - 10} more[/dim]")
754
+ console.print()
755
+
756
+ # Extract symbols from changed files only
757
+ console.print("[cyan]Extracting symbols from changed files...[/cyan]")
758
+ extract_from_directory(
759
+ root_dir=codebase_dir,
760
+ include_dirs=[],
761
+ exclude_dirs=['.venv', '__pycache__', '.git', 'node_modules', '.pytest_cache'],
762
+ output_nodes=output_nodes_path,
763
+ output_edges=output_edges_path,
764
+ output_types=output_types_path,
765
+ use_gitignore=True
766
+ )
767
+
768
+ # Filter extracted data to only include changed files
769
+ console.print("[cyan]Filtering extracted data to changed files...[/cyan]")
770
+ filtered_nodes = output_nodes_path + '.filtered'
771
+ filtered_edges = output_edges_path + '.filtered'
772
+ filtered_types = output_types_path + '.filtered'
773
+
774
+ # Filter nodes
775
+ node_count = 0
776
+ with open(filtered_nodes, 'w') as out_f:
777
+ with open(output_nodes_path, 'r') as in_f:
778
+ for line in in_f:
779
+ if not line.strip():
780
+ continue
781
+ try:
782
+ data = json.loads(line)
783
+ file_path = data.get('meta', {}).get('file', '') if isinstance(data.get('meta'),
784
+ dict) else ''
785
+ # Normalize path for comparison
786
+ normalized_path = file_path.replace('\\', '/')
787
+ if normalized_path in changed_files or any(
788
+ normalized_path.endswith('/' + f) for f in changed_files):
789
+ out_f.write(line)
790
+ node_count += 1
791
+ except json.JSONDecodeError:
792
+ continue
793
+
794
+ # Filter edges (include if either endpoint is in changed files)
795
+ edge_count = 0
796
+ node_ids_from_changed = set()
797
+ with open(filtered_nodes, 'r') as f:
798
+ for line in f:
799
+ if line.strip():
800
+ try:
801
+ data = json.loads(line)
802
+ node_ids_from_changed.add(data.get('id'))
803
+ except json.JSONDecodeError:
804
+ continue
805
+
806
+ with open(filtered_edges, 'w') as out_f:
807
+ with open(output_edges_path, 'r') as in_f:
808
+ for line in in_f:
809
+ if not line.strip():
810
+ continue
811
+ try:
812
+ data = json.loads(line)
813
+ src = data.get('src', '')
814
+ dst = data.get('dst', '')
815
+ # Include edge if either endpoint is from changed files
816
+ if src in node_ids_from_changed or dst in node_ids_from_changed:
817
+ out_f.write(line)
818
+ edge_count += 1
819
+ except json.JSONDecodeError:
820
+ continue
821
+
822
+ # Filter types
823
+ type_count = 0
824
+ if os.path.exists(output_types_path):
825
+ with open(filtered_types, 'w') as out_f:
826
+ with open(output_types_path, 'r') as in_f:
827
+ for line in in_f:
828
+ if not line.strip():
829
+ continue
830
+ try:
831
+ data = json.loads(line)
832
+ symbol_id = data.get('symbol', '')
833
+ if symbol_id in node_ids_from_changed:
834
+ out_f.write(line)
835
+ type_count += 1
836
+ except json.JSONDecodeError:
837
+ continue
838
+
839
+ console.print(
840
+ f"[green]✓ Extracted {node_count} nodes, {edge_count} edges, {type_count} type tokens from changed files[/green]\n")
841
+
842
+ if node_count == 0:
843
+ console.print("[bold yellow]No nodes found in changed files. Nothing to retrain.[/bold yellow]")
844
+ # Cleanup
845
+ for f in [output_nodes_path, output_edges_path, output_types_path, filtered_nodes, filtered_edges, filtered_types]:
846
+ if os.path.exists(f):
847
+ os.remove(f)
848
+ return
849
+
850
+ # Retrain the model
851
+ console.print("[cyan]Retraining model...[/cyan]\n")
852
+ train_model(
853
+ nodes_path=filtered_nodes,
854
+ edges_path=filtered_edges,
855
+ types_path=filtered_types if os.path.exists(filtered_types) else None,
856
+ output_dir=model_dir,
857
+ graph_dim=graph_dim,
858
+ context_dim=context_dim,
859
+ typed_dim=typed_dim,
860
+ final_dim=final_dim,
861
+ num_walks=num_walks,
862
+ walk_length=walk_length,
863
+ train_ratio=train_ratio,
864
+ random_state=random_state,
865
+ use_gpu=False # Retrain doesn't support GPU yet
866
+ )
867
+
868
+ # Save updated git metadata
869
+ commit_hash = get_git_commit_hash(codebase_dir)
870
+ commit_timestamp = get_git_commit_timestamp(codebase_dir)
871
+ all_files = extract_files_from_jsonl(filtered_nodes)
872
+ all_files.update(extract_files_from_jsonl(filtered_edges))
873
+ if os.path.exists(filtered_types):
874
+ all_files.update(extract_files_from_jsonl(filtered_types))
875
+
876
+ save_training_metadata(model_dir, commit_hash, commit_timestamp, all_files)
877
+ console.print(
878
+ f"[dim]Updated training metadata (commit: {commit_hash[:8] if commit_hash else 'N/A'})[/dim]")
879
+
880
+ # Cleanup temporary files
881
+ console.print("\n[cyan]Cleaning up temporary files...[/cyan]")
882
+ for f in [output_nodes_path, output_edges_path, output_types_path, filtered_nodes, filtered_edges, filtered_types]:
883
+ if os.path.exists(f):
884
+ os.remove(f)
885
+
886
+ console.print("[bold green]✓ Incremental retraining complete![/bold green]")
887
+
888
+
889
+ if __name__ == '__main__':
890
+ cli()