diagram-to-iac 1.0.2__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,13 +4,22 @@ import argparse
4
4
  import json
5
5
  import sys
6
6
  import logging
7
+ import os
7
8
  from datetime import datetime
8
9
  from pathlib import Path
10
+ from typing import Optional
9
11
 
10
12
  from diagram_to_iac.agents.supervisor_langgraph import (
11
13
  SupervisorAgent,
12
14
  SupervisorAgentInput,
13
15
  )
16
+ from diagram_to_iac.agents.supervisor_langgraph.github_listener import (
17
+ GitHubListener,
18
+ RetryContext,
19
+ CommentEvent,
20
+ create_github_listener
21
+ )
22
+ from diagram_to_iac.core.registry import RunRegistry, RunStatus
14
23
  from diagram_to_iac.services import get_log_path, generate_step_summary, reset_log_bus
15
24
 
16
25
 
@@ -34,6 +43,16 @@ def create_argument_parser() -> argparse.ArgumentParser:
34
43
  parser.add_argument("--verbose", action="store_true", help="Enable verbose logging")
35
44
  parser.add_argument("--no-interactive", action="store_true", help="Skip interactive prompts")
36
45
  parser.add_argument("--dry-run", action="store_true", help="Print issue text instead of creating it")
46
+
47
+ # Comment listener options
48
+ parser.add_argument("--listen-comments", action="store_true",
49
+ help="Enable GitHub comment listening for retry commands")
50
+ parser.add_argument("--issue-id", type=int, help="Issue ID to monitor for comments")
51
+ parser.add_argument("--poll-interval", type=int, default=30,
52
+ help="Comment polling interval in seconds (default: 30)")
53
+ parser.add_argument("--max-polls", type=int,
54
+ help="Maximum number of polls (default: infinite)")
55
+
37
56
  return parser
38
57
 
39
58
 
@@ -56,13 +75,158 @@ def format_output(result: object) -> str:
56
75
  return str(result)
57
76
 
58
77
 
78
+ def handle_resume_workflow(context: RetryContext) -> bool:
79
+ """
80
+ Handle resuming an existing workflow.
81
+
82
+ Args:
83
+ context: RetryContext with resumption information
84
+
85
+ Returns:
86
+ True if resumption was successful, False otherwise
87
+ """
88
+ logger = logging.getLogger("supervisor_entry")
89
+
90
+ if not context.existing_run:
91
+ logger.error("No existing run to resume")
92
+ return False
93
+
94
+ try:
95
+ logger.info(f"Resuming run {context.existing_run.run_key}")
96
+
97
+ # Initialize registry and SupervisorAgent
98
+ registry = RunRegistry()
99
+ agent = SupervisorAgent(registry=registry)
100
+
101
+ # Update run status to clear wait reason if PAT is now available
102
+ pat_available = os.getenv('TFE_TOKEN') is not None
103
+ if pat_available and context.existing_run.status == RunStatus.WAITING_FOR_PAT:
104
+ logger.info("PAT token now available, clearing wait reason")
105
+ registry.update(context.existing_run.run_key, {
106
+ 'status': RunStatus.IN_PROGRESS,
107
+ 'wait_reason': None
108
+ })
109
+
110
+ # Resume the workflow from where it left off
111
+ reset_log_bus()
112
+ result = agent.resume_workflow(
113
+ context.existing_run.run_key,
114
+ context.target_sha or context.existing_run.commit_sha
115
+ )
116
+
117
+ logger.info(f"Resume workflow result: {result.success}")
118
+ return result.success
119
+
120
+ except Exception as e:
121
+ logger.error(f"Error resuming workflow: {e}")
122
+ return False
123
+
124
+
125
+ def handle_new_workflow(context: RetryContext) -> bool:
126
+ """
127
+ Handle starting a new workflow for manual retry requests.
128
+
129
+ Args:
130
+ context: RetryContext with new workflow information
131
+
132
+ Returns:
133
+ True if new workflow was started successfully, False otherwise
134
+ """
135
+ logger = logging.getLogger("supervisor_entry")
136
+
137
+ try:
138
+ logger.info(f"Starting new workflow for SHA {context.target_sha[:7] if context.target_sha else 'unknown'}")
139
+
140
+ # Initialize SupervisorAgent
141
+ agent = SupervisorAgent()
142
+
143
+ # Start new workflow
144
+ reset_log_bus()
145
+ result = agent.run(SupervisorAgentInput(
146
+ repo_url=context.comment_event.repo_url,
147
+ branch_name="main", # Placeholder - supervisor handles this
148
+ thread_id=f"retry-{context.comment_event.comment_id}",
149
+ commit_sha=context.target_sha
150
+ ))
151
+
152
+ logger.info(f"New workflow result: {result.success}")
153
+ return result.success
154
+
155
+ except Exception as e:
156
+ logger.error(f"Error starting new workflow: {e}")
157
+ return False
158
+
159
+
160
+ def start_comment_listener(repo_url: str, issue_id: int, poll_interval: int = 30,
161
+ max_polls: Optional[int] = None) -> None:
162
+ """
163
+ Start the GitHub comment listener.
164
+
165
+ Args:
166
+ repo_url: Repository URL to monitor
167
+ issue_id: Issue ID to monitor for comments
168
+ poll_interval: Seconds between polls
169
+ max_polls: Maximum number of polls
170
+ """
171
+ logger = logging.getLogger("supervisor_entry")
172
+
173
+ try:
174
+ # Create GitHub listener with callbacks
175
+ github_token = os.getenv('GITHUB_TOKEN')
176
+ registry = RunRegistry()
177
+ listener = create_github_listener(github_token=github_token, registry=registry)
178
+
179
+ # Set up callbacks
180
+ listener.set_callbacks(
181
+ resume_callback=handle_resume_workflow,
182
+ new_run_callback=handle_new_workflow
183
+ )
184
+
185
+ logger.info(f"Starting comment listener for issue #{issue_id} in {repo_url}")
186
+ logger.info(f"Poll interval: {poll_interval}s, Max polls: {max_polls or 'infinite'}")
187
+
188
+ # Start polling
189
+ listener.poll_issue_comments(
190
+ issue_id=issue_id,
191
+ repo_url=repo_url,
192
+ poll_interval=poll_interval,
193
+ max_polls=max_polls
194
+ )
195
+
196
+ except Exception as e:
197
+ logger.error(f"Error in comment listener: {e}")
198
+ raise
199
+
200
+
59
201
  def main() -> int:
