mcp-github-agent 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tests/test_audit.py ADDED
@@ -0,0 +1,181 @@
1
+ """Tests for audit.py — JSONL logging & redaction."""
2
+ import io
3
+ import json
4
+ import os
5
+ from datetime import datetime
6
+ import tempfile
7
+ import unittest
8
+ from unittest.mock import Mock, patch
9
+
10
+ from src.audit import AuditLogger
11
+
12
+
13
+ class TestAuditLogger(unittest.TestCase):
14
+
15
+ def test_log_to_file(self):
16
+ path = tempfile.mktemp(suffix=".jsonl")
17
+ try:
18
+ logger = AuditLogger(sink=path)
19
+ logger.log(
20
+ tool="create_pr", action="pull_request.create",
21
+ repo="owner/repo",
22
+ policy_decision="allow", policy_rule="repo_allowlist:owner/*",
23
+ request_body={"title": "fix bug", "head": "patch", "base": "main"},
24
+ response={"number": 42, "html_url": "https://github.com/owner/repo/pull/42"},
25
+ )
26
+ logger.close()
27
+
28
+ with open(path, "r") as f:
29
+ entry = json.loads(f.readline())
30
+
31
+ self.assertEqual(entry["tool"], "create_pr")
32
+ self.assertEqual(entry["repo"], "owner/repo")
33
+ self.assertEqual(entry["policy"]["decision"], "allow")
34
+ self.assertIn("matched_rule", entry["policy"])
35
+ self.assertEqual(entry["request"]["title"], "fix bug")
36
+ self.assertEqual(entry["response"]["number"], 42)
37
+ self.assertEqual(entry["dry_run"], False)
38
+ self.assertIn("timestamp", entry)
39
+ self.assertIn("request_id", entry)
40
+ finally:
41
+ logger.close()
42
+ if os.path.exists(path):
43
+ os.unlink(path)
44
+
45
+ def test_log_dry_run(self):
46
+ path = tempfile.mktemp(suffix=".jsonl")
47
+ try:
48
+ logger = AuditLogger(sink=path)
49
+ logger.log(
50
+ tool="create_issue", action="issue.create",
51
+ repo="owner/repo", dry_run=True,
52
+ policy_decision="allow", policy_rule="default-allow",
53
+ request_body={"title": "test", "body": "test body"},
54
+ )
55
+ logger.close()
56
+
57
+ with open(path, "r") as f:
58
+ entry = json.loads(f.readline())
59
+
60
+ self.assertTrue(entry["dry_run"])
61
+ self.assertIsNone(entry["response"])
62
+ finally:
63
+ logger.close()
64
+ if os.path.exists(path):
65
+ os.unlink(path)
66
+
67
+ def test_log_denied(self):
68
+ path = tempfile.mktemp(suffix=".jsonl")
69
+ try:
70
+ logger = AuditLogger(sink=path)
71
+ logger.log(
72
+ tool="create_pr", action="pull_request.create",
73
+ repo="forbidden/repo",
74
+ policy_decision="deny",
75
+ policy_rule="repo_allowlist:deny_unlisted",
76
+ error="repo forbidden/repo not in allowlist",
77
+ )
78
+ logger.close()
79
+
80
+ with open(path, "r") as f:
81
+ entry = json.loads(f.readline())
82
+
83
+ self.assertEqual(entry["policy"]["decision"], "deny")
84
+ self.assertEqual(entry["error"], "repo forbidden/repo not in allowlist")
85
+ finally:
86
+ logger.close()
87
+ if os.path.exists(path):
88
+ os.unlink(path)
89
+
90
+ def test_redaction(self):
91
+ """Sensitive keys are redacted in audit logs."""
92
+ path = tempfile.mktemp(suffix=".jsonl")
93
+ try:
94
+ logger = AuditLogger(sink=path)
95
+ logger.log(
96
+ tool="create_pr", action="pull_request.create",
97
+ repo="owner/repo",
98
+ request_body={
99
+ "title": "test",
100
+ "token": "ghp_should_be_redacted",
101
+ "password": "secret123",
102
+ },
103
+ response={
104
+ "status": 201,
105
+ "authorization": "bearer xyz",
106
+ },
107
+ )
108
+ logger.close()
109
+
110
+ with open(path, "r") as f:
111
+ entry = json.loads(f.readline())
112
+
113
+ self.assertEqual(entry["request"]["token"], "***REDACTED***")
114
+ self.assertEqual(entry["request"]["password"], "***REDACTED***")
115
+ self.assertEqual(entry["response"]["authorization"], "***REDACTED***")
116
+ finally:
117
+ logger.close()
118
+ if os.path.exists(path):
119
+ os.unlink(path)
120
+
121
+ def test_log_format_redacts_nested_values_and_truncates_long_strings(self):
122
+ stream = io.StringIO()
123
+ with patch("src.audit.sys.stdout", stream):
124
+ logger = AuditLogger(sink="stdout")
125
+ logger.log(
126
+ tool="create_issue",
127
+ action="issue.create",
128
+ repo="owner/repo",
129
+ request_body={
130
+ "nested": {"api_key": "secret"},
131
+ "items": [{"auth": "bearer"}],
132
+ "body": "x" * 250,
133
+ },
134
+ )
135
+
136
+ entry = json.loads(stream.getvalue())
137
+ datetime.fromisoformat(entry["timestamp"])
138
+ self.assertEqual(len(entry["request_id"]), 12)
139
+ self.assertEqual(entry["request"]["nested"]["api_key"], "***REDACTED***")
140
+ self.assertEqual(entry["request"]["items"][0]["auth"], "***REDACTED***")
141
+ self.assertEqual(entry["request"]["body"], "x" * 200 + "...")
142
+
143
+ def test_multiple_operations_write_json_lines(self):
144
+ path = tempfile.mktemp(suffix=".jsonl")
145
+ try:
146
+ logger = AuditLogger(sink=path)
147
+ logger.log("tool1", "action1", "owner/repo")
148
+ logger.log("tool2", "action2", "owner/repo", error="failed")
149
+ logger.close()
150
+
151
+ with open(path, "r", encoding="utf-8") as f:
152
+ entries = [json.loads(line) for line in f]
153
+
154
+ self.assertEqual([e["tool"] for e in entries], ["tool1", "tool2"])
155
+ self.assertEqual(entries[1]["error"], "failed")
156
+ finally:
157
+ logger.close()
158
+ if os.path.exists(path):
159
+ os.unlink(path)
160
+
161
+ def test_invalid_relative_sink_and_allowlist_rejection(self):
162
+ with self.assertRaises(ValueError):
163
+ AuditLogger(sink="relative.jsonl")
164
+
165
+ with patch.dict(os.environ, {"GITHUB_AUDIT_DIR_ALLOWLIST": "/allowed"}, clear=False):
166
+ with self.assertRaises(ValueError):
167
+ AuditLogger(sink="/tmp/audit.jsonl")
168
+
169
+ def test_logging_errors_are_suppressed(self):
170
+ logger = AuditLogger(sink="stdout")
171
+ broken = Mock()
172
+ broken.write.side_effect = RuntimeError("disk full")
173
+
174
+ with patch.object(logger, "_get_stream", return_value=broken):
175
+ logger.log("tool", "action", "owner/repo")
176
+
177
+ broken.write.assert_called_once()
178
+
179
+
180
+ if __name__ == "__main__":
181
+ unittest.main()
tests/test_config.py ADDED
@@ -0,0 +1,41 @@
1
+ """Tests for config.py."""
2
+ import pytest
3
+
4
+ from src import config
5
+
6
+
7
+ def test_get_github_token_reads_env(monkeypatch):
8
+ monkeypatch.setenv("GITHUB_TOKEN", "ghp_token")
9
+ assert config.get_github_token() == "ghp_token"
10
+
11
+
12
+ def test_get_github_token_raises_when_missing(monkeypatch):
13
+ monkeypatch.delenv("GITHUB_TOKEN", raising=False)
14
+ with pytest.raises(ValueError, match="GITHUB_TOKEN not set"):
15
+ config.get_github_token()
16
+
17
+
18
+ def test_config_defaults_and_boolean_envs(monkeypatch):
19
+ monkeypatch.delenv("GITHUB_API_BASE", raising=False)
20
+ monkeypatch.delenv("GITHUB_POLICY_PATH", raising=False)
21
+ monkeypatch.delenv("GITHUB_POLICY_REQUIRED", raising=False)
22
+ monkeypatch.delenv("GITHUB_AUDIT_LOG", raising=False)
23
+ monkeypatch.delenv("GITHUB_DRY_RUN", raising=False)
24
+
25
+ assert config.get_github_api_base() == "https://api.github.com"
26
+ assert config.get_policy_path() == "policy.json"
27
+ assert config.get_policy_required() is False
28
+ assert config.get_audit_sink() == "stdout"
29
+ assert config.get_dry_run_enabled() is False
30
+
31
+ monkeypatch.setenv("GITHUB_API_BASE", "https://api.example")
32
+ monkeypatch.setenv("GITHUB_POLICY_PATH", "/tmp/policy.json")
33
+ monkeypatch.setenv("GITHUB_POLICY_REQUIRED", "yes")
34
+ monkeypatch.setenv("GITHUB_AUDIT_LOG", "stderr")
35
+ monkeypatch.setenv("GITHUB_DRY_RUN", "1")
36
+
37
+ assert config.get_github_api_base() == "https://api.example"
38
+ assert config.get_policy_path() == "/tmp/policy.json"
39
+ assert config.get_policy_required() is True
40
+ assert config.get_audit_sink() == "stderr"
41
+ assert config.get_dry_run_enabled() is True
@@ -0,0 +1,74 @@
1
+ """Tests for diff_parser.py"""
2
+ from src.diff_parser import parse_diff
3
+
4
+
5
+ def test_parse_diff_single_file():
6
+ diff = """diff --git a/src/app.py b/src/app.py
7
+ index 123..456 100644
8
+ --- a/src/app.py
9
+ +++ b/src/app.py
10
+ @@ -1,3 +1,4 @@
11
+ def hello():
12
+ + print("debug")
13
+ return True"""
14
+ files = parse_diff(diff)
15
+ assert len(files) == 1
16
+ assert files[0].path == "src/app.py"
17
+ assert 2 in files[0].added_lines # line 2: print("debug")
18
+
19
+
20
+ def test_parse_diff_multiple_files():
21
+ diff = """diff --git a/a.py b/a.py
22
+ --- a/a.py
23
+ +++ b/a.py
24
+ @@ -1,1 +1,2 @@
25
+ +import os
26
+ diff --git a/b.py b/b.py
27
+ --- a/b.py
28
+ +++ b/b.py
29
+ @@ -3,0 +4,1 @@
30
+ +logger.info("start")"""
31
+ files = parse_diff(diff)
32
+ assert len(files) == 2
33
+ assert {f.path for f in files} == {"a.py", "b.py"}
34
+
35
+
36
+ def test_parse_diff_no_added_lines():
37
+ diff = """diff --git a/x.py b/x.py
38
+ --- a/x.py
39
+ +++ b/x.py
40
+ @@ -1,1 +1,1 @@
41
+ -print("old")
42
+ +print("new")"""
43
+ files = parse_diff(diff)
44
+ assert len(files) == 1
45
+ assert len(files[0].added_lines) == 1 # "new" is an addition
46
+
47
+
48
+ def test_parse_diff_tracks_removed_lines_without_returning_delete_only_file():
49
+ diff = """diff --git a/x.py b/x.py
50
+ --- a/x.py
51
+ +++ b/x.py
52
+ @@ -10,2 +10,1 @@
53
+ context
54
+ -old
55
+ """
56
+ files = parse_diff(diff)
57
+ assert files == []
58
+
59
+
60
+ def test_parse_diff_line_numbers_advance_over_context_and_removals():
61
+ diff = """diff --git a/src/app.py b/src/app.py
62
+ --- a/src/app.py
63
+ +++ b/src/app.py
64
+ @@ -20,4 +20,5 @@
65
+ line 20
66
+ -old line
67
+ +new line
68
+ line 22
69
+ +another new line
70
+ """
71
+ files = parse_diff(diff)
72
+ assert len(files) == 1
73
+ assert files[0].added_lines == {21, 23}
74
+ assert files[0].removed_lines == {21}
@@ -0,0 +1,173 @@
1
+ """Tests for github_client.py."""
2
+ from types import SimpleNamespace
3
+ from unittest.mock import Mock, patch
4
+
5
+ from src.github_client import GitHubClient
6
+
7
+
8
+ class FakeHttpxClient:
9
+ def __init__(self, response=None, error=None):
10
+ self.response = response
11
+ self.error = error
12
+ self.get_calls = []
13
+ self.post_calls = []
14
+
15
+ def __enter__(self):
16
+ if self.error:
17
+ raise self.error
18
+ return self
19
+
20
+ def __exit__(self, exc_type, exc, tb):
21
+ return False
22
+
23
+ def get(self, *args, **kwargs):
24
+ self.get_calls.append((args, kwargs))
25
+ return self.response
26
+
27
+ def post(self, *args, **kwargs):
28
+ self.post_calls.append((args, kwargs))
29
+ return self.response
30
+
31
+
32
+ def test_client_initializes_headers_and_base_url():
33
+ client = GitHubClient("token", "https://api.example/")
34
+ assert client.base_url == "https://api.example"
35
+ assert client.headers["Authorization"] == "Bearer token"
36
+ assert client.headers["Accept"] == "application/vnd.github+json"
37
+
38
+
39
+ def test_search_code_formats_items_and_query_scope():
40
+ response = Mock()
41
+ response.json.return_value = {
42
+ "items": [
43
+ {
44
+ "path": "src/app.py",
45
+ "repository": {"full_name": "owner/repo"},
46
+ "html_url": "https://example/blob/src/app.py",
47
+ }
48
+ ]
49
+ }
50
+ fake = FakeHttpxClient(response=response)
51
+
52
+ with patch("src.github_client.httpx.Client", return_value=fake):
53
+ result = GitHubClient("token", "https://api.example").search_code("def main", "owner/repo")
54
+
55
+ assert result == {
56
+ "items": [
57
+ {
58
+ "path": "src/app.py",
59
+ "repo": "owner/repo",
60
+ "url": "https://example/blob/src/app.py",
61
+ }
62
+ ]
63
+ }
64
+ args, kwargs = fake.get_calls[0]
65
+ assert args[0] == "https://api.example/search/code"
66
+ assert kwargs["params"] == {"q": "repo:owner/repo def main"}
67
+ response.raise_for_status.assert_called_once()
68
+
69
+
70
+ def test_list_issues_success_and_error():
71
+ response = Mock()
72
+ response.json.return_value = [{"number": 1}]
73
+ fake = FakeHttpxClient(response=response)
74
+
75
+ with patch("src.github_client.httpx.Client", return_value=fake):
76
+ result = GitHubClient("token").list_issues("owner/repo", "closed")
77
+
78
+ assert result == [{"number": 1}]
79
+ assert fake.get_calls[0][1]["params"] == {"state": "closed"}
80
+
81
+ response.raise_for_status.side_effect = RuntimeError("401 Unauthorized")
82
+ with patch("src.github_client.httpx.Client", return_value=FakeHttpxClient(response=response)):
83
+ error = GitHubClient("token").list_issues("owner/repo")
84
+ assert error["error"].startswith("List issues failed: 401 Unauthorized")
85
+
86
+
87
+ def test_create_issue_posts_payload_and_handles_error():
88
+ response = Mock()
89
+ response.json.return_value = {"number": 2}
90
+ fake = FakeHttpxClient(response=response)
91
+
92
+ with patch("src.github_client.httpx.Client", return_value=fake):
93
+ result = GitHubClient("token").create_issue("owner/repo", "Title", "Body")
94
+
95
+ assert result == {"number": 2}
96
+ assert fake.post_calls[0][1]["json"] == {"title": "Title", "body": "Body"}
97
+
98
+ with patch("src.github_client.httpx.Client", return_value=FakeHttpxClient(error=RuntimeError("down"))):
99
+ error = GitHubClient("token").create_issue("owner/repo", "Title", "Body")
100
+ assert error == {"error": "Create issue failed: down"}
101
+
102
+
103
+ def test_get_pr_diff_uses_diff_accept_header_and_handles_error():
104
+ response = Mock()
105
+ response.text = "diff text"
106
+ fake = FakeHttpxClient(response=response)
107
+
108
+ with patch("src.github_client.httpx.Client", return_value=fake):
109
+ result = GitHubClient("token", "https://api.example").get_pr_diff("owner/repo", 3)
110
+
111
+ assert result == {"diff": "diff text"}
112
+ args, kwargs = fake.get_calls[0]
113
+ assert args[0] == "https://api.example/repos/owner/repo/pulls/3"
114
+ assert kwargs["headers"]["Accept"] == "application/vnd.github.v3.diff"
115
+
116
+ response.raise_for_status.side_effect = RuntimeError("404")
117
+ with patch("src.github_client.httpx.Client", return_value=FakeHttpxClient(response=response)):
118
+ error = GitHubClient("token").get_pr_diff("owner/repo", 3)
119
+ assert error == {"error": "Get PR diff failed: 404"}
120
+
121
+
122
+ def test_create_pr_posts_payload_and_handles_error():
123
+ response = Mock()
124
+ response.json.return_value = {"number": 4}
125
+ fake = FakeHttpxClient(response=response)
126
+
127
+ with patch("src.github_client.httpx.Client", return_value=fake):
128
+ result = GitHubClient("token").create_pr(
129
+ "owner/repo", "Title", "Body", "feature", "main"
130
+ )
131
+
132
+ assert result == {"number": 4}
133
+ assert fake.post_calls[0][1]["json"] == {
134
+ "title": "Title",
135
+ "body": "Body",
136
+ "head": "feature",
137
+ "base": "main",
138
+ }
139
+
140
+ with patch("src.github_client.httpx.Client", return_value=FakeHttpxClient(error=RuntimeError("403"))):
141
+ error = GitHubClient("token").create_pr("owner/repo", "Title", "Body", "h", "b")
142
+ assert error == {"error": "Create PR failed: 403"}
143
+
144
+
145
+ def test_create_review_comment_payload_variants_and_error():
146
+ response = Mock()
147
+ response.json.return_value = {"id": 10}
148
+ fake = FakeHttpxClient(response=response)
149
+
150
+ with patch("src.github_client.httpx.Client", return_value=fake):
151
+ result = GitHubClient("token").create_review_comment(
152
+ "owner/repo", 5, "body", commit_id="abc", path="src/app.py", line=7
153
+ )
154
+
155
+ assert result == {"id": 10}
156
+ assert fake.post_calls[0][1]["json"] == {
157
+ "body": "body",
158
+ "commit_id": "abc",
159
+ "path": "src/app.py",
160
+ "line": 7,
161
+ }
162
+
163
+ response2 = Mock()
164
+ response2.json.return_value = {"id": 11}
165
+ fake2 = FakeHttpxClient(response=response2)
166
+ with patch("src.github_client.httpx.Client", return_value=fake2):
167
+ GitHubClient("token").create_review_comment("owner/repo", 5, "body")
168
+ assert fake2.post_calls[0][1]["json"] == {"body": "body"}
169
+
170
+ response2.raise_for_status.side_effect = RuntimeError("validation")
171
+ with patch("src.github_client.httpx.Client", return_value=FakeHttpxClient(response=response2)):
172
+ error = GitHubClient("token").create_review_comment("owner/repo", 5, "body")
173
+ assert error == {"error": "Review comment failed: validation"}
tests/test_main.py ADDED
@@ -0,0 +1,48 @@
1
+ """Tests for main.py."""
2
+ from unittest.mock import patch
3
+
4
+ import pytest
5
+
6
+ from src import main as main_module
7
+
8
+
9
+ def test_main_runs_mcp_stdio_transport():
10
+ with patch("src.main.signal.signal") as mock_signal, \
11
+ patch.object(main_module.mcp, "run") as mock_run:
12
+ main_module.main()
13
+
14
+ mock_signal.assert_called_once()
15
+ mock_run.assert_called_once_with(transport="stdio")
16
+
17
+
18
+ def test_main_exits_cleanly_on_keyboard_interrupt():
19
+ with patch.object(main_module.mcp, "run", side_effect=KeyboardInterrupt), \
20
+ patch("src.main.sys.exit") as mock_exit:
21
+ main_module.main()
22
+
23
+ mock_exit.assert_called_once_with(0)
24
+
25
+
26
+ def test_main_prints_and_exits_on_startup_error(capsys):
27
+ with patch.object(main_module.mcp, "run", side_effect=RuntimeError("boom")), \
28
+ patch("src.main.sys.exit") as mock_exit:
29
+ main_module.main()
30
+
31
+ assert "Error starting MCP server: boom" in capsys.readouterr().err
32
+ mock_exit.assert_called_once_with(1)
33
+
34
+
35
+ def test_sigterm_handler_exits_cleanly():
36
+ captured = {}
37
+
38
+ def capture_handler(_signum, handler):
39
+ captured["handler"] = handler
40
+
41
+ with patch("src.main.signal.signal", side_effect=capture_handler), \
42
+ patch.object(main_module.mcp, "run"), \
43
+ patch("src.main.sys.exit", side_effect=SystemExit(0)) as mock_exit:
44
+ main_module.main()
45
+ with pytest.raises(SystemExit):
46
+ captured["handler"](15, None)
47
+
48
+ mock_exit.assert_called_once_with(0)
tests/test_policy.py ADDED
@@ -0,0 +1,133 @@
1
+ """Tests for policy.py — allowlist & branch protection."""
2
+ import json
3
+ import os
4
+ import tempfile
5
+ import unittest
6
+
7
+ from src.policy import PolicyConfig, resolve_dry_run
8
+
9
+
10
+ class TestPolicyConfig(unittest.TestCase):
11
+
12
+ def _write_policy(self, data: dict) -> str:
13
+ """Write a temporary policy.json and return its path."""
14
+ f = tempfile.NamedTemporaryFile(
15
+ mode="w", suffix=".json", delete=False, encoding="utf-8"
16
+ )
17
+ json.dump(data, f)
18
+ f.close()
19
+ return f.name
20
+
21
+ def test_empty_config_default_allow(self):
22
+ cfg = PolicyConfig()
23
+ # not loaded → allows everything
24
+ self.assertEqual(cfg.check_repo("anything/here").action, "allow")
25
+ self.assertEqual(cfg.check_branch_for_pr("main").action, "allow")
26
+
27
+ def test_load_missing_file_not_required(self):
28
+ cfg = PolicyConfig().load("/nonexistent/path.json", required=False)
29
+ # no crash, default allow
30
+ self.assertEqual(cfg.check_repo("x/y").action, "allow")
31
+
32
+ def test_load_missing_file_required_raises(self):
33
+ with self.assertRaises(FileNotFoundError):
34
+ PolicyConfig().load("/nonexistent/path.json", required=True)
35
+
36
+ def test_repo_allowlist_exact_match(self):
37
+ path = self._write_policy({"repo_allowlist": ["FMorgan-111/test"]})
38
+ try:
39
+ cfg = PolicyConfig().load(path)
40
+ self.assertEqual(cfg.check_repo("FMorgan-111/test").action, "allow")
41
+ self.assertEqual(cfg.check_repo("other/repo").action, "deny")
42
+ finally:
43
+ os.unlink(path)
44
+
45
+ def test_repo_allowlist_wildcard(self):
46
+ path = self._write_policy({"repo_allowlist": ["FMorgan-111/*"]})
47
+ try:
48
+ cfg = PolicyConfig().load(path)
49
+ self.assertEqual(cfg.check_repo("FMorgan-111/foo").action, "allow")
50
+ self.assertEqual(cfg.check_repo("FMorgan-111/bar").action, "allow")
51
+ self.assertEqual(cfg.check_repo("other/repo").action, "deny")
52
+ finally:
53
+ os.unlink(path)
54
+
55
+ def test_wildcard_match_is_anchored(self):
56
+ path = self._write_policy({"repo_allowlist": ["org/*-service"]})
57
+ try:
58
+ cfg = PolicyConfig().load(path)
59
+ self.assertEqual(cfg.check_repo("org/api-service").action, "allow")
60
+ self.assertEqual(cfg.check_repo("prefix/org/api-service").action, "deny")
61
+ self.assertEqual(cfg.check_repo("org/api-service-extra").action, "deny")
62
+ finally:
63
+ os.unlink(path)
64
+
65
+ def test_single_string_policy_values_are_wrapped(self):
66
+ path = self._write_policy({
67
+ "repo_allowlist": "owner/repo",
68
+ "protected_branches": {"deny_pr_base": "release/*"},
69
+ })
70
+ try:
71
+ cfg = PolicyConfig().load(path)
72
+ self.assertEqual(cfg.repo_allowlist, ["owner/repo"])
73
+ self.assertEqual(cfg.deny_pr_base, ["release/*"])
74
+ self.assertEqual(cfg.check_branch_for_pr("release/1.0").action, "deny")
75
+ finally:
76
+ os.unlink(path)
77
+
78
+ def test_invalid_policy_required_raises_and_optional_keeps_defaults(self):
79
+ f = tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False)
80
+ f.write("{invalid")
81
+ f.close()
82
+ try:
83
+ with self.assertRaises(RuntimeError):
84
+ PolicyConfig().load(f.name, required=True)
85
+ cfg = PolicyConfig().load(f.name, required=False)
86
+ self.assertEqual(cfg.check_repo("owner/repo").action, "deny")
87
+ finally:
88
+ os.unlink(f.name)
89
+
90
+ def test_non_list_values_are_wrapped(self):
91
+ path = self._write_policy({
92
+ "repo_allowlist": 123,
93
+ "protected_branches": {"deny_force_push": False},
94
+ })
95
+ try:
96
+ cfg = PolicyConfig().load(path)
97
+ self.assertEqual(cfg.repo_allowlist, [123])
98
+ self.assertFalse(cfg.deny_force_push)
99
+ finally:
100
+ os.unlink(path)
101
+
102
+ def test_allowlist_empty_means_allow_all(self):
103
+ path = self._write_policy({"repo_allowlist": []})
104
+ try:
105
+ cfg = PolicyConfig().load(path)
106
+ self.assertEqual(cfg.check_repo("anything/goes").action, "allow")
107
+ finally:
108
+ os.unlink(path)
109
+
110
+ def test_branch_protection_deny_main_master(self):
111
+ path = self._write_policy({
112
+ "repo_allowlist": [],
113
+ "protected_branches": {"deny_pr_base": ["main", "master"]},
114
+ })
115
+ try:
116
+ cfg = PolicyConfig().load(path)
117
+ self.assertEqual(cfg.check_branch_for_pr("main").action, "deny")
118
+ self.assertEqual(cfg.check_branch_for_pr("master").action, "deny")
119
+ self.assertEqual(cfg.check_branch_for_pr("develop").action, "allow")
120
+ finally:
121
+ os.unlink(path)
122
+
123
+ def test_resolve_dry_run(self):
124
+ # explicit arg wins
125
+ self.assertTrue(resolve_dry_run(True, False))
126
+ self.assertFalse(resolve_dry_run(False, True))
127
+ # None falls back to env
128
+ self.assertTrue(resolve_dry_run(None, True))
129
+ self.assertFalse(resolve_dry_run(None, False))
130
+
131
+
132
+ if __name__ == "__main__":
133
+ unittest.main()