aiverify-moonshot 0.4.5__py3-none-any.whl → 0.4.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,6 +13,26 @@ from moonshot.api import (
13
13
  api_read_endpoint,
14
14
  api_update_endpoint,
15
15
  )
16
+ from moonshot.integrations.cli.cli_errors import (
17
+ ERROR_COMMON_ADD_ENDPOINT_CONNECTOR_TYPE_VALIDATION,
18
+ ERROR_COMMON_ADD_ENDPOINT_MAX_CALLS_PER_SECOND_VALIDATION,
19
+ ERROR_COMMON_ADD_ENDPOINT_MAX_CONCURRENCY_VALIDATION,
20
+ ERROR_COMMON_ADD_ENDPOINT_NAME_VALIDATION,
21
+ ERROR_COMMON_ADD_ENDPOINT_PARAMS_VALIDATION,
22
+ ERROR_COMMON_ADD_ENDPOINT_TOKEN_VALIDATION,
23
+ ERROR_COMMON_ADD_ENDPOINT_URI_VALIDATION,
24
+ ERROR_COMMON_DELETE_ENDPOINT_ENDPOINT_VALIDATION,
25
+ ERROR_COMMON_LIST_CONNECTOR_TYPES_FIND_VALIDATION,
26
+ ERROR_COMMON_LIST_CONNECTOR_TYPES_PAGINATION_VALIDATION,
27
+ ERROR_COMMON_LIST_CONNECTOR_TYPES_PAGINATION_VALIDATION_1,
28
+ ERROR_COMMON_LIST_ENDPOINTS_FIND_VALIDATION,
29
+ ERROR_COMMON_LIST_ENDPOINTS_PAGINATION_VALIDATION,
30
+ ERROR_COMMON_LIST_ENDPOINTS_PAGINATION_VALIDATION_1,
31
+ ERROR_COMMON_UPDATE_ENDPOINT_ENDPOINT_VALIDATION,
32
+ ERROR_COMMON_UPDATE_ENDPOINT_VALUES_VALIDATION,
33
+ ERROR_COMMON_UPDATE_ENDPOINT_VALUES_VALIDATION_1,
34
+ ERROR_COMMON_VIEW_ENDPOINT_ENDPOINT_VALIDATION,
35
+ )
16
36
  from moonshot.integrations.cli.utils.process_data import filter_data
17
37
 
18
38
  console = Console()
@@ -42,7 +62,44 @@ def add_endpoint(args) -> None:
42
62
  None
43
63
  """
44
64
  try:
45
- params_dict = literal_eval(args.params)
65
+ if not isinstance(args.name, str) or not args.name or args.name is None:
66
+ raise TypeError(ERROR_COMMON_ADD_ENDPOINT_NAME_VALIDATION)
67
+
68
+ if (
69
+ not isinstance(args.connector_type, str)
70
+ or not args.connector_type
71
+ or args.connector_type is None
72
+ ):
73
+ raise TypeError(ERROR_COMMON_ADD_ENDPOINT_CONNECTOR_TYPE_VALIDATION)
74
+
75
+ if not isinstance(args.uri, str) or not args.uri or args.uri is None:
76
+ raise TypeError(ERROR_COMMON_ADD_ENDPOINT_URI_VALIDATION)
77
+
78
+ if not isinstance(args.token, str) or not args.token or args.token is None:
79
+ raise TypeError(ERROR_COMMON_ADD_ENDPOINT_TOKEN_VALIDATION)
80
+
81
+ if (
82
+ not isinstance(args.max_calls_per_second, int)
83
+ or not args.max_calls_per_second
84
+ or args.max_calls_per_second is None
85
+ or args.max_calls_per_second < 0
86
+ ):
87
+ raise TypeError(ERROR_COMMON_ADD_ENDPOINT_MAX_CALLS_PER_SECOND_VALIDATION)
88
+
89
+ if (
90
+ not isinstance(args.max_concurrency, int)
91
+ or not args.max_concurrency
92
+ or args.max_concurrency is None
93
+ or args.max_calls_per_second < 0
94
+ ):
95
+ raise TypeError(ERROR_COMMON_ADD_ENDPOINT_MAX_CONCURRENCY_VALIDATION)
96
+
97
+ try:
98
+ params_dict = literal_eval(args.params)
99
+ except Exception:
100
+ raise SyntaxError(ERROR_COMMON_ADD_ENDPOINT_PARAMS_VALIDATION)
101
+ if not isinstance(params_dict, dict):
102
+ raise ValueError(ERROR_COMMON_ADD_ENDPOINT_PARAMS_VALIDATION)
46
103
 
47
104
  new_endpoint_id = api_create_endpoint(
48
105
  args.name,
@@ -74,9 +131,30 @@ def list_endpoints(args) -> list | None:
74
131
  list | None: A list of ConnectorEndpoint or None if there is no result.
75
132
  """