60
202
  parser = create_argument_parser()
61
203
  args = parser.parse_args()
62
204
 
63
205
  setup_logging(args.verbose)
64
206
 
65
- # Handle repo url
207
+ # Handle comment listening mode
208
+ if args.listen_comments:
209
+ if not args.repo_url:
210
+ parser.error("--repo-url is required when using --listen-comments")
211
+ if not args.issue_id:
212
+ parser.error("--issue-id is required when using --listen-comments")
213
+
214
+ try:
215
+ start_comment_listener(
216
+ repo_url=args.repo_url,
217
+ issue_id=args.issue_id,
218
+ poll_interval=args.poll_interval,
219
+ max_polls=args.max_polls
220
+ )
221
+ return 0
222
+ except KeyboardInterrupt:
223
+ print("\n⚠️ Comment listener stopped by user")
224
+ return 0
225
+ except Exception as e:
226
+ logging.error(f"Comment listener failed: {e}")
227
+ return 1
228
+
229
+ # Handle normal workflow mode
66
230
  repo_url = args.repo_url
67
231
  if not repo_url and not args.no_interactive:
68
232
  repo_url = prompt_for_repo_url()
@@ -75,7 +239,6 @@ def main() -> int:
75
239
 
76
240
  agent = SupervisorAgent()
77
241
 
78
-
79
242
  while True:
80
243
  reset_log_bus()
81
244
  result = agent.run(
@@ -85,7 +248,6 @@ def main() -> int:
85
248
  thread_id=args.thread_id,
86
249
  dry_run=args.dry_run,
87
250
  )
88
-
89
251
  )
90
252
 
91
253
  print(format_output(result))
@@ -69,17 +69,15 @@ class GitAgentInput(BaseModel):
69
69
 
70
70
  class GitAgentOutput(BaseModel):
71
71
  """Output schema for GitAgent operations."""
72
- result: str = Field(..., description="The result of the DevOps operation")
73
- thread_id: str = Field(..., description="Thread ID used for the conversation")
74
- repo_path: Optional[str] = Field(None, description="Repository path for clone operations")
75
- error_message: Optional[str] = Field(None, description="Error message if the operation failed")
76
- operation_type: Optional[str] = Field(None, description="Type of operation performed (clone, issue, shell)")
77
- pr_url: Optional[str] = Field(None, description="URL of the created pull request, if any.")
72
+ success: bool = Field(..., description="Indicates if the operation was successful")
73
+ created_pr_id: Optional[int] = Field(None, description="ID of the created pull request, if any")
74
+ pr_url: Optional[str] = Field(None, description="URL of the created pull request, if any")
75
+ created_issue_id: Optional[int] = Field(None, description="ID of the created issue, if any")
76
+ issue_url: Optional[str] = Field(None, description="URL of the created issue, if any")
77
+ summary: Optional[str] = Field(None, description="Summary of the operation result")
78
+ artifacts: Optional[Dict[str, Any]] = Field(None, description="Optional artifacts returned by the operation")
78
79
 
