stackfix 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,138 @@
1
+ import os
2
+ import re
3
+ import subprocess
4
+ from typing import List, Tuple
5
+
6
+ from .safety import is_forbidden_path
7
+ from .util import is_git_repo
8
+
9
+
10
+ def _extract_paths_from_diff(diff_text: str) -> List[str]:
11
+ paths = []
12
+ for line in diff_text.splitlines():
13
+ if line.startswith("+++ ") or line.startswith("--- "):
14
+ path = line[4:].strip()
15
+ if path.startswith("a/") or path.startswith("b/"):
16
+ path = path[2:]
17
+ if path == "/dev/null":
18
+ continue
19
+ paths.append(path)
20
+ return paths
21
+
22
+
23
+ def _extract_paths_from_begin_patch(diff_text: str) -> List[str]:
24
+ paths = []
25
+ for line in diff_text.splitlines():
26
+ if line.startswith("*** Update File: "):
27
+ paths.append(line.replace("*** Update File: ", "").strip())
28
+ if line.startswith("*** Add File: "):
29
+ paths.append(line.replace("*** Add File: ", "").strip())
30
+ if line.startswith("*** Delete File: "):
31
+ paths.append(line.replace("*** Delete File: ", "").strip())
32
+ return paths
33
+
34
+
35
+ def _is_begin_patch(diff_text: str) -> bool:
36
+ for line in diff_text.splitlines():
37
+ if line.strip():
38
+ return line.startswith("*** Begin Patch")
39
+ return False
40
+
41
+
42
+ def _is_valid_hunk_header(line: str) -> bool:
43
+ return re.match(r"^@@ -\d+(,\d+)? \+\d+(,\d+)? @@", line) is not None
44
+
45
+
46
+ def _is_valid_unified_diff(diff_text: str) -> bool:
47
+ if "diff --git " not in diff_text or "+++" not in diff_text or "---" not in diff_text:
48
+ return False
49
+ has_hunk = False
50
+ for line in diff_text.splitlines():
51
+ if line.startswith("@@"):
52
+ if not _is_valid_hunk_header(line):
53
+ return False
54
+ has_hunk = True
55
+ return has_hunk
56
+
57
+
58
+ def validate_patch_paths(diff_text: str, cwd: str) -> List[str]:
59
+ if _is_begin_patch(diff_text):
60
+ paths = _extract_paths_from_begin_patch(diff_text)
61
+ else:
62
+ paths = _extract_paths_from_diff(diff_text)
63
+ if not paths:
64
+ raise RuntimeError("Patch contains no file paths")
65
+ for rel in paths:
66
+ abs_path = os.path.abspath(os.path.join(cwd, rel))
67
+ if is_forbidden_path(abs_path, cwd):
68
+ raise RuntimeError(f"Patch touches forbidden path: {rel}")
69
+ return paths
70
+
71
+
72
+ def _parse_simple_blocks(diff_text: str) -> Tuple[List[str], List[str]]:
73
+ old_lines: List[str] = []
74
+ new_lines: List[str] = []
75
+ in_hunk = False
76
+ for line in diff_text.splitlines():
77
+ if line.startswith("diff --git ") or line.startswith("--- ") or line.startswith("+++ "):
78
+ in_hunk = True
79
+ continue
80
+ if line.startswith("*** Begin Patch") or line.startswith("*** End Patch"):
81
+ in_hunk = True
82
+ continue
83
+ if line.startswith("*** Update File: ") or line.startswith("*** Add File: ") or line.startswith("*** Delete File: "):
84
+ in_hunk = True
85
+ continue
86
+ if line.startswith("@@"):
87
+ in_hunk = True
88
+ continue
89
+ if not in_hunk:
90
+ continue
91
+ if line.startswith("+") and not line.startswith("+++ "):
92
+ new_lines.append(line[1:])
93
+ continue
94
+ if line.startswith("-") and not line.startswith("--- "):
95
+ old_lines.append(line[1:])
96
+ continue
97
+ if line.startswith(" "):
98
+ old_lines.append(line[1:])
99
+ new_lines.append(line[1:])
100
+ continue
101
+ return old_lines, new_lines
102
+
103
+
104
+ def _apply_simple_replace(path: str, old_lines: List[str], new_lines: List[str]) -> None:
105
+ if not old_lines and not new_lines:
106
+ raise RuntimeError("Fallback patch has no changes to apply")
107
+ old_block = "\n".join(old_lines)
108
+ new_block = "\n".join(new_lines)
109
+ with open(path, "r", encoding="utf-8", errors="replace") as f:
110
+ content = f.read()
111
+ count = content.count(old_block)
112
+ if count != 1:
113
+ raise RuntimeError("Fallback patch failed; old block not found exactly once")
114
+ content = content.replace(old_block, new_block, 1)
115
+ with open(path, "w", encoding="utf-8") as f:
116
+ f.write(content)
117
+
118
+
119
+ def apply_patch(diff_text: str, cwd: str) -> None:
120
+ paths = validate_patch_paths(diff_text, cwd)
121
+
122
+ if _is_valid_unified_diff(diff_text):
123
+ if is_git_repo(cwd):
124
+ cmd = ["git", "apply", "--whitespace=nowarn", "-"]
125
+ else:
126
+ cmd = ["git", "apply", "--no-index", "--whitespace=nowarn", "-"]
127
+ proc = subprocess.Popen(cmd, cwd=cwd, stdin=subprocess.PIPE, text=True)
128
+ proc.communicate(diff_text)
129
+ if proc.returncode == 0:
130
+ return
131
+
132
+ if len(paths) != 1:
133
+ raise RuntimeError("Fallback patch only supports single-file edits")
134
+
135
+ rel = paths[0]
136
+ abs_path = os.path.abspath(os.path.join(cwd, rel))
137
+ old_lines, new_lines = _parse_simple_blocks(diff_text)
138
+ _apply_simple_replace(abs_path, old_lines, new_lines)
@@ -0,0 +1,60 @@
1
+ import os
2
+ import re
3
+ from typing import Iterable
4
+
5
+ DENYLIST_NAMES = {
6
+ ".env",
7
+ ".env.local",
8
+ ".env.development",
9
+ ".env.production",
10
+ ".env.test",
11
+ }
12
+
13
+ DENYLIST_PATTERNS = [
14
+ re.compile(r".*\.key$", re.IGNORECASE),
15
+ re.compile(r".*\.pem$", re.IGNORECASE),
16
+ re.compile(r".*\.p12$", re.IGNORECASE),
17
+ re.compile(r".*\.pfx$", re.IGNORECASE),
18
+ re.compile(r".*\.crt$", re.IGNORECASE),
19
+ re.compile(r".*\.cer$", re.IGNORECASE),
20
+ re.compile(r".*id_rsa$", re.IGNORECASE),
21
+ re.compile(r".*id_ed25519$", re.IGNORECASE),
22
+ ]
23
+
24
+ DENYLIST_DIRS = {
25
+ ".git",
26
+ "node_modules",
27
+ "dist",
28
+ "build",
29
+ "__pycache__",
30
+ ".venv",
31
+ ".env",
32
+ ".ssh",
33
+ ".gnupg",
34
+ }
35
+
36
+
37
+ def is_forbidden_path(path: str, cwd: str) -> bool:
38
+ abs_path = os.path.abspath(path)
39
+ cwd_abs = os.path.abspath(cwd)
40
+ if not abs_path.startswith(cwd_abs):
41
+ return True
42
+
43
+ parts = os.path.relpath(abs_path, cwd_abs).split(os.sep)
44
+ for part in parts:
45
+ if part in DENYLIST_DIRS:
46
+ return True
47
+
48
+ name = os.path.basename(abs_path)
49
+ if name in DENYLIST_NAMES:
50
+ return True
51
+
52
+ for pat in DENYLIST_PATTERNS:
53
+ if pat.match(name):
54
+ return True
55
+
56
+ return False
57
+
58
+
59
+ def filter_allowed_paths(paths: Iterable[str], cwd: str) -> list:
60
+ return [p for p in paths if not is_forbidden_path(p, cwd)]
@@ -0,0 +1,40 @@
1
+ import os
2
+ import json
3
+ from datetime import datetime
4
+ from typing import Dict, Any, Optional
5
+
6
+ SESSION_DIR = ".stackfix/sessions"
7
+
8
+
9
+ def _session_dir(cwd: str) -> str:
10
+ path = os.path.join(cwd, SESSION_DIR)
11
+ os.makedirs(path, exist_ok=True)
12
+ return path
13
+
14
+
15
+ def new_session_id() -> str:
16
+ return datetime.utcnow().strftime("%Y%m%d_%H%M%S")
17
+
18
+
19
+ def save_session(cwd: str, session_id: str, state: Dict[str, Any]) -> str:
20
+ path = os.path.join(_session_dir(cwd), f"{session_id}.json")
21
+ with open(path, "w", encoding="utf-8") as f:
22
+ json.dump(state, f, indent=2)
23
+ return path
24
+
25
+
26
+ def load_session(cwd: str, session_id: str) -> Optional[Dict[str, Any]]:
27
+ path = os.path.join(_session_dir(cwd), f"{session_id}.json")
28
+ if not os.path.isfile(path):
29
+ return None
30
+ with open(path, "r", encoding="utf-8") as f:
31
+ return json.load(f)
32
+
33
+
34
+ def list_sessions(cwd: str) -> list:
35
+ path = _session_dir(cwd)
36
+ items = []
37
+ for name in os.listdir(path):
38
+ if name.endswith(".json"):
39
+ items.append(name[:-5])
40
+ return sorted(items)