76
133
  try:
134
+ if args.find is not None:
135
+ if not isinstance(args.find, str) or not args.find:
136
+ raise TypeError(ERROR_COMMON_LIST_ENDPOINTS_FIND_VALIDATION)
137
+
138
+ if args.pagination is not None:
139
+ if not isinstance(args.pagination, str) or not args.pagination:
140
+ raise TypeError(ERROR_COMMON_LIST_ENDPOINTS_PAGINATION_VALIDATION)
141
+ try:
142
+ pagination = literal_eval(args.pagination)
143
+ if not (
144
+ isinstance(pagination, tuple)
145
+ and len(pagination) == 2
146
+ and all(isinstance(i, int) for i in pagination)
147
+ ):
148
+ raise ValueError(
149
+ ERROR_COMMON_LIST_ENDPOINTS_PAGINATION_VALIDATION_1
150
+ )
151
+ except (ValueError, SyntaxError):
152
+ raise ValueError(ERROR_COMMON_LIST_ENDPOINTS_PAGINATION_VALIDATION_1)
153
+ else:
154
+ pagination = ()
155
+
77
156
  endpoints_list = api_get_all_endpoint()
78
157
  keyword = args.find.lower() if args.find else ""
79
- pagination = literal_eval(args.pagination) if args.pagination else ()
80
158
 
81
159
  if endpoints_list:
82
160
  filtered_endpoints_list = filter_data(endpoints_list, keyword, pagination)
@@ -86,7 +164,6 @@ def list_endpoints(args) -> list | None:
86
164
 
87
165
  console.print("[red]There are no endpoints found.[/red]")
88
166
  return None
89
-
90
167
  except Exception as e:
91
168
  print(f"[list_endpoints]: {str(e)}")
92
169
 
@@ -107,9 +184,32 @@ def list_connector_types(args) -> list | None:
107
184
  list | None: A list of Connector or None if there is no result.
108
185
  """
109
186
  try:
187
+ if args.find is not None:
188
+ if not isinstance(args.find, str) or not args.find:
189
+ raise TypeError(ERROR_COMMON_LIST_CONNECTOR_TYPES_FIND_VALIDATION)
190
+
191
+ if args.pagination is not None:
192
+ if not isinstance(args.pagination, str) or not args.pagination:
193
+ raise TypeError(ERROR_COMMON_LIST_CONNECTOR_TYPES_PAGINATION_VALIDATION)
194
+ try:
195
+ pagination = literal_eval(args.pagination)
196
+ if not (
197
+ isinstance(pagination, tuple)
198
+ and len(pagination) == 2
199
+ and all(isinstance(i, int) for i in pagination)
200
+ ):
201
+ raise ValueError(
202
+ ERROR_COMMON_LIST_CONNECTOR_TYPES_PAGINATION_VALIDATION_1
203
+ )
204
+ except (ValueError, SyntaxError):
205
+ raise ValueError(
206
+ ERROR_COMMON_LIST_CONNECTOR_TYPES_PAGINATION_VALIDATION_1
207
+ )
208
+ else:
209
+ pagination = ()
210
+
110
211
  connector_type_list = api_get_all_connector_type()
111
212
  keyword = args.find.lower() if args.find else ""
112
- pagination = literal_eval(args.pagination) if args.pagination else ()
113
213
 
114
214
  if connector_type_list:
115
215
  filtered_connector_type_list = filter_data(
@@ -121,7 +221,6 @@ def list_connector_types(args) -> list | None:
121
221
 
122
222
  console.print("[red]There are no connector types found.[/red]")
123
223
  return None
124
-
125
224
  except Exception as e:
126
225
  print(f"[list_connector_types]: {str(e)}")
127
226
 
@@ -142,6 +241,13 @@ def view_endpoint(args) -> None:
142
241
  None
143
242
  """
144
243
  try:
244
+ if (
245
+ not isinstance(args.endpoint, str)
246
+ or not args.endpoint
247
+ or args.endpoint is None
248
+ ):
249
+ raise TypeError(ERROR_COMMON_VIEW_ENDPOINT_ENDPOINT_VALIDATION)
250
+
145
251
  endpoint_info = api_read_endpoint(args.endpoint)
146
252
  _display_endpoints([endpoint_info])
147
253
  except Exception as e:
@@ -165,8 +271,27 @@ def update_endpoint(args) -> None:
165
271
  None
166
272
  """
167
273
  try:
274
+ if (
275
+ args.endpoint is None
276
+ or not isinstance(args.endpoint, str)
277
+ or not args.endpoint
278
+ ):
279
+ raise ValueError(ERROR_COMMON_UPDATE_ENDPOINT_ENDPOINT_VALIDATION)
168
280
  endpoint = args.endpoint
169
- update_values = dict(literal_eval(args.update_kwargs))
281
+
282
+ if (
283
+ args.update_values is None
284
+ or not isinstance(args.update_values, str)
285
+ or not args.update_values
286
+ ):
287
+ raise ValueError(ERROR_COMMON_UPDATE_ENDPOINT_VALUES_VALIDATION)
288
+
289
+ if literal_eval(args.update_values) and all(
290
+ isinstance(i, tuple) for i in literal_eval(args.update_values)
291
+ ):
292
+ update_values = dict(literal_eval(args.update_values))
293
+ else:
294
+ raise ValueError(ERROR_COMMON_UPDATE_ENDPOINT_VALUES_VALIDATION_1)
170
295
  api_update_endpoint(endpoint, **update_values)
171
296
  print("[update_endpoint]: Endpoint updated.")
172
297
  except Exception as e:
@@ -196,6 +321,12 @@ def delete_endpoint(args) -> None:
196
321
  console.print("[bold yellow]Endpoint deletion cancelled.[/]")
197
322
  return
198
323
  try:
324
+ if (
325
+ args.endpoint is None
326
+ or not isinstance(args.endpoint, str)
327
+ or not args.endpoint
328
+ ):
329
+ raise ValueError(ERROR_COMMON_DELETE_ENDPOINT_ENDPOINT_VALIDATION)
199
330
  api_delete_endpoint(args.endpoint)
200
331
  print("[delete_endpoint]: Endpoint deleted.")
201
332
  except Exception as e:
@@ -3,11 +3,19 @@ from ast import literal_eval
3
3
  import cmd2
4
4
  from rich.console import Console
5
5
 
6
- from moonshot.api import (
7
- api_create_datasets,
6
+ from moonshot.api import api_create_datasets
7
+ from moonshot.integrations.cli.cli_errors import (
8
+ ERROR_COMMON_ADD_DATASET_DESC_VALIDATION,
9
+ ERROR_COMMON_ADD_DATASET_LICENSE_VALIDATION,
10
+ ERROR_COMMON_ADD_DATASET_METHOD_VALIDATION,
11
+ ERROR_COMMON_ADD_DATASET_NAME_VALIDATION,
12
+ ERROR_COMMON_ADD_DATASET_PARAMS_VALIDATION,
13
+ ERROR_COMMON_ADD_DATASET_REFERENCE_VALIDATION,
8
14
  )
9
15
 
10
16
  console = Console()
17
+
18
+
11
19
  def add_dataset(args) -> None:
12
20
  """
13
21
  Create a new dataset using the provided arguments and log the result.
@@ -26,6 +34,41 @@ def add_dataset(args) -> None:
26
34
  - params (dict): Additional parameters for dataset creation.
