alita-sdk 0.3.161__py3-none-any.whl → 0.3.163__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,278 @@
1
+ import logging
2
+ import json
3
+ import traceback
4
+ import re
5
+ from typing import Type
6
+ from langchain_core.tools import BaseTool, ToolException
7
+ from pydantic.fields import Field
8
+ from pydantic import create_model, BaseModel
9
+ from .api_wrapper import CarrierAPIWrapper
10
+
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class UpdateUITestScheduleTool(BaseTool):
16
+ api_wrapper: CarrierAPIWrapper = Field(..., description="Carrier API Wrapper instance")
17
+ name: str = "update_ui_test_schedule"
18
+ description: str = ("Update UI test schedule on the Carrier platform. Use this tool when user wants to update, modify, or change a UI test schedule. "
19
+ "Provide test_id, schedule_name, and cron_timer, or leave empty to see available tests.")
20
+ args_schema: Type[BaseModel] = create_model(
21
+ "UpdateUITestScheduleInput",
22
+ test_id=(str, Field(default="", description="Test ID to update schedule for")),
23
+ schedule_name=(str, Field(default="", description="Name for the new schedule")),
24
+ cron_timer=(str, Field(default="", description="Cron expression for schedule timing (e.g., '0 2 * * *')")),
25
+ )
26
+
27
+ def _run(self, test_id: str = "", schedule_name: str = "", cron_timer: str = ""):
28
+ try:
29
+ # Check if no parameters provided - show available tests
30
+ if (not test_id or test_id.strip() == "") and (not schedule_name or schedule_name.strip() == "") and (not cron_timer or cron_timer.strip() == ""):
31
+ return self._show_available_tests_and_instructions()
32
+
33
+ # Check if test_id is missing but other params provided
34
+ if (not test_id or test_id.strip() == ""):
35
+ return self._show_missing_test_id_message()
36
+
37
+ # Check if schedule_name or cron_timer is missing
38
+ if (not schedule_name or schedule_name.strip() == "") or (not cron_timer or cron_timer.strip() == ""):
39
+ return self._show_missing_parameters_message(test_id, schedule_name, cron_timer)
40
+
41
+ # Validate cron timer format
42
+ if not self._validate_cron_timer(cron_timer):
43
+ return self._show_invalid_cron_message(cron_timer)
44
+
45
+ # Get UI tests list to verify test exists
46
+ ui_tests = self.api_wrapper.get_ui_tests_list()
47
+ test_data = None
48
+ test_id_int = None
49
+
50
+ # Try to find test by ID
51
+ if test_id.isdigit():
52
+ test_id_int = int(test_id)
53
+ for test in ui_tests:
54
+ if test.get("id") == test_id_int:
55
+ test_data = test
56
+ break
57
+
58
+ if not test_data:
59
+ available_tests = []
60
+ for test in ui_tests:
61
+ available_tests.append(f"ID: {test.get('id')}, Name: {test.get('name')}")
62
+
63
+ return f"❌ **Test not found for ID: {test_id}**\n\n**Available UI tests:**\n" + "\n".join([f"- {test}" for test in available_tests])
64
+
65
+ # Get detailed test configuration
66
+ test_details = self.api_wrapper.get_ui_test_details(str(test_id_int))
67
+
68
+ if not test_details:
69
+ return f"❌ **Could not retrieve test details for test ID {test_id_int}.**"
70
+
71
+ # Parse and update the test configuration
72
+ updated_config = self._parse_and_update_test_data(test_details, schedule_name, cron_timer)
73
+
74
+ # Execute the PUT request to update the test
75
+ result = self.api_wrapper.update_ui_test(str(test_id_int), updated_config)
76
+
77
+ return self._format_success_message(test_data.get('name', 'Unknown'), test_id_int, schedule_name, cron_timer)
78
+
79
+ except Exception:
80
+ stacktrace = traceback.format_exc()
81
+ logger.error(f"Error updating UI test schedule: {stacktrace}")
82
+ raise ToolException(stacktrace)
83
+
84
+ def _show_available_tests_and_instructions(self):
85
+ """Show available tests and instructions when no parameters provided."""
86
+ try:
87
+ ui_tests = self.api_wrapper.get_ui_tests_list()
88
+
89
+ if not ui_tests:
90
+ return "❌ **No UI tests found.**"
91
+
92
+ message = ["# 📋 Update UI Test Schedule\n"]
93
+ message.append("## Available UI Tests:")
94
+
95
+ for test in ui_tests:
96
+ message.append(f"- **ID: {test.get('id')}**, Name: `{test.get('name')}`, Runner: `{test.get('runner')}`")
97
+
98
+ message.append("\n## 📝 Instructions:")
99
+ message.append("For updating UI test schedule, please provide me:")
100
+ message.append("- **`test_id`** - The ID of the test you want to update")
101
+ message.append("- **`schedule_name`** - A name for your new schedule")
102
+ message.append("- **`cron_timer`** - Cron expression for timing (e.g., `0 2 * * *` for daily at 2 AM)")
103
+
104
+ message.append("\n## 💡 Example:")
105
+ message.append("```")
106
+ message.append("test_id: 42")
107
+ message.append("schedule_name: Daily Morning Test")
108
+ message.append("cron_timer: 0 2 * * *")
109
+ message.append("```")
110
+
111
+ return "\n".join(message)
112
+
113
+ except Exception:
114
+ stacktrace = traceback.format_exc()
115
+ logger.error(f"Error fetching UI tests list: {stacktrace}")
116
+ raise ToolException(stacktrace)
117
+
118
+ def _show_missing_test_id_message(self):
119
+ """Show message when test_id is missing."""
120
+ return """# ❌ Missing Test ID
121
+
122
+ **For updating UI test schedule, please provide me:**
123
+ - **`test_id`** - The ID of the test you want to update
124
+ - **`schedule_name`** - A name for your new schedule
125
+ - **`cron_timer`** - Cron expression for timing
126
+
127
+ Use the tool without parameters to see available tests."""
128
+
129
+ def _show_missing_parameters_message(self, test_id: str, schedule_name: str, cron_timer: str):
130
+ """Show message when some parameters are missing."""
131
+ missing = []
132
+ if not schedule_name or schedule_name.strip() == "":
133
+ missing.append("**`schedule_name`**")
134
+ if not cron_timer or cron_timer.strip() == "":
135
+ missing.append("**`cron_timer`**")
136
+
137
+ message = [f"# ❌ Missing Parameters for Test ID: {test_id}\n"]
138
+ message.append("**Missing parameters:**")
139
+ for param in missing:
140
+ message.append(f"- {param}")
141
+
142
+ message.append("\n**For updating UI test schedule, please provide:**")
143
+ message.append("- **`test_id`** ✅ (provided)")
144
+ message.append("- **`schedule_name`** - A name for your new schedule")
145
+ message.append("- **`cron_timer`** - Cron expression for timing (e.g., `0 2 * * *`)")
146
+
147
+ return "\n".join(message)
148
+
149
+ def _validate_cron_timer(self, cron_timer: str) -> bool:
150
+ """Validate cron timer format."""
151
+ # Basic cron validation - should have 5 parts separated by spaces
152
+ parts = cron_timer.strip().split()
153
+ if len(parts) != 5:
154
+ return False
155
+
156
+ # Each part should contain only digits, *, /, -, or ,
157
+ cron_pattern = re.compile(r'^[0-9*,/-]+$')
158
+ return all(cron_pattern.match(part) for part in parts)
159
+
160
+ def _show_invalid_cron_message(self, cron_timer: str):
161
+ """Show message for invalid cron timer."""
162
+ return f"""# ❌ Invalid Cron Timer Format
163
+
164
+ **Provided:** `{cron_timer}`
165
+
166
+ **Cron format should be:** `minute hour day month weekday`
167
+
168
+ ## Valid Examples:
169
+ - `0 2 * * *` - Daily at 2:00 AM
170
+ - `30 14 * * 1` - Every Monday at 2:30 PM
171
+ - `0 */6 * * *` - Every 6 hours
172
+ - `15 10 1 * *` - First day of every month at 10:15 AM
173
+ - `0 9 * * 1-5` - Weekdays at 9:00 AM
174
+
175
+ ## Format Rules:
176
+ - **Minute:** 0-59
177
+ - **Hour:** 0-23
178
+ - **Day:** 1-31
179
+ - **Month:** 1-12
180
+ - **Weekday:** 0-7 (0 and 7 are Sunday)
181
+ - Use **`*`** for "any value"
182
+ - Use **`,`** for multiple values
183
+ - Use **`-`** for ranges
184
+ - Use **`/`** for step values"""
185
+
186
+ def _parse_and_update_test_data(self, get_data: dict, schedule_name: str, cron_timer: str) -> dict:
187
+ """Parse GET response data and transform it into the required format for PUT request."""
188
+
189
+ # Extract environment and test type from test parameters
190
+ env_type = ""
191
+ test_type = ""
192
+ for param in get_data.get("test_parameters", []):
193
+ if param.get("name") == "env_type":
194
+ env_type = param.get("default", "")
195
+ elif param.get("name") == "test_type":
196
+ test_type = param.get("default", "")
197
+
198
+ # Construct common_params from GET data
199
+ common_params = {
200
+ "aggregation": get_data.get("aggregation", "max"),
201
+ "cc_env_vars": get_data.get("cc_env_vars", {}),
202
+ "entrypoint": get_data.get("entrypoint", ""),
203
+ "env_type": env_type,
204
+ "env_vars": get_data.get("env_vars", {}),
205
+ "location": get_data.get("location", ""),
206
+ "loops": get_data.get("loops", 1),
207
+ "name": get_data.get("name", ""),
208
+ "parallel_runners": get_data.get("parallel_runners", 1),
209
+ "runner": get_data.get("runner", ""),
210
+ "source": get_data.get("source", {}),
211
+ "test_type": test_type
212
+ }
213
+
214
+ # Extract only required integrations (reporters and system)
215
+ integrations = {
216
+ "reporters": get_data.get("integrations", {}).get("reporters", {}),
217
+ "system": get_data.get("integrations", {}).get("system", {})
218
+ }
219
+
220
+ # Process existing schedules and add the new one
221
+ schedules = []
222
+
223
+ # Keep existing schedules
224
+ for schedule in get_data.get("schedules", []):
225
+ existing_schedule = {
226
+ "active": schedule.get("active", False),
227
+ "cron": schedule.get("cron", ""),
228
+ "cron_radio": "custom",
229
+ "errors": {},
230
+ "id": schedule.get("id"),
231
+ "name": schedule.get("name", ""),
232
+ "project_id": schedule.get("project_id"),
233
+ "rpc_kwargs": schedule.get("rpc_kwargs"),
234
+ "test_id": schedule.get("test_id"),
235
+ "test_params": schedule.get("test_params", [])
236
+ }
237
+ schedules.append(existing_schedule)
238
+
239
+ # Add the new schedule
240
+ new_schedule = {
241
+ "active": True,
242
+ "cron": cron_timer,
243
+ "cron_radio": "custom",
244
+ "errors": {},
245
+ "id": None, # New schedule, no ID yet
246
+ "name": schedule_name,
247
+ "test_params": []
248
+ }
249
+ schedules.append(new_schedule)
250
+
251
+ # Assemble the final PUT request data
252
+ put_data = {
253
+ "common_params": common_params,
254
+ "integrations": integrations,
255
+ "run_test": False,
256
+ "schedules": schedules,
257
+ "test_parameters": [] # Empty as required in PUT request
258
+ }
259
+
260
+ return put_data
261
+
262
+ def _format_success_message(self, test_name: str, test_id: int, schedule_name: str, cron_timer: str) -> str:
263
+ """Format success message in markdown."""
264
+ return f"""# ✅ UI Test Schedule Updated Successfully!
265
+
266
+ ## Test Information:
267
+ - **Test Name:** `{test_name}`
268
+ - **Test ID:** `{test_id}`
269
+
270
+ ## New Schedule Added:
271
+ - **Schedule Name:** `{schedule_name}`
272
+ - **Cron Timer:** `{cron_timer}`
273
+ - **Status:** Active ✅
274
+
275
+ ## 🎯 What happens next:
276
+ The test will now run automatically according to the specified schedule. You can view and manage schedules in the Carrier platform UI.
277
+
278
+ **Schedule will execute:** Based on cron expression `{cron_timer}`"""
@@ -23,6 +23,7 @@ def _get_toolkit(tool) -> BaseToolkit:
23
23
  github_app_id=tool['settings'].get('app_id', None),