79
- @property
80
- def answer(self) -> str:
81
- """Alias for result to match learning guide tests."""
82
- return self.result
80
+ model_config = {"extra": "ignore"}
83
81
 
84
82
 
85
83
  # --- Agent State Definition ---
@@ -848,7 +846,7 @@ Important: Only use routing tokens if the input contains actionable DevOps reque
848
846
  return "open_issue"
849
847
  elif final_result in ["route_to_shell", "ROUTE_TO_SHELL"]:
850
848
  return "shell_exec"
851
- elif final_result == self.config['routing_keys']['create_pr']:
849
+ elif final_result in ["route_to_create_pr", "ROUTE_TO_CREATE_PR"]:
852
850
  return "create_pr_node"
853
851
  elif final_result in ["route_to_end", "ROUTE_TO_END"]:
854
852
  return END
@@ -959,13 +957,37 @@ Important: Only use routing tokens if the input contains actionable DevOps reque
959
957
  result=final_result,
960
958
  )
961
959
 
960
+ # Determine success based on whether there was an error
961
+ success = error_message is None or error_message == ""
962
+
963
+ # Extract PR/Issue IDs from URLs if available
964
+ created_pr_id = None
965
+ created_issue_id = None
966
+ if pr_url:
967
+ # Extract PR ID from GitHub URL pattern
968
+ import re
969
+ pr_match = re.search(r'/pull/(\d+)', pr_url)
970
+ if pr_match:
971
+ created_pr_id = int(pr_match.group(1))
972
+ else:
973
+ # Check if it's an issue URL
974
+ issue_match = re.search(r'/issues/(\d+)', pr_url)
975
+ if issue_match:
976
+ created_issue_id = int(issue_match.group(1))
977
+
962
978
  output = GitAgentOutput(
963
- result=final_result,
964
- thread_id=current_thread_id,
965
- repo_path=repo_path,
966
- error_message=error_message,
967
- operation_type=operation_type,
968
- pr_url=pr_url
979
+ success=success,
980
+ created_pr_id=created_pr_id,
981
+ pr_url=pr_url if created_pr_id else None,
982
+ created_issue_id=created_issue_id,
983
+ issue_url=pr_url if created_issue_id else None,
984
+ summary=final_result,
985
+ artifacts={
986
+ "thread_id": current_thread_id,
987
+ "repo_path": repo_path,
988
+ "operation_type": operation_type,
989
+ "error_message": error_message
990
+ }
969
991
  )
970
992
  return output
971
993
 
@@ -977,12 +999,18 @@ Important: Only use routing tokens if the input contains actionable DevOps reque
977
999
  error=str(e),
978
1000
  )
979
1001
  return GitAgentOutput(
980
- result="An unexpected error occurred during execution.",
981
- thread_id=current_thread_id,
982
- repo_path=None, # Or more specifically result_state.get("repo_path") if available
983
- error_message=str(e),
984
- operation_type="error",
985
- pr_url=None
1002
+ success=False,
1003
+ created_pr_id=None,
1004
+ pr_url=None,
1005
+ created_issue_id=None,
1006
+ issue_url=None,
1007
+ summary="An unexpected error occurred during execution.",
1008
+ artifacts={
1009
+ "thread_id": current_thread_id,
1010
+ "repo_path": None, # Or more specifically result_state.get("repo_path") if available
1011
+ "operation_type": "error",
1012
+ "error_message": str(e)
1013
+ }
986
1014
  )
987
1015
 
988
1016
  def get_conversation_history(self) -> List[Dict[str, Any]]:
@@ -1042,11 +1070,15 @@ Important: Only use routing tokens if the input contains actionable DevOps reque
1042
1070
  if isinstance(result, GitAgentOutput):
1043
1071
  report = {
1044
1072
  "status": "completed",
1045
- "result": result.result,
1046
- "thread_id": result.thread_id,
1047
- "error": result.error_message,
1048
- "operation_type": result.operation_type,
1049
- "success": result.error_message is None
1073
+ "result": result.summary,
1074
+ "thread_id": result.artifacts.get("thread_id") if result.artifacts else None,
1075
+ "error": result.artifacts.get("error_message") if result.artifacts else None,
1076
+ "operation_type": result.artifacts.get("operation_type") if result.artifacts else None,
1077
+ "success": result.success,
1078
+ "pr_url": result.pr_url,
1079
+ "issue_url": result.issue_url,
1080
+ "created_pr_id": result.created_pr_id,
1081
+ "created_issue_id": result.created_issue_id
1050
1082
  }
1051
1083
  elif isinstance(result, str):
1052
1084
  report = {