octotui 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1010 @@
1
+ from pathlib import Path
2
+ from typing import Dict, List, Optional, Tuple, Set, FrozenSet, Iterable
3
+ import re
4
+ import tempfile
5
+ import os
6
+ import git
7
+ from dataclasses import dataclass
8
+ from datetime import datetime
9
+ import time
10
+ from collections import defaultdict
11
+
12
+ # Import for backward compatibility with existing code
13
+
14
+
15
+ @dataclass
16
+ class Hunk:
17
+ """Represents a diff hunk with header and line information."""
18
+
19
+ header: str
20
+ lines: List[str]
21
+
22
+ def __post_init__(self):
23
+ # Remove the newline at the end of header if present
24
+ if self.header.endswith("\n"):
25
+ self.header = self.header[:-1]
26
+
27
+
28
+ @dataclass
29
+ class CommitInfo:
30
+ """Represents commit information for history display."""
31
+
32
+ sha: str
33
+ message: str
34
+ author: str
35
+ date: datetime
36
+
37
+ def __post_init__(self):
38
+ # Remove the newline at the end of message if present
39
+ if self.message.endswith("\n"):
40
+ self.message = self.message[:-1]
41
+
42
+
43
+ class GitStatusSidebar:
44
+ """Manages git repository status and file tree display."""
45
+
46
+ def __init__(self, repo_path: Optional[str] = None):
47
+ """Initialize the git status sidebar.
48
+
49
+ Args:
50
+ repo_path: Path to the git repository. If None, uses current directory.
51
+ """
52
+ try:
53
+ self.repo = git.Repo(repo_path or ".", search_parent_directories=True)
54
+ self.repo_path = Path(self.repo.working_dir)
55
+ except git.InvalidGitRepositoryError:
56
+ self.repo = None
57
+ self.repo_path = Path("")
58
+ except Exception:
59
+ self.repo = None
60
+ self.repo_path = Path("")
61
+
62
+ # Cache for expensive git operations
63
+ self._cache = {}
64
+ self._cache_timestamps = {}
65
+ self._cache_ttl = 5.0 # Cache TTL in seconds
66
+
67
+ # Track which files were affected by recent operations
68
+ self._recently_modified_files = set()
69
+
70
+ def _get_cache_key(self, method_name: str, *args) -> str:
71
+ """Generate cache key for method calls."""
72
+ return f"{method_name}:{':'.join(map(str, args))}"
73
+
74
+ def _is_cache_valid(self, cache_key: str) -> bool:
75
+ """Check if cache entry is still valid."""
76
+ if cache_key not in self._cache_timestamps:
77
+ return False
78
+ return (time.time() - self._cache_timestamps[cache_key]) < self._cache_ttl
79
+
80
+ def _get_cached(self, cache_key: str):
81
+ """Get cached value if valid."""
82
+ if self._is_cache_valid(cache_key):
83
+ return self._cache.get(cache_key)
84
+ return None
85
+
86
+ def _set_cache(self, cache_key: str, value):
87
+ """Set cache value with timestamp."""
88
+ self._cache[cache_key] = value
89
+ self._cache_timestamps[cache_key] = time.time()
90
+
91
+ def _invalidate_cache(self, pattern: Optional[str] = None):
92
+ """Invalidate cache entries matching pattern, or all if pattern is None."""
93
+ if pattern is None:
94
+ self._cache.clear()
95
+ self._cache_timestamps.clear()
96
+ else:
97
+ keys_to_remove = [k for k in self._cache.keys() if pattern in k]
98
+ for key in keys_to_remove:
99
+ self._cache.pop(key, None)
100
+ self._cache_timestamps.pop(key, None)
101
+
102
+ def _mark_file_modified(self, file_path: str):
103
+ """Mark a file as recently modified to optimize future updates."""
104
+ self._recently_modified_files.add(file_path)
105
+ # Invalidate relevant caches
106
+ self._invalidate_cache("get_file_statuses")
107
+ self._invalidate_cache("get_files_with_unstaged_changes")
108
+ self._invalidate_cache("get_staged_files")
109
+ self._invalidate_cache("get_file_tree")
110
+ # Invalidate diff hunks for this specific file
111
+ self._invalidate_cache(f"get_diff_hunks:{file_path}")
112
+
113
+ def get_recently_modified_files(self) -> set:
114
+ """Get and clear the set of recently modified files."""
115
+ modified_files = self._recently_modified_files.copy()
116
+ self._recently_modified_files.clear()
117
+ return modified_files
118
+
119
+ def has_recent_modifications(self) -> bool:
120
+ """Check if there are any recent modifications."""
121
+ return len(self._recently_modified_files) > 0
122
+
123
+ def get_file_statuses(self) -> Dict[str, FrozenSet[str]]:
124
+ """Get git status flags for files in the repository.
125
+
126
+ Returns:
127
+ Dictionary mapping file paths to frozen sets of git status flags.
128
+ Flags include "staged", "modified", and "untracked".
129
+ """
130
+ if not self.repo:
131
+ return {}
132
+
133
+ cache_key = self._get_cache_key("get_file_statuses")
134
+ cached_result = self._get_cached(cache_key)
135
+ if cached_result is not None:
136
+ return cached_result
137
+
138
+ statuses: Dict[str, Set[str]] = defaultdict(set)
139
+
140
+ try:
141
+ # Get staged changes (index vs HEAD)
142
+ for diff in self.repo.index.diff("HEAD"):
143
+ statuses[diff.b_path].add("staged")
144
+
145
+ # Get unstaged changes (working tree vs index)
146
+ for diff in self.repo.index.diff(None):
147
+ statuses[diff.b_path].add("modified")
148
+
149
+ # Get untracked files
150
+ for file_path in self.repo.untracked_files:
151
+ statuses[file_path].add("untracked")
152
+ except Exception:
153
+ # Return empty dict on error, but don't cache it
154
+ return {}
155
+
156
+ frozen_statuses: Dict[str, FrozenSet[str]] = {
157
+ path: frozenset(flags) for path, flags in statuses.items()
158
+ }
159
+
160
+ self._set_cache(cache_key, frozen_statuses)
161
+ return frozen_statuses
162
+
163
+ def get_staged_files(self) -> List[str]:
164
+ """Get list of staged files in the repository.
165
+
166
+ Returns:
167
+ List of file paths that are staged
168
+ """
169
+ if not self.repo:
170
+ return []
171
+
172
+ cache_key = self._get_cache_key("get_staged_files")
173
+ cached_result = self._get_cached(cache_key)
174
+ if cached_result is not None:
175
+ return cached_result
176
+
177
+ try:
178
+ staged_files = self.repo.index.diff("HEAD")
179
+ result = [diff.b_path for diff in staged_files]
180
+ self._set_cache(cache_key, result)
181
+ return result
182
+ except Exception:
183
+ return []
184
+
185
+ def get_unstaged_files(self) -> List[str]:
186
+ """Get a list of unstaged (modified) files.
187
+
188
+ Returns:
189
+ List of file paths that are modified but not staged
190
+ """
191
+ if not self.repo:
192
+ return []
193
+
194
+ try:
195
+ unstaged_files = self.repo.index.diff(None)
196
+ return [diff.b_path for diff in unstaged_files]
197
+ except Exception:
198
+ return []
199
+
200
+ def _resolve_primary_status(self, status_flags: Iterable[str]) -> str:
201
+ """Pick a single status that best represents the file for tree display."""
202
+ flags = set(status_flags)
203
+ for candidate in ("staged", "modified", "untracked"):
204
+ if candidate in flags:
205
+ return candidate
206
+ return "unchanged"
207
+
208
+ def collect_file_data(self) -> Dict[str, any]:
209
+ """Collect consolidated file and directory data for minimal git calls.
210
+
211
+ Returns:
212
+ Dict containing:
213
+ - files: List of tuples (file_path, git_status)
214
+ - directories: Set of directory paths
215
+ - staged_files: List of staged file paths
216
+ - unstaged_files: List of modified file paths
217
+ - untracked_files: List of untracked file paths
218
+ """
219
+ if not self.repo:
220
+ return {
221
+ "files": [],
222
+ "directories": set(),
223
+ "staged_files": [],
224
+ "unstaged_files": [],
225
+ "untracked_files": [],
226
+ }
227
+
228
+ try:
229
+ # Get file statuses once
230
+ statuses = self.get_file_statuses()
231
+
232
+ # Files from git listing
233
+ tracked_files = self.repo.git.ls_files().splitlines()
234
+ files = [
235
+ (f, self._resolve_primary_status(statuses.get(f, frozenset())))
236
+ for f in tracked_files
237
+ ]
238
+
239
+ # Add untracked files to list explicitly (not part of tracked files)
240
+ untracked_files = [
241
+ f for f, status_flags in statuses.items() if "untracked" in status_flags
242
+ ]
243
+ files.extend(
244
+ [(f, "untracked") for f in untracked_files if f not in tracked_files]
245
+ )
246
+
247
+ # Directories via git ls-tree, more reliable than Path walk fallback
248
+ try:
249
+ ls_tree_dirs = self.repo.git.ls_tree(
250
+ "--full-tree", "-d", "--name-only", "HEAD"
251
+ )
252
+ directories = set(ls_tree_dirs.splitlines()) if ls_tree_dirs else set()
253
+ except Exception:
254
+ directories = set()
255
+
256
+ # Always ensure .git paths excluded no matter what
257
+ files = [(f, s) for f, s in files if ".git" not in f.split("/")]
258
+ directories = {d for d in directories if ".git" not in d.split("/")}
259
+
260
+ return {
261
+ "files": files,
262
+ "directories": directories,
263
+ "staged_files": [
264
+ f for f, flags in statuses.items() if "staged" in flags
265
+ ],
266
+ "unstaged_files": [
267
+ f
268
+ for f, flags in statuses.items()
269
+ if "modified" in flags or "untracked" in flags
270
+ ],
271
+ "untracked_files": [
272
+ f for f, flags in statuses.items() if "untracked" in flags
273
+ ],
274
+ }
275
+ except Exception:
276
+ return {
277
+ "files": [],
278
+ "directories": set(),
279
+ "staged_files": [],
280
+ "unstaged_files": [],
281
+ "untracked_files": [],
282
+ }
283
+
284
+ def get_file_tree(self) -> List[Tuple[str, str, str]]:
285
+ """Get a flattened list of all files with their git status.
286
+
287
+ Returns:
288
+ List of tuples (file_path, file_type, git_status) where file_type is "file" or "directory"
289
+ and git_status is "staged", "modified", "untracked", or "unchanged"
290
+ """
291
+ file_data = self.collect_file_data()
292
+ file_entries = [
293
+ (f_path, "file", status) for f_path, status in file_data["files"]
294
+ ]
295
+ dir_entries = [
296
+ (d_path, "directory", "unchanged") for d_path in file_data["directories"]
297
+ ]
298
+ return sorted(
299
+ file_entries + dir_entries, key=lambda x: (x[1] != "directory", x[0])
300
+ )
301
+
302
+ def get_diff_hunks(self, file_path: str, staged: bool = False) -> List[Hunk]:
303
+ """Get diff hunks for a specific file.
304
+
305
+ Args:
306
+ file_path: Path to the file relative to repository root
307
+ staged: Whether to get staged diff
308
+
309
+ Returns:
310
+ List of Hunk objects representing the diff hunks
311
+ """
312
+ if not self.repo:
313
+ return []
314
+
315
+ # Check cache first for diff hunks
316
+ cache_key = self._get_cache_key("get_diff_hunks", file_path, staged)
317
+ cached_result = self._get_cached(cache_key)
318
+ if cached_result is not None:
319
+ return cached_result
320
+
321
+ try:
322
+ diff_cmd = ["--", file_path]
323
+ if staged:
324
+ diff_cmd.insert(0, "--cached")
325
+ diff = self.repo.git.diff(*diff_cmd)
326
+ if not diff:
327
+ status = self.get_file_status(file_path)
328
+ if staged:
329
+ result = []
330
+ elif status == "untracked":
331
+ with (self.repo_path / file_path).open("r") as f:
332
+ content = f.read()
333
+ lines = ["+" + line for line in content.splitlines()]
334
+ result = [Hunk("@@ -0,0 +1," + str(len(lines)) + " @@", lines)]
335
+ elif status == "unchanged":
336
+ with (self.repo_path / file_path).open("r") as f:
337
+ content = f.read()
338
+ lines = content.splitlines()
339
+ result = [Hunk("", lines)]
340
+ else:
341
+ result = []
342
+ else:
343
+ hunks = self._parse_diff_into_hunks(diff)
344
+ if file_path.endswith(".md"):
345
+ hunks = self._filter_whitespace_hunks(hunks)
346
+ result = hunks
347
+
348
+ # Cache the result
349
+ self._set_cache(cache_key, result)
350
+ return result
351
+ except Exception:
352
+ return []
353
+
354
+ def _is_whitespace_only_change(self, old_line: str, new_line: str) -> bool:
355
+ """Check if a change is only whitespace differences.
356
+
357
+ Args:
358
+ old_line: The original line
359
+ new_line: The new line
360
+
361
+ Returns:
362
+ True if the change is only whitespace, False otherwise
363
+ """
364
+ # Strip the lines to compare content
365
+ old_stripped = old_line.strip()
366
+ new_stripped = new_line.strip()
367
+
368
+ # If stripped lines are identical, it's a whitespace-only change
369
+ if old_stripped == new_stripped:
370
+ return True
371
+
372
+ # For markdown bullet points, check if it's just leading space differences
373
+ # But only if the bullet type is the same
374
+ bullet_types = ["- ", "* ", "+ "]
375
+ for bullet in bullet_types:
376
+ if old_stripped.startswith(bullet) and new_stripped.startswith(bullet):
377
+ # Get the content part (without the bullet)
378
+ old_content = old_stripped[len(bullet) :]
379
+ new_content = new_stripped[len(bullet) :]
380
+ return old_content == new_content
381
+
382
+ # Not a whitespace-only change
383
+ return False
384
+
385
+ def _filter_whitespace_hunks(self, hunks: List[Hunk]) -> List[Hunk]:
386
+ """Filter out hunks that contain only whitespace changes.
387
+
388
+ Args:
389
+ hunks: List of hunks to filter
390
+
391
+ Returns:
392
+ List of hunks with meaningful changes
393
+ """
394
+ filtered_hunks = []
395
+
396
+ for hunk in hunks:
397
+ # We'll implement a simple filter that removes lines where the only change is whitespace
398
+ filtered_lines = []
399
+ i = 0
400
+ while i < len(hunk.lines):
401
+ line = hunk.lines[i]
402
+
403
+ # Handle diff lines
404
+ if (
405
+ line and line[:1] == "-"
406
+ ): # Only check first character to avoid confusion with content starting with '-'
407
+ # Check if there's a corresponding addition line
408
+ if (
409
+ i + 1 < len(hunk.lines)
410
+ and hunk.lines[i + 1]
411
+ and hunk.lines[i + 1][:1] == "+"
412
+ ): # Only check first character
413
+ next_line = hunk.lines[i + 1]
414
+
415
+ # Check if they're only whitespace different
416
+ if self._is_whitespace_only_change(
417
+ line[1:], next_line[1:]
418
+ ): # Skip the +/- prefix
419
+ # Skip both lines (filter out this whitespace change)
420
+ i += 2
421
+ continue
422
+ else:
423
+ filtered_lines.append(line)
424
+ filtered_lines.append(next_line)
425
+ i += 2
426
+ continue
427
+ else:
428
+ filtered_lines.append(line)
429
+ i += 1
430
+ elif (
431
+ line and line[:1] == "+"
432
+ ): # Only check first character to avoid confusion with content starting with '+'
433
+ # Check if there's a corresponding removal line
434
+ if (
435
+ i > 0 and hunk.lines[i - 1] and hunk.lines[i - 1][:1] == "-"
436
+ ): # Only check first character
437
+ # This line was already processed with the previous line, skip it
438
+ i += 1
439
+ continue
440
+ else:
441
+ # This is an addition without a corresponding removal
442
+ filtered_lines.append(line)
443
+ i += 1
444
+ else:
445
+ # Context line (unchanged)
446
+ filtered_lines.append(line)
447
+ i += 1
448
+
449
+ # Only add hunk if it has meaningful content
450
+ filtered_hunks.append(Hunk(header=hunk.header, lines=filtered_lines))
451
+
452
+ return filtered_hunks
453
+
454
+ def _parse_diff_into_hunks(self, diff: str) -> List[Hunk]:
455
+ """Parse a unified diff string into hunks.
456
+
457
+ Args:
458
+ diff: Unified diff string
459
+
460
+ Returns:
461
+ List of Hunk objects
462
+ """
463
+ hunks = []
464
+ lines = diff.splitlines()
465
+
466
+ current_hunk_lines = []
467
+ current_hunk_header = ""
468
+
469
+ for line in lines:
470
+ if line.startswith("@@"):
471
+ # If we have a previous hunk, save it
472
+ if current_hunk_header and current_hunk_lines:
473
+ hunks.append(
474
+ Hunk(header=current_hunk_header, lines=current_hunk_lines)
475
+ )
476
+ current_hunk_lines = []
477
+
478
+ # Start new hunk
479
+ current_hunk_header = line
480
+ elif current_hunk_header:
481
+ # Add line to current hunk
482
+ current_hunk_lines.append(line)
483
+
484
+ # Don't forget the last hunk
485
+ if current_hunk_header and current_hunk_lines:
486
+ hunks.append(Hunk(header=current_hunk_header, lines=current_hunk_lines))
487
+
488
+ # If no hunks were found, return empty hunk
489
+ if not hunks and lines:
490
+ hunks.append(Hunk(header="", lines=lines))
491
+
492
+ return hunks
493
+
494
+ def get_file_status(self, file_path: str) -> str:
495
+ """Get the git status of a specific file.
496
+
497
+ Args:
498
+ file_path: Relative path to the file
499
+
500
+ Returns:
501
+ Git status: "modified", "staged", "untracked", or "unchanged"
502
+ """
503
+ # Check staged changes first (index vs HEAD)
504
+ try:
505
+ diff_index = self.repo.index.diff("HEAD")
506
+ for diff in diff_index:
507
+ if diff.b_path == file_path:
508
+ return "staged"
509
+ except Exception:
510
+ pass
511
+
512
+ # Check unstaged changes (working tree vs index)
513
+ try:
514
+ diff_working = self.repo.index.diff(None)
515
+ for diff in diff_working:
516
+ if diff.b_path == file_path:
517
+ return "modified"
518
+ except Exception:
519
+ pass
520
+
521
+ # Check untracked files
522
+ try:
523
+ if file_path in self.repo.untracked_files:
524
+ return "untracked"
525
+ except Exception:
526
+ pass
527
+
528
+ return "unchanged"
529
+
530
+ def stage_file(self, file_path: str) -> bool:
531
+ """Stage a file.
532
+
533
+ Args:
534
+ file_path: Path to the file relative to repository root
535
+
536
+ Returns:
537
+ True if successful, False otherwise
538
+ """
539
+ if not self.repo:
540
+ return False
541
+
542
+ try:
543
+ self.repo.index.add([file_path])
544
+ return True
545
+ except Exception:
546
+ return False
547
+
548
+ def unstage_file(self, file_path: str) -> bool:
549
+ """Unstage a file from the index (remove all entries for the file from staging).
550
+
551
+ This uses `git restore --staged` which is safer for partials.
552
+ """
553
+ if not self.repo:
554
+ return False
555
+ try:
556
+ # Safer than index.remove for mixed states
557
+ self.repo.git.restore("--staged", "--", file_path)
558
+ return True
559
+ except Exception:
560
+ return False
561
+
562
+ def unstage_file_all(self, file_path: str) -> bool:
563
+ """Unstage all changes for a file using git restore --staged."""
564
+ return self.unstage_file(file_path)
565
+
566
+ def discard_file_changes(self, file_path: str) -> bool:
567
+ """Discard changes to a file.
568
+
569
+ Args:
570
+ file_path: Path to the file relative to repository root
571
+
572
+ Returns:
573
+ True if successful, False otherwise
574
+ """
575
+ if not self.repo:
576
+ return False
577
+
578
+ try:
579
+ self.repo.git.checkout("--", file_path)
580
+ return True
581
+ except Exception:
582
+ return False
583
+
584
+ def get_commit_history(self) -> List[CommitInfo]:
585
+ """Get commit history.
586
+
587
+ Returns:
588
+ List of CommitInfo objects
589
+ """
590
+ if not self.repo:
591
+ return []
592
+
593
+ try:
594
+ commits = list(self.repo.iter_commits("HEAD"))
595
+ commit_info_list = []
596
+
597
+ for commit in commits:
598
+ commit_info = CommitInfo(
599
+ sha=commit.hexsha[:8], # Short SHA
600
+ message=commit.message.strip(),
601
+ author=commit.author.name,
602
+ date=commit.committed_datetime,
603
+ )
604
+ commit_info_list.append(commit_info)
605
+
606
+ return commit_info_list
607
+ except Exception:
608
+ return []
609
+
610
+ def get_current_branch(self) -> str:
611
+ """Get the current branch name.
612
+
613
+ Returns:
614
+ Current branch name or 'unknown' if not in a repo
615
+ """
616
+ if not self.repo:
617
+ return "unknown"
618
+
619
+ try:
620
+ return self.repo.active_branch.name
621
+ except Exception:
622
+ return "unknown"
623
+
624
+ def commit_staged_changes(self, message: str) -> bool:
625
+ """Commit staged changes.
626
+
627
+ Args:
628
+ message: Commit message
629
+
630
+ Returns:
631
+ True if successful, False otherwise
632
+ """
633
+ if not self.repo:
634
+ return False
635
+
636
+ try:
637
+ self.repo.index.commit(message)
638
+ return True
639
+ except Exception:
640
+ return False
641
+
642
+ def get_all_branches(self) -> List[str]:
643
+ """Get all branch names in the repository.
644
+
645
+ Returns:
646
+ List of branch names
647
+ """
648
+ if not self.repo:
649
+ return []
650
+
651
+ try:
652
+ # Try a simpler approach using git branch command
653
+ branches_output = self.repo.git.branch()
654
+ branches = [branch.strip() for branch in branches_output.split("\n")]
655
+ # Remove the '*' marker from current branch and filter out empty lines
656
+ branches = [
657
+ branch.replace("*", "").strip() for branch in branches if branch.strip()
658
+ ]
659
+ return branches
660
+ except Exception:
661
+ # Fallback to the previous method
662
+ try:
663
+ branches = [
664
+ ref.name
665
+ for ref in self.repo.refs
666
+ if ref.name.startswith("refs/heads/")
667
+ ]
668
+ # Remove the 'refs/heads/' prefix
669
+ branches = [branch.replace("refs/heads/", "") for branch in branches]
670
+ return branches
671
+ except Exception:
672
+ return []
673
+
674
+ def _get_remote_and_branch(self) -> Tuple[str, str]:
675
+ """Resolve the remote/branch pair for push and pull operations."""
676
+ if not self.repo:
677
+ raise ValueError("Not inside a git repository")
678
+
679
+ try:
680
+ if self.repo.head.is_detached:
681
+ raise ValueError("Detached HEAD state; cannot infer branch")
682
+ except Exception:
683
+ raise ValueError("Unable to determine HEAD state")
684
+
685
+ active_branch = self.repo.active_branch
686
+ tracking_branch = active_branch.tracking_branch()
687
+
688
+ if tracking_branch is not None:
689
+ remote_name = tracking_branch.remote_name
690
+ branch_name = tracking_branch.remote_head or active_branch.name
691
+ else:
692
+ remote_name = "origin"
693
+ if remote_name not in self.repo.remotes:
694
+ if not self.repo.remotes:
695
+ raise ValueError("No remotes configured")
696
+ remote_name = self.repo.remotes[0].name
697
+ branch_name = active_branch.name
698
+
699
+ return remote_name, branch_name
700
+
701
+ def push_current_branch(self) -> Tuple[bool, str]:
702
+ """Push the current branch to its remote tracking branch."""
703
+ if not self.repo:
704
+ return False, "Not inside a git repository"
705
+
706
+ try:
707
+ remote_name, branch_name = self._get_remote_and_branch()
708
+ self.repo.git.push(remote_name, branch_name)
709
+ return True, f"Pushed {branch_name} to {remote_name}"
710
+ except ValueError as err:
711
+ return False, str(err)
712
+ except git.GitCommandError as err:
713
+ return False, f"Git push failed: {err}"
714
+ except Exception as err:
715
+ return False, f"Unexpected push failure: {err}"
716
+
717
+ def pull_current_branch(self) -> Tuple[bool, str]:
718
+ """Pull the latest changes for the current branch."""
719
+ if not self.repo:
720
+ return False, "Not inside a git repository"
721
+
722
+ try:
723
+ remote_name, branch_name = self._get_remote_and_branch()
724
+ self.repo.git.pull(remote_name, branch_name)
725
+ return True, f"Pulled {branch_name} from {remote_name}"
726
+ except ValueError as err:
727
+ return False, str(err)
728
+ except git.GitCommandError as err:
729
+ return False, f"Git pull failed: {err}"
730
+ except Exception as err:
731
+ return False, f"Unexpected pull failure: {err}"
732
+
733
+ def is_dirty(self) -> bool:
734
+ """Check if the repository has modified or staged changes.
735
+
736
+ Returns:
737
+ True if repository is dirty, False otherwise
738
+ """
739
+ if not self.repo:
740
+ return False
741
+
742
+ try:
743
+ # Check for staged changes
744
+ if self.get_staged_files():
745
+ return True
746
+
747
+ # Check for unstaged changes
748
+ if self.get_unstaged_files():
749
+ return True
750
+
751
+ return False
752
+ except Exception:
753
+ return False
754
+
755
+ def switch_branch(self, branch_name: str) -> bool:
756
+ """Switch to a different branch.
757
+
758
+ Args:
759
+ branch_name: Name of the branch to switch to
760
+
761
+ Returns:
762
+ True if successful, False otherwise
763
+ """
764
+ if not self.repo or self.is_dirty():
765
+ return False
766
+
767
+ try:
768
+ self.repo.git.checkout(branch_name)
769
+ return True
770
+ except Exception:
771
+ return False
772
+
773
+ def _reverse_hunk_header(self, header: str) -> str:
774
+ match = re.match(r"@@ -(\d+),(\d+) \+(\d+),(\d+) @@", header)
775
+ if match:
776
+ old_start, old_len, new_start, new_len = map(int, match.groups())
777
+ return f"@@ -{new_start},{new_len} +{old_start},{old_len} @@"
778
+ return header
779
+
780
+ def _create_patch_from_hunk(
781
+ self, file_path: str, hunk: Hunk, reverse: bool = False
782
+ ) -> str:
783
+ # Create a proper unified diff header
784
+ diff_header = f"--- a/{file_path}\n+++ b/{file_path}\n"
785
+
786
+ if reverse:
787
+ header = self._reverse_hunk_header(hunk.header)
788
+ reversed_lines = []
789
+ for line in hunk.lines:
790
+ if line.startswith("+"):
791
+ reversed_lines.append("-" + line[1:])
792
+ elif line.startswith("-"):
793
+ reversed_lines.append("+" + line[1:])
794
+ else:
795
+ reversed_lines.append(line)
796
+ lines = [header] + reversed_lines
797
+ else:
798
+ lines = [hunk.header] + hunk.lines
799
+
800
+ # Combine the diff header and hunk content
801
+ patch_content = diff_header + "\n".join(lines) + "\n"
802
+ return patch_content
803
+
804
+ def _apply_patch(
805
+ self,
806
+ patch: str,
807
+ cached: bool = False,
808
+ reverse: bool = False,
809
+ index: bool = False,
810
+ ) -> bool:
811
+ with tempfile.NamedTemporaryFile(mode="w", delete=False) as tmp:
812
+ tmp.write(patch)
813
+ tmp_path = tmp.name
814
+ args = []
815
+ if reverse:
816
+ args.append("-R")
817
+ if cached:
818
+ args.append("--cached")
819
+ if index:
820
+ args.append("--index")
821
+ args.append(tmp_path)
822
+
823
+ try:
824
+ self.repo.git.apply(*args)
825
+ return True
826
+ except git.GitCommandError as e:
827
+ # More specific error handling for git apply
828
+ error_msg = str(e)
829
+ if "error: patch failed" in error_msg:
830
+ print(f"Patch failed: {error_msg}")
831
+ elif "error: unable to write" in error_msg:
832
+ print(f"Unable to write patch: {error_msg}")
833
+ else:
834
+ print(f"Git command error applying patch: {error_msg}")
835
+ return False
836
+ except Exception as e:
837
+ print(f"Unexpected error applying patch: {e}")
838
+ return False
839
+ finally:
840
+ os.unlink(tmp_path)
841
+
842
+ def stage_hunk(self, file_path: str, hunk_index: int) -> bool:
843
+ try:
844
+ hunks = self.get_diff_hunks(file_path, staged=False)
845
+ if hunk_index >= len(hunks):
846
+ return False
847
+ hunk = hunks[hunk_index]
848
+ patch = self._create_patch_from_hunk(file_path, hunk)
849
+ success = self._apply_patch(patch, cached=True)
850
+ if success:
851
+ self._mark_file_modified(file_path)
852
+ return success
853
+ except Exception as e:
854
+ print(f"Error in stage_hunk: {e}")
855
+ return False
856
+
857
+ def unstage_hunk(self, file_path: str, hunk_index: int) -> bool:
858
+ try:
859
+ hunks = self.get_diff_hunks(file_path, staged=True)
860
+ if hunk_index >= len(hunks):
861
+ return False
862
+ hunk = hunks[hunk_index]
863
+ patch = self._create_patch_from_hunk(file_path, hunk)
864
+ success = self._apply_patch(patch, cached=True, reverse=True)
865
+ if success:
866
+ self._mark_file_modified(file_path)
867
+ return success
868
+ except Exception as e:
869
+ print(f"Error in unstage_hunk: {e}")
870
+ return False
871
+
872
+ def discard_hunk(self, file_path: str, hunk_index: int) -> bool:
873
+ try:
874
+ hunks = self.get_diff_hunks(file_path, staged=False)
875
+ if hunk_index >= len(hunks):
876
+ return False
877
+ hunk = hunks[hunk_index]
878
+ patch = self._create_patch_from_hunk(file_path, hunk, reverse=True)
879
+ success = self._apply_patch(patch)
880
+ if success:
881
+ self._mark_file_modified(file_path)
882
+ return success
883
+ except Exception as e:
884
+ print(f"Error in discard_hunk: {e}")
885
+ return False
886
+
887
+ def stage_all_changes(self) -> Tuple[bool, str]:
888
+ """Stage all unstaged changes in the repository.
889
+
890
+ Returns:
891
+ Tuple of (success, message)
892
+ """
893
+ if not self.repo:
894
+ return False, "Not in a git repository"
895
+
896
+ try:
897
+ # Get all unstaged files (including modified, untracked, and deleted)
898
+ unstaged_files = self.get_unstaged_files()
899
+ untracked_files = self.repo.untracked_files
900
+
901
+ files_to_stage = []
902
+
903
+ # Handle regular unstaged files (modified and deleted)
904
+ for file_path in unstaged_files:
905
+ files_to_stage.append(file_path)
906
+
907
+ # Handle untracked files
908
+ for file_path in untracked_files:
909
+ files_to_stage.append(file_path)
910
+
911
+ if not files_to_stage:
912
+ return True, "No changes to stage"
913
+
914
+ # Stage all changes using git add --update for modified/deleted and git add for untracked
915
+ if unstaged_files:
916
+ # This handles modified and deleted files
917
+ self.repo.git.add("--update")
918
+
919
+ if untracked_files:
920
+ # This handles untracked files
921
+ self.repo.git.add("--", *untracked_files)
922
+
923
+ # Mark all modified files as recently modified
924
+ for file_path in files_to_stage:
925
+ self._mark_file_modified(file_path)
926
+
927
+ return True, f"Staged {len(files_to_stage)} files"
928
+
929
+ except Exception as e:
930
+ return False, f"Failed to stage all changes: {str(e)}"
931
+
932
+ def unstage_all_changes(self) -> Tuple[bool, str]:
933
+ """Unstage all staged changes in the repository.
934
+
935
+ Returns:
936
+ Tuple of (success, message)
937
+ """
938
+ if not self.repo:
939
+ return False, "Not in a git repository"
940
+
941
+ try:
942
+ staged_files = self.get_staged_files()
943
+
944
+ if not staged_files:
945
+ return True, "No staged changes to unstage"
946
+
947
+ # Use git reset to unstage all changes
948
+ self.repo.git.reset("--")
949
+
950
+ # Mark all modified files as recently modified
951
+ for file_path in staged_files:
952
+ self._mark_file_modified(file_path)
953
+
954
+ return True, f"Unstaged {len(staged_files)} files"
955
+
956
+ except Exception as e:
957
+ return False, f"Failed to unstage all changes: {str(e)}"
958
+
959
+ def get_git_status(self) -> str:
960
+ """Get git status output as string for GAC.
961
+
962
+ Returns:
963
+ Git status output as string
964
+ """
965
+ if not self.repo:
966
+ return ""
967
+
968
+ try:
969
+ return self.repo.git.status()
970
+ except Exception:
971
+ return ""
972
+
973
+ def get_staged_diff(self) -> str:
974
+ """Get staged changes diff for GAC.
975
+
976
+ Returns:
977
+ Staged diff as string
978
+ """
979
+ if not self.repo:
980
+ return ""
981
+
982
+ try:
983
+ return self.repo.git.diff("--cached")
984
+ except Exception:
985
+ return ""
986
+
987
+ def get_full_diff(self) -> str:
988
+ """Get full diff (staged + unstaged) for GAC.
989
+
990
+ Returns:
991
+ Full diff as string
992
+ """
993
+ if not self.repo:
994
+ return ""
995
+
996
+ try:
997
+ # Get both staged and unstaged changes
998
+ staged_diff = self.repo.git.diff("--cached")
999
+ unstaged_diff = self.repo.git.diff()
1000
+
1001
+ if staged_diff and unstaged_diff:
1002
+ return f"# Staged changes:\n{staged_diff}\n\n# Unstaged changes:\n{unstaged_diff}"
1003
+ elif staged_diff:
1004
+ return staged_diff
1005
+ elif unstaged_diff:
1006
+ return unstaged_diff
1007
+ else:
1008
+ return ""
1009
+ except Exception:
1010
+ return ""