24
24
  github_app_private_key=tool['settings'].get('app_private_key', None),
25
25
  llm=tool['settings'].get('llm', None),
26
+ alita=tool['settings'].get('alita', None),
26
27
  connection_string=tool['settings'].get('connection_string', None),
27
28
  collection_name=str(tool['id']),
28
29
  doctype='code',
@@ -51,7 +51,9 @@ class AlitaGitHubAPIWrapper(BaseCodeToolApiWrapper):
51
51
 
52
52
  # Add LLM instance
53
53
  llm: Optional[Any] = None
54
-
54
+ # Alita instance
55
+ alita: Optional[Any] = None
56
+
55
57
  # Vector store configuration
56
58
  connection_string: Optional[SecretStr] = None
57
59
  collection_name: Optional[str] = None
@@ -109,7 +111,7 @@ class AlitaGitHubAPIWrapper(BaseCodeToolApiWrapper):
109
111
  )
110
112
 
111
113
  # Initialize GitHub client with keyword arguments
112
- github_client = GitHubClient(auth_config=auth_config, repo_config=repo_config)
114
+ github_client = GitHubClient(auth_config=auth_config, repo_config=repo_config, alita=values.get("alita"))
113
115
  # Initialize GraphQL client with keyword argument
114
116
  graphql_client = GraphQLClientWrapper(github_graphql_instance=github_client.github_api._Github__requester)