27
35
  """
28
36
  try:
37
+ if not isinstance(args.name, str) or not args.name or args.name is None:
38
+ raise TypeError(ERROR_COMMON_ADD_DATASET_NAME_VALIDATION)
39
+
40
+ if (
41
+ not isinstance(args.description, str)
42
+ or not args.description
43
+ or args.description is None
44
+ ):
45
+ raise TypeError(ERROR_COMMON_ADD_DATASET_DESC_VALIDATION)
46
+
47
+ if (
48
+ not isinstance(args.reference, str)
49
+ or not args.reference
50
+ or args.reference is None
51
+ ):
52
+ raise TypeError(ERROR_COMMON_ADD_DATASET_REFERENCE_VALIDATION)
53
+
54
+ if (
55
+ not isinstance(args.license, str)
56
+ or not args.license
57
+ or args.license is None
58
+ ):
59
+ raise TypeError(ERROR_COMMON_ADD_DATASET_LICENSE_VALIDATION)
60
+
61
+ if (
62
+ not isinstance(args.method, str)
63
+ or not args.method
64
+ or args.method is None
65
+ or args.method.lower() not in ["hf", "csv"]
66
+ ):
67
+ raise TypeError(ERROR_COMMON_ADD_DATASET_METHOD_VALIDATION)
68
+
69
+ if not isinstance(args.params, dict) or not args.params or args.params is None:
70
+ raise TypeError(ERROR_COMMON_ADD_DATASET_PARAMS_VALIDATION)
71
+
29
72
  new_dataset_id = api_create_datasets(
30
73
  args.name,
31
74
  args.description,
@@ -38,7 +81,8 @@ def add_dataset(args) -> None:
38
81
  except Exception as e:
39
82
  print(f"[add_dataset]: {str(e)}")
40
83
 
41
- # ------------------------------------------------------------------------------
84
+
85
+ # ------------------------------------------------------------------------------
42
86
  # Cmd2 Arguments Parsers
43
87
  # ------------------------------------------------------------------------------
44
88
  # Add dataset arguments
@@ -46,21 +90,30 @@ add_dataset_args = cmd2.Cmd2ArgumentParser(
46
90
  description="Add a new dataset. The 'name' argument will be slugified to create a unique identifier.",
47
91
  epilog=(
48
92
  "Examples:\n"
49
- "1. add_dataset 'dataset-name' 'A brief description' 'http://reference.com' 'MIT' 'csv' \"{'csv_file_path': '/path/to/your/file.csv'}\"\n"
50
- "2. add_dataset 'dataset-name' 'A brief description' 'http://reference.com' 'MIT' 'hf' \"{'dataset_name': 'cais/mmlu', 'dataset_config': 'college_biology', 'split': 'test', 'input_col': ['question','choices'], 'target_col': 'answer'}\""
93
+ "1. add_dataset 'dataset-name' 'A brief description' 'http://reference.com' 'MIT' 'csv' \"{'csv_file_path': '/path/to/your/file.csv'}\"\n" # noqa: E501
94
+ "2. add_dataset 'dataset-name' 'A brief description' 'http://reference.com' 'MIT' 'hf' \"{'dataset_name': 'cais/mmlu', 'dataset_config': 'college_biology', 'split': 'test', 'input_col': ['question','choices'], 'target_col': 'answer'}\"" # noqa: E501
51
95
  ),
52
96
  )
53
97
  add_dataset_args.add_argument("name", type=str, help="Name of the new dataset")
54
- add_dataset_args.add_argument("description", type=str, help="Description of the new dataset")
55
- add_dataset_args.add_argument("reference", type=str, help="Reference of the new dataset")
98
+ add_dataset_args.add_argument(
99
+ "description", type=str, help="Description of the new dataset"
100
+ )
101
+ add_dataset_args.add_argument(
102
+ "reference", type=str, help="Reference of the new dataset"
103
+ )
56
104
  add_dataset_args.add_argument("license", type=str, help="License of the new dataset")
57
- add_dataset_args.add_argument("method", type=str, choices=['hf', 'csv'], help="Method to convert the new dataset. Choose either 'hf' or 'csv'.")
58
105
  add_dataset_args.add_argument(
59
- "params",
60
- type=literal_eval,
106
+ "method",
107
+ type=str,
108
+ choices=["hf", "csv"],
109
+ help="Method to convert the new dataset. Choose either 'hf' or 'csv'.",
110
+ )
111
+ add_dataset_args.add_argument(
112
+ "params",
113
+ type=literal_eval,
61
114
  help=(
62
115
  "Params of the new dataset in dictionary format. For example: \n"
63
116
  "1. For 'csv' method: \"{'csv_file_path': '/path/to/your/file.csv'}\"\n"
64
- "2. For 'hf' method: \"{'dataset_name': 'cais_mmlu', 'dataset_config': 'college_biology', 'split': 'test', 'input_col': ['questions','choices'], 'target_col': 'answer'}\""
65
- )
66
- )
117
+ "2. For 'hf' method: \"{'dataset_name': 'cais_mmlu', 'dataset_config': 'college_biology', 'split': 'test', 'input_col': ['questions','choices'], 'target_col': 'answer'}\"" # noqa: E501
118
+ ),
119
+ )
@@ -5,6 +5,12 @@ from rich.console import Console
5
5
  from rich.table import Table
6
6
 
7
7
  from moonshot.api import api_delete_prompt_template, api_get_all_prompt_template_detail
8
+ from moonshot.integrations.cli.cli_errors import (
9
+ ERROR_COMMON_DELETE_PROMPT_TEMPLATE_PROMPT_TEMPLATE_VALIDATION,
10
+ ERROR_COMMON_LIST_CONNECTOR_TYPES_PAGINATION_VALIDATION,
11
+ ERROR_COMMON_LIST_CONNECTOR_TYPES_PAGINATION_VALIDATION_1,
12
+ ERROR_COMMON_LIST_PROMPT_TEMPLATES_FIND_VALIDATION,
13
+ )
8
14
  from moonshot.integrations.cli.utils.process_data import filter_data
9
15
 
10
16
  console = Console()
@@ -26,9 +32,32 @@ def list_prompt_templates(args) -> list | None:
26
32
  list | None: A list of PromptTemplate or None if there is no result.
27
33
  """
