alita-sdk 0.3.161__py3-none-any.whl → 0.3.163__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- alita_sdk/tools/__init__.py +1 -0
- alita_sdk/tools/carrier/api_wrapper.py +74 -2
- alita_sdk/tools/carrier/cancel_ui_test_tool.py +178 -0
- alita_sdk/tools/carrier/carrier_sdk.py +71 -3
- alita_sdk/tools/carrier/create_ui_excel_report_tool.py +473 -0
- alita_sdk/tools/carrier/create_ui_test_tool.py +199 -0
- alita_sdk/tools/carrier/lighthouse_excel_reporter.py +155 -0
- alita_sdk/tools/carrier/run_ui_test_tool.py +394 -0
- alita_sdk/tools/carrier/tools.py +11 -1
- alita_sdk/tools/carrier/ui_reports_tool.py +6 -2
- alita_sdk/tools/carrier/update_ui_test_schedule_tool.py +278 -0
- alita_sdk/tools/github/__init__.py +1 -0
- alita_sdk/tools/github/api_wrapper.py +4 -2
- alita_sdk/tools/github/github_client.py +124 -1
- alita_sdk/tools/github/schemas.py +15 -0
- alita_sdk/tools/zephyr_squad/__init__.py +62 -0
- alita_sdk/tools/zephyr_squad/api_wrapper.py +135 -0
- alita_sdk/tools/zephyr_squad/zephyr_squad_cloud_client.py +79 -0
- {alita_sdk-0.3.161.dist-info → alita_sdk-0.3.163.dist-info}/METADATA +1 -1
- {alita_sdk-0.3.161.dist-info → alita_sdk-0.3.163.dist-info}/RECORD +23 -14
- {alita_sdk-0.3.161.dist-info → alita_sdk-0.3.163.dist-info}/WHEEL +0 -0
- {alita_sdk-0.3.161.dist-info → alita_sdk-0.3.163.dist-info}/licenses/LICENSE +0 -0
- {alita_sdk-0.3.161.dist-info → alita_sdk-0.3.163.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,278 @@
|
|
1
|
+
import logging
|
2
|
+
import json
|
3
|
+
import traceback
|
4
|
+
import re
|
5
|
+
from typing import Type
|
6
|
+
from langchain_core.tools import BaseTool, ToolException
|
7
|
+
from pydantic.fields import Field
|
8
|
+
from pydantic import create_model, BaseModel
|
9
|
+
from .api_wrapper import CarrierAPIWrapper
|
10
|
+
|
11
|
+
|
12
|
+
logger = logging.getLogger(__name__)
|
13
|
+
|
14
|
+
|
15
|
+
class UpdateUITestScheduleTool(BaseTool):
|
16
|
+
api_wrapper: CarrierAPIWrapper = Field(..., description="Carrier API Wrapper instance")
|
17
|
+
name: str = "update_ui_test_schedule"
|
18
|
+
description: str = ("Update UI test schedule on the Carrier platform. Use this tool when user wants to update, modify, or change a UI test schedule. "
|
19
|
+
"Provide test_id, schedule_name, and cron_timer, or leave empty to see available tests.")
|
20
|
+
args_schema: Type[BaseModel] = create_model(
|
21
|
+
"UpdateUITestScheduleInput",
|
22
|
+
test_id=(str, Field(default="", description="Test ID to update schedule for")),
|
23
|
+
schedule_name=(str, Field(default="", description="Name for the new schedule")),
|
24
|
+
cron_timer=(str, Field(default="", description="Cron expression for schedule timing (e.g., '0 2 * * *')")),
|
25
|
+
)
|
26
|
+
|
27
|
+
def _run(self, test_id: str = "", schedule_name: str = "", cron_timer: str = ""):
|
28
|
+
try:
|
29
|
+
# Check if no parameters provided - show available tests
|
30
|
+
if (not test_id or test_id.strip() == "") and (not schedule_name or schedule_name.strip() == "") and (not cron_timer or cron_timer.strip() == ""):
|
31
|
+
return self._show_available_tests_and_instructions()
|
32
|
+
|
33
|
+
# Check if test_id is missing but other params provided
|
34
|
+
if (not test_id or test_id.strip() == ""):
|
35
|
+
return self._show_missing_test_id_message()
|
36
|
+
|
37
|
+
# Check if schedule_name or cron_timer is missing
|
38
|
+
if (not schedule_name or schedule_name.strip() == "") or (not cron_timer or cron_timer.strip() == ""):
|
39
|
+
return self._show_missing_parameters_message(test_id, schedule_name, cron_timer)
|
40
|
+
|
41
|
+
# Validate cron timer format
|
42
|
+
if not self._validate_cron_timer(cron_timer):
|
43
|
+
return self._show_invalid_cron_message(cron_timer)
|
44
|
+
|
45
|
+
# Get UI tests list to verify test exists
|
46
|
+
ui_tests = self.api_wrapper.get_ui_tests_list()
|
47
|
+
test_data = None
|
48
|
+
test_id_int = None
|
49
|
+
|
50
|
+
# Try to find test by ID
|
51
|
+
if test_id.isdigit():
|
52
|
+
test_id_int = int(test_id)
|
53
|
+
for test in ui_tests:
|
54
|
+
if test.get("id") == test_id_int:
|
55
|
+
test_data = test
|
56
|
+
break
|
57
|
+
|
58
|
+
if not test_data:
|
59
|
+
available_tests = []
|
60
|
+
for test in ui_tests:
|
61
|
+
available_tests.append(f"ID: {test.get('id')}, Name: {test.get('name')}")
|
62
|
+
|
63
|
+
return f"❌ **Test not found for ID: {test_id}**\n\n**Available UI tests:**\n" + "\n".join([f"- {test}" for test in available_tests])
|
64
|
+
|
65
|
+
# Get detailed test configuration
|
66
|
+
test_details = self.api_wrapper.get_ui_test_details(str(test_id_int))
|
67
|
+
|
68
|
+
if not test_details:
|
69
|
+
return f"❌ **Could not retrieve test details for test ID {test_id_int}.**"
|
70
|
+
|
71
|
+
# Parse and update the test configuration
|
72
|
+
updated_config = self._parse_and_update_test_data(test_details, schedule_name, cron_timer)
|
73
|
+
|
74
|
+
# Execute the PUT request to update the test
|
75
|
+
result = self.api_wrapper.update_ui_test(str(test_id_int), updated_config)
|
76
|
+
|
77
|
+
return self._format_success_message(test_data.get('name', 'Unknown'), test_id_int, schedule_name, cron_timer)
|
78
|
+
|
79
|
+
except Exception:
|
80
|
+
stacktrace = traceback.format_exc()
|
81
|
+
logger.error(f"Error updating UI test schedule: {stacktrace}")
|
82
|
+
raise ToolException(stacktrace)
|
83
|
+
|
84
|
+
def _show_available_tests_and_instructions(self):
|
85
|
+
"""Show available tests and instructions when no parameters provided."""
|
86
|
+
try:
|
87
|
+
ui_tests = self.api_wrapper.get_ui_tests_list()
|
88
|
+
|
89
|
+
if not ui_tests:
|
90
|
+
return "❌ **No UI tests found.**"
|
91
|
+
|
92
|
+
message = ["# 📋 Update UI Test Schedule\n"]
|
93
|
+
message.append("## Available UI Tests:")
|
94
|
+
|
95
|
+
for test in ui_tests:
|
96
|
+
message.append(f"- **ID: {test.get('id')}**, Name: `{test.get('name')}`, Runner: `{test.get('runner')}`")
|
97
|
+
|
98
|
+
message.append("\n## 📝 Instructions:")
|
99
|
+
message.append("For updating UI test schedule, please provide me:")
|
100
|
+
message.append("- **`test_id`** - The ID of the test you want to update")
|
101
|
+
message.append("- **`schedule_name`** - A name for your new schedule")
|
102
|
+
message.append("- **`cron_timer`** - Cron expression for timing (e.g., `0 2 * * *` for daily at 2 AM)")
|
103
|
+
|
104
|
+
message.append("\n## 💡 Example:")
|
105
|
+
message.append("```")
|
106
|
+
message.append("test_id: 42")
|
107
|
+
message.append("schedule_name: Daily Morning Test")
|
108
|
+
message.append("cron_timer: 0 2 * * *")
|
109
|
+
message.append("```")
|
110
|
+
|
111
|
+
return "\n".join(message)
|
112
|
+
|
113
|
+
except Exception:
|
114
|
+
stacktrace = traceback.format_exc()
|
115
|
+
logger.error(f"Error fetching UI tests list: {stacktrace}")
|
116
|
+
raise ToolException(stacktrace)
|
117
|
+
|
118
|
+
def _show_missing_test_id_message(self):
|
119
|
+
"""Show message when test_id is missing."""
|
120
|
+
return """# ❌ Missing Test ID
|
121
|
+
|
122
|
+
**For updating UI test schedule, please provide me:**
|
123
|
+
- **`test_id`** - The ID of the test you want to update
|
124
|
+
- **`schedule_name`** - A name for your new schedule
|
125
|
+
- **`cron_timer`** - Cron expression for timing
|
126
|
+
|
127
|
+
Use the tool without parameters to see available tests."""
|
128
|
+
|
129
|
+
def _show_missing_parameters_message(self, test_id: str, schedule_name: str, cron_timer: str):
|
130
|
+
"""Show message when some parameters are missing."""
|
131
|
+
missing = []
|
132
|
+
if not schedule_name or schedule_name.strip() == "":
|
133
|
+
missing.append("**`schedule_name`**")
|
134
|
+
if not cron_timer or cron_timer.strip() == "":
|
135
|
+
missing.append("**`cron_timer`**")
|
136
|
+
|
137
|
+
message = [f"# ❌ Missing Parameters for Test ID: {test_id}\n"]
|
138
|
+
message.append("**Missing parameters:**")
|
139
|
+
for param in missing:
|
140
|
+
message.append(f"- {param}")
|
141
|
+
|
142
|
+
message.append("\n**For updating UI test schedule, please provide:**")
|
143
|
+
message.append("- **`test_id`** ✅ (provided)")
|
144
|
+
message.append("- **`schedule_name`** - A name for your new schedule")
|
145
|
+
message.append("- **`cron_timer`** - Cron expression for timing (e.g., `0 2 * * *`)")
|
146
|
+
|
147
|
+
return "\n".join(message)
|
148
|
+
|
149
|
+
def _validate_cron_timer(self, cron_timer: str) -> bool:
|
150
|
+
"""Validate cron timer format."""
|
151
|
+
# Basic cron validation - should have 5 parts separated by spaces
|
152
|
+
parts = cron_timer.strip().split()
|
153
|
+
if len(parts) != 5:
|
154
|
+
return False
|
155
|
+
|
156
|
+
# Each part should contain only digits, *, /, -, or ,
|
157
|
+
cron_pattern = re.compile(r'^[0-9*,/-]+$')
|
158
|
+
return all(cron_pattern.match(part) for part in parts)
|
159
|
+
|
160
|
+
def _show_invalid_cron_message(self, cron_timer: str):
|
161
|
+
"""Show message for invalid cron timer."""
|
162
|
+
return f"""# ❌ Invalid Cron Timer Format
|
163
|
+
|
164
|
+
**Provided:** `{cron_timer}`
|
165
|
+
|
166
|
+
**Cron format should be:** `minute hour day month weekday`
|
167
|
+
|
168
|
+
## Valid Examples:
|
169
|
+
- `0 2 * * *` - Daily at 2:00 AM
|
170
|
+
- `30 14 * * 1` - Every Monday at 2:30 PM
|
171
|
+
- `0 */6 * * *` - Every 6 hours
|
172
|
+
- `15 10 1 * *` - First day of every month at 10:15 AM
|
173
|
+
- `0 9 * * 1-5` - Weekdays at 9:00 AM
|
174
|
+
|
175
|
+
## Format Rules:
|
176
|
+
- **Minute:** 0-59
|
177
|
+
- **Hour:** 0-23
|
178
|
+
- **Day:** 1-31
|
179
|
+
- **Month:** 1-12
|
180
|
+
- **Weekday:** 0-7 (0 and 7 are Sunday)
|
181
|
+
- Use **`*`** for "any value"
|
182
|
+
- Use **`,`** for multiple values
|
183
|
+
- Use **`-`** for ranges
|
184
|
+
- Use **`/`** for step values"""
|
185
|
+
|
186
|
+
def _parse_and_update_test_data(self, get_data: dict, schedule_name: str, cron_timer: str) -> dict:
|
187
|
+
"""Parse GET response data and transform it into the required format for PUT request."""
|
188
|
+
|
189
|
+
# Extract environment and test type from test parameters
|
190
|
+
env_type = ""
|
191
|
+
test_type = ""
|
192
|
+
for param in get_data.get("test_parameters", []):
|
193
|
+
if param.get("name") == "env_type":
|
194
|
+
env_type = param.get("default", "")
|
195
|
+
elif param.get("name") == "test_type":
|
196
|
+
test_type = param.get("default", "")
|
197
|
+
|
198
|
+
# Construct common_params from GET data
|
199
|
+
common_params = {
|
200
|
+
"aggregation": get_data.get("aggregation", "max"),
|
201
|
+
"cc_env_vars": get_data.get("cc_env_vars", {}),
|
202
|
+
"entrypoint": get_data.get("entrypoint", ""),
|
203
|
+
"env_type": env_type,
|
204
|
+
"env_vars": get_data.get("env_vars", {}),
|
205
|
+
"location": get_data.get("location", ""),
|
206
|
+
"loops": get_data.get("loops", 1),
|
207
|
+
"name": get_data.get("name", ""),
|
208
|
+
"parallel_runners": get_data.get("parallel_runners", 1),
|
209
|
+
"runner": get_data.get("runner", ""),
|
210
|
+
"source": get_data.get("source", {}),
|
211
|
+
"test_type": test_type
|
212
|
+
}
|
213
|
+
|
214
|
+
# Extract only required integrations (reporters and system)
|
215
|
+
integrations = {
|
216
|
+
"reporters": get_data.get("integrations", {}).get("reporters", {}),
|
217
|
+
"system": get_data.get("integrations", {}).get("system", {})
|
218
|
+
}
|
219
|
+
|
220
|
+
# Process existing schedules and add the new one
|
221
|
+
schedules = []
|
222
|
+
|
223
|
+
# Keep existing schedules
|
224
|
+
for schedule in get_data.get("schedules", []):
|
225
|
+
existing_schedule = {
|
226
|
+
"active": schedule.get("active", False),
|
227
|
+
"cron": schedule.get("cron", ""),
|
228
|
+
"cron_radio": "custom",
|
229
|
+
"errors": {},
|
230
|
+
"id": schedule.get("id"),
|
231
|
+
"name": schedule.get("name", ""),
|
232
|
+
"project_id": schedule.get("project_id"),
|
233
|
+
"rpc_kwargs": schedule.get("rpc_kwargs"),
|
234
|
+
"test_id": schedule.get("test_id"),
|
235
|
+
"test_params": schedule.get("test_params", [])
|
236
|
+
}
|
237
|
+
schedules.append(existing_schedule)
|
238
|
+
|
239
|
+
# Add the new schedule
|
240
|
+
new_schedule = {
|
241
|
+
"active": True,
|
242
|
+
"cron": cron_timer,
|
243
|
+
"cron_radio": "custom",
|
244
|
+
"errors": {},
|
245
|
+
"id": None, # New schedule, no ID yet
|
246
|
+
"name": schedule_name,
|
247
|
+
"test_params": []
|
248
|
+
}
|
249
|
+
schedules.append(new_schedule)
|
250
|
+
|
251
|
+
# Assemble the final PUT request data
|
252
|
+
put_data = {
|
253
|
+
"common_params": common_params,
|
254
|
+
"integrations": integrations,
|
255
|
+
"run_test": False,
|
256
|
+
"schedules": schedules,
|
257
|
+
"test_parameters": [] # Empty as required in PUT request
|
258
|
+
}
|
259
|
+
|
260
|
+
return put_data
|
261
|
+
|
262
|
+
def _format_success_message(self, test_name: str, test_id: int, schedule_name: str, cron_timer: str) -> str:
|
263
|
+
"""Format success message in markdown."""
|
264
|
+
return f"""# ✅ UI Test Schedule Updated Successfully!
|
265
|
+
|
266
|
+
## Test Information:
|
267
|
+
- **Test Name:** `{test_name}`
|
268
|
+
- **Test ID:** `{test_id}`
|
269
|
+
|
270
|
+
## New Schedule Added:
|
271
|
+
- **Schedule Name:** `{schedule_name}`
|
272
|
+
- **Cron Timer:** `{cron_timer}`
|
273
|
+
- **Status:** Active ✅
|
274
|
+
|
275
|
+
## 🎯 What happens next:
|
276
|
+
The test will now run automatically according to the specified schedule. You can view and manage schedules in the Carrier platform UI.
|
277
|
+
|
278
|
+
**Schedule will execute:** Based on cron expression `{cron_timer}`"""
|
@@ -23,6 +23,7 @@ def _get_toolkit(tool) -> BaseToolkit:
|
|
23
23
|
github_app_id=tool['settings'].get('app_id', None),
|
24
24
|
github_app_private_key=tool['settings'].get('app_private_key', None),
|
25
25
|
llm=tool['settings'].get('llm', None),
|
26
|
+
alita=tool['settings'].get('alita', None),
|
26
27
|
connection_string=tool['settings'].get('connection_string', None),
|
27
28
|
collection_name=str(tool['id']),
|
28
29
|
doctype='code',
|
@@ -51,7 +51,9 @@ class AlitaGitHubAPIWrapper(BaseCodeToolApiWrapper):
|
|
51
51
|
|
52
52
|
# Add LLM instance
|
53
53
|
llm: Optional[Any] = None
|
54
|
-
|
54
|
+
# Alita instance
|
55
|
+
alita: Optional[Any] = None
|
56
|
+
|
55
57
|
# Vector store configuration
|
56
58
|
connection_string: Optional[SecretStr] = None
|
57
59
|
collection_name: Optional[str] = None
|
@@ -109,7 +111,7 @@ class AlitaGitHubAPIWrapper(BaseCodeToolApiWrapper):
|
|
109
111
|
)
|
110
112
|
|
111
113
|
# Initialize GitHub client with keyword arguments
|
112
|
-
github_client = GitHubClient(auth_config=auth_config, repo_config=repo_config)
|
114
|
+
github_client = GitHubClient(auth_config=auth_config, repo_config=repo_config, alita=values.get("alita"))
|
113
115
|
# Initialize GraphQL client with keyword argument
|
114
116
|
graphql_client = GraphQLClientWrapper(github_graphql_instance=github_client.github_api._Github__requester)
|
115
117
|
# Set client attributes on the class (renamed from _github_client to github_client_instance)
|
@@ -34,10 +34,11 @@ from .schemas import (
|
|
34
34
|
SearchIssues,
|
35
35
|
CreateIssue,
|
36
36
|
UpdateIssue,
|
37
|
-
LoaderSchema,
|
38
37
|
GetCommits,
|
39
38
|
GetCommitChanges,
|
39
|
+
GetCommitsDiff,
|
40
40
|
ApplyGitPatch,
|
41
|
+
ApplyGitPatchFromArtifact,
|
41
42
|
TriggerWorkflow,
|
42
43
|
GetWorkflowStatus,
|
43
44
|
GetWorkflowLogs,
|
@@ -91,6 +92,9 @@ class GitHubClient(BaseModel):
|
|
91
92
|
# Adding auth config and repo config as optional fields for initialization
|
92
93
|
auth_config: Optional[GitHubAuthConfig] = Field(default=None, exclude=True)
|
93
94
|
repo_config: Optional[GitHubRepoConfig] = Field(default=None, exclude=True)
|
95
|
+
|
96
|
+
# Alita instance
|
97
|
+
alita: Optional[Any] = Field(default=None, exclude=True)
|
94
98
|
|
95
99
|
@model_validator(mode='before')
|
96
100
|
def initialize_github_client(cls, values):
|
@@ -388,6 +392,111 @@ class GitHubClient(BaseModel):
|
|
388
392
|
except Exception as e:
|
389
393
|
# Return error as JSON instead of plain text
|
390
394
|
return {"error": str(e), "message": f"Unable to retrieve commit changes due to error: {str(e)}"}
|
395
|
+
|
396
|
+
def get_commits_diff(self, base_sha: str, head_sha: str, repo_name: Optional[str] = None) -> str:
|
397
|
+
"""
|
398
|
+
Retrieves the diff between two commits.
|
399
|
+
|
400
|
+
Parameters:
|
401
|
+
base_sha (str): The base commit SHA to compare from.
|
402
|
+
head_sha (str): The head commit SHA to compare to.
|
403
|
+
repo_name (Optional[str]): Name of the repository in format 'owner/repo'.
|
404
|
+
|
405
|
+
Returns:
|
406
|
+
str: A detailed diff comparison between the two commits or an error message.
|
407
|
+
"""
|
408
|
+
try:
|
409
|
+
# Get the repository
|
410
|
+
repo = self.github_api.get_repo(repo_name) if repo_name else self.github_repo_instance
|
411
|
+
|
412
|
+
# Get the comparison between the two commits
|
413
|
+
comparison = repo.compare(base_sha, head_sha)
|
414
|
+
|
415
|
+
# Extract comparison information
|
416
|
+
diff_info = {
|
417
|
+
"base_commit": {
|
418
|
+
"sha": comparison.base_commit.sha,
|
419
|
+
"message": comparison.base_commit.commit.message,
|
420
|
+
"author": comparison.base_commit.commit.author.name,
|
421
|
+
"date": comparison.base_commit.commit.author.date.isoformat()
|
422
|
+
},
|
423
|
+
"head_commit": {
|
424
|
+
"sha": comparison.head_commit.sha,
|
425
|
+
"message": comparison.head_commit.commit.message,
|
426
|
+
"author": comparison.head_commit.commit.author.name,
|
427
|
+
"date": comparison.head_commit.commit.author.date.isoformat()
|
428
|
+
},
|
429
|
+
"status": comparison.status, # ahead, behind, identical, or diverged
|
430
|
+
"ahead_by": comparison.ahead_by,
|
431
|
+
"behind_by": comparison.behind_by,
|
432
|
+
"total_commits": comparison.total_commits,
|
433
|
+
"commits": [],
|
434
|
+
"files": []
|
435
|
+
}
|
436
|
+
|
437
|
+
# Get commits in the comparison
|
438
|
+
for commit in comparison.commits:
|
439
|
+
commit_info = {
|
440
|
+
"sha": commit.sha,
|
441
|
+
"message": commit.commit.message,
|
442
|
+
"author": commit.commit.author.name,
|
443
|
+
"date": commit.commit.author.date.isoformat(),
|
444
|
+
"url": commit.html_url
|
445
|
+
}
|
446
|
+
diff_info["commits"].append(commit_info)
|
447
|
+
|
448
|
+
# Get changed files information
|
449
|
+
for file in comparison.files:
|
450
|
+
file_info = {
|
451
|
+
"filename": file.filename,
|
452
|
+
"status": file.status, # added, modified, removed, renamed
|
453
|
+
"additions": file.additions,
|
454
|
+
"deletions": file.deletions,
|
455
|
+
"changes": file.changes,
|
456
|
+
"patch": file.patch if hasattr(file, 'patch') and file.patch else None,
|
457
|
+
"blob_url": file.blob_url if hasattr(file, 'blob_url') else None,
|
458
|
+
"raw_url": file.raw_url if hasattr(file, 'raw_url') else None
|
459
|
+
}
|
460
|
+
|
461
|
+
# Add previous filename for renamed files
|
462
|
+
if file.status == "renamed" and hasattr(file, 'previous_filename'):
|
463
|
+
file_info["previous_filename"] = file.previous_filename
|
464
|
+
|
465
|
+
diff_info["files"].append(file_info)
|
466
|
+
|
467
|
+
# Add summary statistics
|
468
|
+
diff_info["summary"] = {
|
469
|
+
"total_files_changed": len(diff_info["files"]),
|
470
|
+
"total_additions": sum(f["additions"] for f in diff_info["files"]),
|
471
|
+
"total_deletions": sum(f["deletions"] for f in diff_info["files"])
|
472
|
+
}
|
473
|
+
|
474
|
+
return diff_info
|
475
|
+
|
476
|
+
except Exception as e:
|
477
|
+
# Return error as JSON instead of plain text
|
478
|
+
return {"error": str(e), "message": f"Unable to retrieve diff between commits due to error: {str(e)}"}
|
479
|
+
|
480
|
+
def apply_git_patch_from_file(self, bucket_name: str, file_name: str, commit_message: Optional[str] = "Apply git patch", repo_name: Optional[str] = None) -> str:
|
481
|
+
"""Applies a git patch from a file stored in a specified bucket.
|
482
|
+
|
483
|
+
Args:
|
484
|
+
bucket_name (str): The name of the bucket where the patch file is stored.
|
485
|
+
file_name (str): The name of the patch file to apply.
|
486
|
+
commit_message (Optional[str], optional): The commit message for the patch application. Defaults to "Apply git patch".
|
487
|
+
repo_name (Optional[str], optional): The name of the repository to apply the patch to. Defaults to None.
|
488
|
+
|
489
|
+
Returns:
|
490
|
+
str: A summary of the applied changes or an error message.
|
491
|
+
"""
|
492
|
+
try:
|
493
|
+
patch_content = self.alita.download_artifact(bucket_name, file_name)
|
494
|
+
if not patch_content or not isinstance(patch_content, str):
|
495
|
+
return {"error": "Patch file not found", "message": f"Patch file '{file_name}' not found in bucket '{bucket_name}'."}
|
496
|
+
# Apply the git patch using the content
|
497
|
+
return self.apply_git_patch(patch_content, commit_message, repo_name)
|
498
|
+
except Exception as e:
|
499
|
+
return {"error": str(e), "message": f"Unable to download patch file: {str(e)}"}
|
391
500
|
|
392
501
|
def apply_git_patch(self, patch_content: str, commit_message: Optional[str] = "Apply git patch", repo_name: Optional[str] = None) -> str:
|
393
502
|
"""
|
@@ -1902,6 +2011,13 @@ class GitHubClient(BaseModel):
|
|
1902
2011
|
"description": self.get_commit_changes.__doc__,
|
1903
2012
|
"args_schema": GetCommitChanges,
|
1904
2013
|
},
|
2014
|
+
{
|
2015
|
+
"ref": self.get_commits_diff,
|
2016
|
+
"name": "get_commits_diff",
|
2017
|
+
"mode": "get_commits_diff",
|
2018
|
+
"description": self.get_commits_diff.__doc__,
|
2019
|
+
"args_schema": GetCommitsDiff,
|
2020
|
+
},
|
1905
2021
|
{
|
1906
2022
|
"ref": self.apply_git_patch,
|
1907
2023
|
"name": "apply_git_patch",
|
@@ -1909,6 +2025,13 @@ class GitHubClient(BaseModel):
|
|
1909
2025
|
"description": self.apply_git_patch.__doc__,
|
1910
2026
|
"args_schema": ApplyGitPatch,
|
1911
2027
|
},
|
2028
|
+
{
|
2029
|
+
"ref": self.apply_git_patch_from_file,
|
2030
|
+
"name": "apply_git_patch_from_file",
|
2031
|
+
"mode": "apply_git_patch_from_file",
|
2032
|
+
"description": self.apply_git_patch_from_file.__doc__,
|
2033
|
+
"args_schema": ApplyGitPatchFromArtifact,
|
2034
|
+
},
|
1912
2035
|
{
|
1913
2036
|
"ref": self.trigger_workflow,
|
1914
2037
|
"name": "trigger_workflow",
|
@@ -159,6 +159,13 @@ GetCommitChanges = create_model(
|
|
159
159
|
repo_name=(Optional[str], Field(default=None, description="Name of the repository (e.g., 'owner/repo'). If None, uses the default repository."))
|
160
160
|
)
|
161
161
|
|
162
|
+
GetCommitsDiff = create_model(
|
163
|
+
"GetCommitsDiff",
|
164
|
+
base_sha=(str, Field(description="The base commit SHA to compare from")),
|
165
|
+
head_sha=(str, Field(description="The head commit SHA to compare to")),
|
166
|
+
repo_name=(Optional[str], Field(default=None, description="Name of the repository (e.g., 'owner/repo'). If None, uses the default repository."))
|
167
|
+
)
|
168
|
+
|
162
169
|
ApplyGitPatch = create_model(
|
163
170
|
"ApplyGitPatch",
|
164
171
|
patch_content=(str, Field(description="The git patch content in unified diff format")),
|
@@ -166,6 +173,14 @@ ApplyGitPatch = create_model(
|
|
166
173
|
repo_name=(Optional[str], Field(default=None, description="Name of the repository (e.g., 'owner/repo'). If None, uses the default repository."))
|
167
174
|
)
|
168
175
|
|
176
|
+
ApplyGitPatchFromArtifact = create_model(
|
177
|
+
"ApplyGitPatchFromArtifact",
|
178
|
+
bucket_name=(str, Field(description="Name of the artifact bucket containing the patch file")),
|
179
|
+
file_name=(str, Field(description="Name of the patch file to download and apply")),
|
180
|
+
commit_message=(Optional[str], Field(description="Commit message for the patch application", default="Apply git patch from artifact")),
|
181
|
+
repo_name=(Optional[str], Field(default=None, description="Name of the repository (e.g., 'owner/repo'). If None, uses the default repository."))
|
182
|
+
)
|
183
|
+
|
169
184
|
TriggerWorkflow = create_model(
|
170
185
|
"TriggerWorkflow",
|
171
186
|
workflow_id=(str, Field(description="The ID or file name of the workflow to trigger (e.g., 'build.yml', '1234567')")),
|
@@ -0,0 +1,62 @@
|
|
1
|
+
from typing import List, Literal, Optional
|
2
|
+
|
3
|
+
from langchain_community.agent_toolkits.base import BaseToolkit
|
4
|
+
from langchain_core.tools import BaseTool
|
5
|
+
from pydantic import create_model, BaseModel, Field, SecretStr
|
6
|
+
|
7
|
+
from .api_wrapper import ZephyrSquadApiWrapper
|
8
|
+
from ..base.tool import BaseAction
|
9
|
+
from ..utils import clean_string, TOOLKIT_SPLITTER, get_max_toolkit_length
|
10
|
+
|
11
|
+
name = "zephyr"
|
12
|
+
|
13
|
+
def get_tools(tool):
|
14
|
+
return ZephyrSquadToolkit().get_toolkit(
|
15
|
+
selected_tools=tool['settings'].get('selected_tools', []),
|
16
|
+
account_id=tool['settings']["account_id"],
|
17
|
+
access_key=tool['settings']["access_key"],
|
18
|
+
secret_key=tool['settings']["secret_key"],
|
19
|
+
toolkit_name=tool.get('toolkit_name')
|
20
|
+
).get_tools()
|
21
|
+
|
22
|
+
class ZephyrSquadToolkit(BaseToolkit):
|
23
|
+
tools: List[BaseTool] = []
|
24
|
+
toolkit_max_length: int = 0
|
25
|
+
|
26
|
+
@staticmethod
|
27
|
+
def toolkit_config_schema() -> BaseModel:
|
28
|
+
selected_tools = {x['name']: x['args_schema'].schema() for x in ZephyrSquadApiWrapper.model_construct().get_available_tools()}
|
29
|
+
ZephyrSquadToolkit.toolkit_max_length = get_max_toolkit_length(selected_tools)
|
30
|
+
return create_model(
|
31
|
+
"zephyr_squad",
|
32
|
+
account_id=(str, Field(description="AccountID for the user that is going to be authenticating")),
|
33
|
+
access_key=(str, Field(description="Generated access key")),
|
34
|
+
secret_key=(SecretStr, Field(description="Generated secret key")),
|
35
|
+
selected_tools=(List[Literal[tuple(selected_tools)]], Field(default=[], json_schema_extra={'args_schemas': selected_tools})),
|
36
|
+
__config__={'json_schema_extra': {'metadata': {"label": "Zephyr Squad", "icon_url": "zephyr.svg",
|
37
|
+
"categories": ["test management"],
|
38
|
+
"extra_categories": ["test automation", "test case management", "test planning"]
|
39
|
+
}}}
|
40
|
+
)
|
41
|
+
|
42
|
+
@classmethod
|
43
|
+
def get_toolkit(cls, selected_tools: list[str] | None = None, toolkit_name: Optional[str] = None, **kwargs):
|
44
|
+
zephyr_api_wrapper = ZephyrSquadApiWrapper(**kwargs)
|
45
|
+
prefix = clean_string(toolkit_name, cls.toolkit_max_length) + TOOLKIT_SPLITTER if toolkit_name else ''
|
46
|
+
available_tools = zephyr_api_wrapper.get_available_tools()
|
47
|
+
tools = []
|
48
|
+
for tool in available_tools:
|
49
|
+
if selected_tools:
|
50
|
+
if tool["name"] not in selected_tools:
|
51
|
+
continue
|
52
|
+
tools.append(BaseAction(
|
53
|
+
api_wrapper=zephyr_api_wrapper,
|
54
|
+
name=prefix + tool["name"],
|
55
|
+
description=tool["description"],
|
56
|
+
args_schema=tool["args_schema"]
|
57
|
+
))
|
58
|
+
return cls(tools=tools)
|
59
|
+
|
60
|
+
def get_tools(self):
|
61
|
+
return self.tools
|
62
|
+
|