115
117
  # Set client attributes on the class (renamed from _github_client to github_client_instance)
@@ -34,10 +34,11 @@ from .schemas import (
34
34
  SearchIssues,
35
35
  CreateIssue,
36
36
  UpdateIssue,
37
- LoaderSchema,
38
37
  GetCommits,
39
38
  GetCommitChanges,
39
+ GetCommitsDiff,
40
40
  ApplyGitPatch,
41
+ ApplyGitPatchFromArtifact,
41
42
  TriggerWorkflow,
42
43
  GetWorkflowStatus,
43
44
  GetWorkflowLogs,
@@ -91,6 +92,9 @@ class GitHubClient(BaseModel):
91
92
  # Adding auth config and repo config as optional fields for initialization
92
93
  auth_config: Optional[GitHubAuthConfig] = Field(default=None, exclude=True)
93
94
  repo_config: Optional[GitHubRepoConfig] = Field(default=None, exclude=True)
95
+
96
+ # Alita instance
97
+ alita: Optional[Any] = Field(default=None, exclude=True)
94
98
 
95
99
  @model_validator(mode='before')
96
100
  def initialize_github_client(cls, values):
@@ -388,6 +392,111 @@ class GitHubClient(BaseModel):
388
392
  except Exception as e:
389
393
  # Return error as JSON instead of plain text
390
394
  return {"error": str(e), "message": f"Unable to retrieve commit changes due to error: {str(e)}"}
395
+
396
+ def get_commits_diff(self, base_sha: str, head_sha: str, repo_name: Optional[str] = None) -> str:
397
+ """
398
+ Retrieves the diff between two commits.
399
+
400
+ Parameters:
401
+ base_sha (str): The base commit SHA to compare from.
402
+ head_sha (str): The head commit SHA to compare to.
403
+ repo_name (Optional[str]): Name of the repository in format 'owner/repo'.
404
+
405
+ Returns:
406
+ str: A detailed diff comparison between the two commits or an error message.
407
+ """
408
+ try:
409
+ # Get the repository
410
+ repo = self.github_api.get_repo(repo_name) if repo_name else self.github_repo_instance
411
+
412
+ # Get the comparison between the two commits
413
+ comparison = repo.compare(base_sha, head_sha)
414
+
415
+ # Extract comparison information
416
+ diff_info = {
417
+ "base_commit": {
418
+ "sha": comparison.base_commit.sha,
419
+ "message": comparison.base_commit.commit.message,
420
+ "author": comparison.base_commit.commit.author.name,
421
+ "date": comparison.base_commit.commit.author.date.isoformat()
422
+ },
423
+ "head_commit": {
424
+ "sha": comparison.head_commit.sha,
425
+ "message": comparison.head_commit.commit.message,
426
+ "author": comparison.head_commit.commit.author.name,
427
+ "date": comparison.head_commit.commit.author.date.isoformat()
428
+ },
429
+ "status": comparison.status, # ahead, behind, identical, or diverged
430
+ "ahead_by": comparison.ahead_by,
431
+ "behind_by": comparison.behind_by,
432
+ "total_commits": comparison.total_commits,
433
+ "commits": [],
434
+ "files": []
435
+ }
436
+
437
+ # Get commits in the comparison
438
+ for commit in comparison.commits:
439
+ commit_info = {
440
+ "sha": commit.sha,
441
+ "message": commit.commit.message,
442
+ "author": commit.commit.author.name,
443
+ "date": commit.commit.author.date.isoformat(),
444
+ "url": commit.html_url
445
+ }
446
+ diff_info["commits"].append(commit_info)
447
+
448
+ # Get changed files information
449
+ for file in comparison.files:
450
+ file_info = {
451
+ "filename": file.filename,
452
+ "status": file.status, # added, modified, removed, renamed
453
+ "additions": file.additions,
454
+ "deletions": file.deletions,
455
+ "changes": file.changes,
456
+ "patch": file.patch if hasattr(file, 'patch') and file.patch else None,
457
+ "blob_url": file.blob_url if hasattr(file, 'blob_url') else None,
458
+ "raw_url": file.raw_url if hasattr(file, 'raw_url') else None
459
+ }
460
+
461
+ # Add previous filename for renamed files
462
+ if file.status == "renamed" and hasattr(file, 'previous_filename'):
463
+ file_info["previous_filename"] = file.previous_filename
464
+
465
+ diff_info["files"].append(file_info)
466
+
467
+ # Add summary statistics
468
+ diff_info["summary"] = {
469
+ "total_files_changed": len(diff_info["files"]),
470
+ "total_additions": sum(f["additions"] for f in diff_info["files"]),
471
+ "total_deletions": sum(f["deletions"] for f in diff_info["files"])
472
+ }
473
+
474
+ return diff_info
475
+
476
+ except Exception as e:
477
+ # Return error as JSON instead of plain text
478
+ return {"error": str(e), "message": f"Unable to retrieve diff between commits due to error: {str(e)}"}
479
+
480
+ def apply_git_patch_from_file(self, bucket_name: str, file_name: str, commit_message: Optional[str] = "Apply git patch", repo_name: Optional[str] = None) -> str:
481
+ """Applies a git patch from a file stored in a specified bucket.
482
+
483
+ Args:
484
+ bucket_name (str): The name of the bucket where the patch file is stored.
485
+ file_name (str): The name of the patch file to apply.
486
+ commit_message (Optional[str], optional): The commit message for the patch application. Defaults to "Apply git patch".
487
+ repo_name (Optional[str], optional): The name of the repository to apply the patch to. Defaults to None.
488
+
489
+ Returns:
490
+ str: A summary of the applied changes or an error message.
491
+ """
492
+ try:
493
+ patch_content = self.alita.download_artifact(bucket_name, file_name)
494
+ if not patch_content or not isinstance(patch_content, str):
495
+ return {"error": "Patch file not found", "message": f"Patch file '{file_name}' not found in bucket '{bucket_name}'."}
496
+ # Apply the git patch using the content
497
+ return self.apply_git_patch(patch_content, commit_message, repo_name)
498
+ except Exception as e:
499
+ return {"error": str(e), "message": f"Unable to download patch file: {str(e)}"}
391
500
 
392
501
  def apply_git_patch(self, patch_content: str, commit_message: Optional[str] = "Apply git patch", repo_name: Optional[str] = None) -> str:
393
502
  """
@@ -1902,6 +2011,13 @@ class GitHubClient(BaseModel):
1902
2011
  "description": self.get_commit_changes.__doc__,
1903
2012
  "args_schema": GetCommitChanges,
1904
2013
  },
2014
+ {
2015
+ "ref": self.get_commits_diff,
2016
+ "name": "get_commits_diff",
2017
+ "mode": "get_commits_diff",
2018
+ "description": self.get_commits_diff.__doc__,
2019
+ "args_schema": GetCommitsDiff,
2020
+ },
1905
2021
  {
1906
2022
  "ref": self.apply_git_patch,
1907
2023
  "name": "apply_git_patch",
@@ -1909,6 +2025,13 @@ class GitHubClient(BaseModel):
1909
2025
  "description": self.apply_git_patch.__doc__,
1910
2026
  "args_schema": ApplyGitPatch,
1911
2027
  },
2028
+ {
2029
+ "ref": self.apply_git_patch_from_file,
2030
+ "name": "apply_git_patch_from_file",
2031
+ "mode": "apply_git_patch_from_file",
2032
+ "description": self.apply_git_patch_from_file.__doc__,
2033
+ "args_schema": ApplyGitPatchFromArtifact,
2034
+ },
1912
2035
  {
1913
2036
  "ref": self.trigger_workflow,
1914
2037
  "name": "trigger_workflow",
@@ -159,6 +159,13 @@ GetCommitChanges = create_model(
159
159
  repo_name=(Optional[str], Field(default=None, description="Name of the repository (e.g., 'owner/repo'). If None, uses the default repository."))
160
160
  )
161
161
 
162
+ GetCommitsDiff = create_model(
163
+ "GetCommitsDiff",
164
+ base_sha=(str, Field(description="The base commit SHA to compare from")),
165
+ head_sha=(str, Field(description="The head commit SHA to compare to")),
166
+ repo_name=(Optional[str], Field(default=None, description="Name of the repository (e.g., 'owner/repo'). If None, uses the default repository."))
167
+ )
168
+
162
169
  ApplyGitPatch = create_model(
163
170
  "ApplyGitPatch",
164
171
  patch_content=(str, Field(description="The git patch content in unified diff format")),
@@ -166,6 +173,14 @@ ApplyGitPatch = create_model(
166
173
  repo_name=(Optional[str], Field(default=None, description="Name of the repository (e.g., 'owner/repo'). If None, uses the default repository."))
167
174
  )
168
175
 
176
+ ApplyGitPatchFromArtifact = create_model(
177
+ "ApplyGitPatchFromArtifact",
178
+ bucket_name=(str, Field(description="Name of the artifact bucket containing the patch file")),
179
+ file_name=(str, Field(description="Name of the patch file to download and apply")),
180
+ commit_message=(Optional[str], Field(description="Commit message for the patch application", default="Apply git patch from artifact")),
181
+ repo_name=(Optional[str], Field(default=None, description="Name of the repository (e.g., 'owner/repo'). If None, uses the default repository."))
182
+ )
183
+
169
184
  TriggerWorkflow = create_model(
170
185
  "TriggerWorkflow",
171
186
  workflow_id=(str, Field(description="The ID or file name of the workflow to trigger (e.g., 'build.yml', '1234567')")),
@@ -0,0 +1,62 @@
1
+ from typing import List, Literal, Optional
2
+
3
+ from langchain_community.agent_toolkits.base import BaseToolkit
4
+ from langchain_core.tools import BaseTool
5
+ from pydantic import create_model, BaseModel, Field, SecretStr
6
+
7
+ from .api_wrapper import ZephyrSquadApiWrapper
8
+ from ..base.tool import BaseAction
9
+ from ..utils import clean_string, TOOLKIT_SPLITTER, get_max_toolkit_length
10
+
11
+ name = "zephyr"
12
+
13
+ def get_tools(tool):
14
+ return ZephyrSquadToolkit().get_toolkit(
15
+ selected_tools=tool['settings'].get('selected_tools', []),
16
+ account_id=tool['settings']["account_id"],
17
+ access_key=tool['settings']["access_key"],
18
+ secret_key=tool['settings']["secret_key"],
19
+ toolkit_name=tool.get('toolkit_name')
20
+ ).get_tools()
21
+
22
+ class ZephyrSquadToolkit(BaseToolkit):
23
+ tools: List[BaseTool] = []
24
+ toolkit_max_length: int = 0
25
+
26
+ @staticmethod
27
+ def toolkit_config_schema() -> BaseModel:
28
+ selected_tools = {x['name']: x['args_schema'].schema() for x in ZephyrSquadApiWrapper.model_construct().get_available_tools()}
29
+ ZephyrSquadToolkit.toolkit_max_length = get_max_toolkit_length(selected_tools)
30
+ return create_model(
31
+ "zephyr_squad",
32
+ account_id=(str, Field(description="AccountID for the user that is going to be authenticating")),
33
+ access_key=(str, Field(description="Generated access key")),
34
+ secret_key=(SecretStr, Field(description="Generated secret key")),
35
+ selected_tools=(List[Literal[tuple(selected_tools)]], Field(default=[], json_schema_extra={'args_schemas': selected_tools})),
36
+ __config__={'json_schema_extra': {'metadata': {"label": "Zephyr Squad", "icon_url": "zephyr.svg",
37
+ "categories": ["test management"],
38
+ "extra_categories": ["test automation", "test case management", "test planning"]
39
+ }}}
40
+ )
41
+
42
+ @classmethod
43
+ def get_toolkit(cls, selected_tools: list[str] | None = None, toolkit_name: Optional[str] = None, **kwargs):
44
+ zephyr_api_wrapper = ZephyrSquadApiWrapper(**kwargs)
45
+ prefix = clean_string(toolkit_name, cls.toolkit_max_length) + TOOLKIT_SPLITTER if toolkit_name else ''
46
+ available_tools = zephyr_api_wrapper.get_available_tools()
47
+ tools = []
48
+ for tool in available_tools:
49
+ if selected_tools:
50
+ if tool["name"] not in selected_tools:
51
+ continue
52
+ tools.append(BaseAction(
53
+ api_wrapper=zephyr_api_wrapper,
54
+ name=prefix + tool["name"],
55
+ description=tool["description"],
56
+ args_schema=tool["args_schema"]
57
+ ))
58
+ return cls(tools=tools)
59
+
60
+ def get_tools(self):
61
+ return self.tools
62
+