28
34
  try:
35
+ if args.find is not None:
36
+ if not isinstance(args.find, str) or not args.find:
37
+ raise TypeError(ERROR_COMMON_LIST_PROMPT_TEMPLATES_FIND_VALIDATION)
38
+
39
+ if args.pagination is not None:
40
+ if not isinstance(args.pagination, str) or not args.pagination:
41
+ raise TypeError(ERROR_COMMON_LIST_CONNECTOR_TYPES_PAGINATION_VALIDATION)
42
+ try:
43
+ pagination = literal_eval(args.pagination)
44
+ if not (
45
+ isinstance(pagination, tuple)
46
+ and len(pagination) == 2
47
+ and all(isinstance(i, int) for i in pagination)
48
+ ):
49
+ raise ValueError(
50
+ ERROR_COMMON_LIST_CONNECTOR_TYPES_PAGINATION_VALIDATION_1
51
+ )
52
+ except (ValueError, SyntaxError):
53
+ raise ValueError(
54
+ ERROR_COMMON_LIST_CONNECTOR_TYPES_PAGINATION_VALIDATION_1
55
+ )
56
+ else:
57
+ pagination = ()
58
+
29
59
  prompt_templates_list = api_get_all_prompt_template_detail()
30
60
  keyword = args.find.lower() if args.find else ""
31
- pagination = literal_eval(args.pagination) if args.pagination else ()
32
61
 
33
62
  if prompt_templates_list:
34
63
  filtered_prompt_templates_list = filter_data(
@@ -40,7 +69,6 @@ def list_prompt_templates(args) -> list | None:
40
69
 
41
70
  console.print("[red]There are no prompt templates found.[/red]")
42
71
  return None
43
-
44
72
  except Exception as e:
45
73
  print(f"[list_prompt_templates]: {str(e)}")
46
74
 
@@ -61,6 +89,14 @@ def delete_prompt_template(args) -> None:
61
89
  console.print("[bold yellow]Prompt template deletion cancelled.[/]")
62
90
  return
63
91
  try:
92
+ if (
93
+ args.prompt_template is None
94
+ or not isinstance(args.prompt_template, str)
95
+ or not args.prompt_template
96
+ ):
97
+ raise ValueError(
98
+ ERROR_COMMON_DELETE_PROMPT_TEMPLATE_PROMPT_TEMPLATE_VALIDATION
99
+ )
64
100
  api_delete_prompt_template(args.prompt_template)
65
101
  print("[delete_prompt_template]: Prompt template deleted.")
66
102
  except Exception as e:
@@ -22,6 +22,23 @@ from moonshot.api import (
22
22
  api_load_session,
23
23
  )
24
24
  from moonshot.integrations.cli.active_session_cfg import active_session
25
+ from moonshot.integrations.cli.cli_errors import (
26
+ ERROR_RED_TEAMING_ADD_BOOKMARK_ENDPOINT_VALIDATION,
27
+ ERROR_RED_TEAMING_ADD_BOOKMARK_ENDPOINT_VALIDATION_1,
28
+ ERROR_RED_TEAMING_ADD_BOOKMARK_NO_ACTIVE_SESSION,
29
+ ERROR_RED_TEAMING_LIST_SESSIONS_FIND_VALIDATION,
30
+ ERROR_RED_TEAMING_LIST_SESSIONS_PAGINATION_VALIDATION,
31
+ ERROR_RED_TEAMING_LIST_SESSIONS_PAGINATION_VALIDATION_1,
32
+ ERROR_RED_TEAMING_NEW_SESSION_ENDPOINTS_VALIDATION,
33
+ ERROR_RED_TEAMING_NEW_SESSION_FAILED_TO_USE_SESSION,
34
+ ERROR_RED_TEAMING_NEW_SESSION_PARAMS_VALIDATION,
35
+ ERROR_RED_TEAMING_NEW_SESSION_PARAMS_VALIDATION_1,
36
+ ERROR_RED_TEAMING_SHOW_PROMPTS_NO_ACTIVE_SESSION_VALIDATION,
37
+ ERROR_RED_TEAMING_USE_BOOKMARK_NO_ACTIVE_SESSION,
38
+ ERROR_RED_TEAMING_USE_SESSION_NO_SESSION_METADATA_VALIDATION,
39
+ ERROR_RED_TEAMING_USE_SESSION_RUNNER_ID_TYPE_VALIDATION,
40
+ ERROR_RED_TEAMING_USE_SESSION_RUNNER_ID_VALIDATION,
41
+ )
25
42
  from moonshot.integrations.cli.utils.process_data import filter_data
26
43
  from moonshot.src.redteaming.session.session import Session
27
44
 
@@ -43,38 +60,73 @@ def new_session(args) -> None:
43
60
  - endpoints (str, optional): The list of endpoints for the runner."""
44
61
  global active_session
45
62
 
46
- runner_id = args.runner_id
47
- context_strategy = args.context_strategy if args.context_strategy else ""
48
- prompt_template = args.prompt_template if args.prompt_template else ""
49
- endpoints = literal_eval(args.endpoints) if args.endpoints else []
50
-
51
- # create new runner and session
52
- if endpoints:
53
- runner = api_create_runner(runner_id, endpoints)
54
- # load existing runner
55
- else:
56
- runner = api_load_runner(runner_id)
63
+ try:
64
+ required_parameters = [("runner_id", str)]
65
+ optional_parameters = [("context_strategy", str), ("prompt_template", str)]
66
+ # Check if required parameters exist in args
67
+ for param, param_type in required_parameters:
68
+ param_value = getattr(args, param, None)
69
+ if not param_value:
70
+ raise ValueError(
71
+ ERROR_RED_TEAMING_NEW_SESSION_PARAMS_VALIDATION.format(param=param)
72
+ )
73
+ if not isinstance(param_value, param_type):
74
+ raise TypeError(
75
+ ERROR_RED_TEAMING_NEW_SESSION_PARAMS_VALIDATION_1.format(
76
+ param=param, param_type=param_type.__name__
77
+ )
78
+ )
57
79
 
58
- runner_args = {}
59
- runner_args["context_strategy"] = context_strategy
60
- runner_args["prompt_template"] = prompt_template
80
+ # Check the type of optional parameters if they exist
81
+ for param, param_type in optional_parameters:
82
+ param_value = getattr(args, param, None)
83
+ if param_value is not None and not isinstance(param_value, param_type):
84
+ raise TypeError(
85
+ ERROR_RED_TEAMING_NEW_SESSION_PARAMS_VALIDATION_1.format(
86
+ param=param, param_type=param_type.__name__
87
+ )
88
+ )
61
89
 
62
- # create new session in runner
63
- if runner.database_instance:
64
- api_create_session(
65
- runner.id, runner.database_instance, runner.endpoints, runner_args
66
- )
67
- session_metadata = api_load_session(runner.id)
68
- if session_metadata:
69
- active_session.update(session_metadata)
70
- if active_session["context_strategy"]:
71
- active_session[
72
- "cs_num_of_prev_prompts"
73
- ] = Session.DEFAULT_CONTEXT_STRATEGY_PROMPT
74
- print(f"Using session: {active_session['session_id']}")
75
- update_chat_display()
90
+ runner_id = args.runner_id
91
+ context_strategy = args.context_strategy if args.context_strategy else ""
92
+ prompt_template = args.prompt_template if args.prompt_template else ""
93
+ endpoints = []
94
+
95
+ # Check if literal eval param is correct type after eval
96
+ if hasattr(args, "endpoints") and args.endpoints:
97
+ endpoints = literal_eval(args.endpoints)
98
+ if not isinstance(endpoints, list):
99
+ raise TypeError(ERROR_RED_TEAMING_NEW_SESSION_ENDPOINTS_VALIDATION)
100
+
101
+ # create new runner and session
102
+ if endpoints:
103
+ runner = api_create_runner(runner_id, endpoints)
104
+ # load existing runner
76
105
  else:
77
- raise RuntimeError("Unable to use session")
106
+ runner = api_load_runner(runner_id)
107
+
108
+ runner_args = {}
109
+ runner_args["context_strategy"] = context_strategy
110
+ runner_args["prompt_template"] = prompt_template
111
+
112
+ # create new session in runner
113
+ if runner.database_instance:
114
+ api_create_session(
115
+ runner.id, runner.database_instance, runner.endpoints, runner_args
116
+ )
117
+ session_metadata = api_load_session(runner.id)
118
+ if session_metadata:
119
+ active_session.update(session_metadata)
120
+ if active_session["context_strategy"]:
121
+ active_session[
122
+ "cs_num_of_prev_prompts"
123
+ ] = Session.DEFAULT_CONTEXT_STRATEGY_PROMPT
124
+ print(f"[new_session] Using session: {active_session['session_id']}")
125
+ update_chat_display()
126
+ else:
127
+ raise RuntimeError(ERROR_RED_TEAMING_NEW_SESSION_FAILED_TO_USE_SESSION)
128
+ except Exception as e:
129
+ print(f"[new_session]: {str(e)}")
78
130
 
79
131
 
80
132
  def use_session(args) -> None:
@@ -85,15 +137,19 @@ def use_session(args) -> None:
85
137
  args (Namespace): The arguments passed to the function.
86
138
  """
87
139
  global active_session
88
- runner_id = args.runner_id
89
140
 
90
141
  # Load session metadata
91
142
  try:
143
+ if not args.runner_id or args.runner_id is None:
144
+ raise ValueError(ERROR_RED_TEAMING_USE_SESSION_RUNNER_ID_VALIDATION)
145
+
146
+ if not isinstance(args.runner_id, str):
147
+ raise TypeError(ERROR_RED_TEAMING_USE_SESSION_RUNNER_ID_TYPE_VALIDATION)
148
+
149
+ runner_id = args.runner_id
92
150
  session_metadata = api_load_session(runner_id)
93
151
  if not session_metadata:
94
- print(
95
- "[Session] Cannot find a session with the existing Runner ID. Please try again."
96
- )
152
+ print(ERROR_RED_TEAMING_USE_SESSION_NO_SESSION_METADATA_VALIDATION)
97
153
  return
98
154
 
99
155
  # Set the current session
@@ -115,7 +171,7 @@ def show_prompts() -> None:
115
171
  global active_session
116
172
 
117
173
  if not active_session:
118
- print("There is no active session. Activate a session to show a chat table.")
174
+ print(ERROR_RED_TEAMING_SHOW_PROMPTS_NO_ACTIVE_SESSION_VALIDATION)
119
175
  return
120
176
 
121
177
  update_chat_display()
@@ -146,6 +202,28 @@ def list_sessions(args) -> list | None:
146
202
  """
147
203
  try:
148
204
  session_metadata_list = api_get_all_session_metadata()
205
+ if args.find is not None:
206
+ if not isinstance(args.find, str) or not args.find:
207
+ raise TypeError(ERROR_RED_TEAMING_LIST_SESSIONS_FIND_VALIDATION)
208
+
209
+ if args.pagination is not None:
210
+ if not isinstance(args.pagination, str) or not args.pagination:
211
+ raise TypeError(ERROR_RED_TEAMING_LIST_SESSIONS_PAGINATION_VALIDATION)
212
+ try:
213
+ pagination = literal_eval(args.pagination)
214
+ if not (
215
+ isinstance(pagination, tuple)
216
+ and len(pagination) == 2
217
+ and all(isinstance(i, int) for i in pagination)
218
+ ):
219
+ raise ValueError(
220
+ ERROR_RED_TEAMING_LIST_SESSIONS_PAGINATION_VALIDATION_1
221
+ )
222
+ except (ValueError, SyntaxError):
223
+ raise ValueError(
224
+ ERROR_RED_TEAMING_LIST_SESSIONS_PAGINATION_VALIDATION_1
225
+ )
226
+
149
227
  keyword = args.find.lower() if args.find else ""
150
228
  pagination = literal_eval(args.pagination) if args.pagination else ()
151
229
 
@@ -159,7 +237,6 @@ def list_sessions(args) -> list | None:
159
237
 
160
238
  console.print("[red]There are no sessions found.[/red]")
161
239
  return None
162
-
163
240
  except Exception as e:
164
241
  print(f"[list_sessions]: {str(e)}")
165
242
 
@@ -214,7 +291,6 @@ def update_chat_display() -> None:
214
291
  title_align="left",
215
292
  )
216
293
  console.print(panel)
217
-
218
294
  else:
219
295
  console.print("[red]There are no active session.[/red]")
220
296
 
@@ -249,9 +325,7 @@ def add_bookmark(args) -> None:
249
325
  target_endpoint_chats = list_of_target_endpoint_chat.get(endpoint, None)
250
326
  target_endpoint_chat_record = {}
251
327
  if not target_endpoint_chats:
252
- print(
253
- "Incorrect endpoint. Please select a valid endpoint in this session."
254
- )
328
+ print(ERROR_RED_TEAMING_ADD_BOOKMARK_ENDPOINT_VALIDATION)
255
329
  return
256
330
  for endpoint_chat in target_endpoint_chats:
257
331
  if endpoint_chat["chat_record_id"] == prompt_id:
@@ -273,12 +347,14 @@ def add_bookmark(args) -> None:
273
347
  print("[bookmark_prompt]:", bookmark_message["message"])
274
348
  else:
275
349
  print(
276
- f"Unable to find prompt ID in the of prompts for endpoint {endpoint}. Please select a valid ID."
350
+ ERROR_RED_TEAMING_ADD_BOOKMARK_ENDPOINT_VALIDATION_1.format(
351
+ endpoint=endpoint
352
+ )
277
353
  )
278
354
  except Exception as e:
279
355
  print(f"[bookmark_prompt]: ({str(e)})")
280
356
  else:
281
- print("There is no active session. Activate a session to bookmark a prompt.")
357
+ print(ERROR_RED_TEAMING_ADD_BOOKMARK_NO_ACTIVE_SESSION)
282
358
  return
283
359
 
284
360
 
@@ -323,7 +399,7 @@ def use_bookmark(args) -> None:
323
399
  except Exception as e:
324
400
  print(f"[use_bookmark]: {str(e)}")
325
401
  else:
326
- print("There is no active session. Activate a session to use a bookmark.")
402
+ print(ERROR_RED_TEAMING_USE_BOOKMARK_NO_ACTIVE_SESSION)
327
403
  return
328
404
 
329
405
 
@@ -693,6 +769,7 @@ def delete_session(args) -> None:
693
769
  which is the ID of the session to delete.
694
770
  """
695
771
  # Confirm with the user before deleting a session
772
+
696
773
  confirmation = console.input(
697
774
  "[bold red]Are you sure you want to delete the session (y/N)? [/]"
698
775
  )
@@ -700,6 +777,12 @@ def delete_session(args) -> None:
700
777
  console.print("[bold yellow]Session deletion cancelled.[/]")
701
778
  return
702
779
  try:
780
+ if not args.session or args.session is None:
781
+ raise ValueError("Invalid or missing required parameter: session")
782
+
783
+ if not isinstance(args.session, str):
784
+ raise TypeError("Invalid type for parameter: session. Expecting type str.")
785
+
703
786
  api_delete_session(args.session)
704
787
  print("[delete_session]: Session deleted.")
705
788
  except Exception as e:
@@ -844,7 +927,7 @@ automated_rt_session_args.add_argument(
844
927
  type=str,
845
928
  help=(
846
929
  "The number of previous prompts to use with the context strategy. If this is set, it will overwrite the"
847
- " number of previous promtps set in the session while running this attack module."
930
+ " number of previous prompts set in the session while running this attack module."
848
931
  ),
849
932
  nargs="?",
850
933
  )
@@ -71,7 +71,7 @@ def create_app(cfg: providers.Configuration) -> CustomFastAPI:
71
71
  }
72
72
 
73
73
  app: CustomFastAPI = CustomFastAPI(
74
- title="Project Moonshot", version="0.4.5", **app_kwargs
74
+ title="Project Moonshot", version="0.4.6", **app_kwargs
75
75
  )
76
76
 
77
77
  if cfg.cors.enabled():