aiverify-moonshot 0.4.8__py3-none-any.whl → 0.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. {aiverify_moonshot-0.4.8.dist-info → aiverify_moonshot-0.4.9.dist-info}/METADATA +3 -3
  2. {aiverify_moonshot-0.4.8.dist-info → aiverify_moonshot-0.4.9.dist-info}/RECORD +29 -29
  3. {aiverify_moonshot-0.4.8.dist-info → aiverify_moonshot-0.4.9.dist-info}/licenses/LICENSE.md +1 -1
  4. moonshot/__main__.py +93 -49
  5. moonshot/api.py +12 -10
  6. moonshot/integrations/cli/benchmark/metrics.py +8 -2
  7. moonshot/integrations/cli/cli_errors.py +14 -0
  8. moonshot/integrations/cli/common/common.py +14 -8
  9. moonshot/integrations/cli/common/dataset.py +303 -65
  10. moonshot/integrations/cli/redteam/attack_module.py +30 -1
  11. moonshot/integrations/web_api/app.py +1 -1
  12. moonshot/integrations/web_api/routes/dataset.py +52 -18
  13. moonshot/integrations/web_api/schemas/cookbook_response_model.py +2 -0
  14. moonshot/integrations/web_api/schemas/dataset_create_dto.py +14 -4
  15. moonshot/integrations/web_api/schemas/recipe_response_model.py +1 -0
  16. moonshot/integrations/web_api/services/cookbook_service.py +36 -9
  17. moonshot/integrations/web_api/services/dataset_service.py +34 -9
  18. moonshot/integrations/web_api/services/recipe_service.py +33 -3
  19. moonshot/src/api/api_dataset.py +43 -11
  20. moonshot/src/bookmark/bookmark.py +16 -9
  21. moonshot/src/datasets/dataset.py +37 -45
  22. moonshot/src/datasets/dataset_arguments.py +2 -1
  23. moonshot/src/messages_constants.py +1 -0
  24. moonshot/src/redteaming/attack/attack_module.py +40 -0
  25. moonshot/src/storage/io_interface.py +18 -1
  26. moonshot/src/storage/storage.py +57 -1
  27. {aiverify_moonshot-0.4.8.dist-info → aiverify_moonshot-0.4.9.dist-info}/WHEEL +0 -0
  28. {aiverify_moonshot-0.4.8.dist-info → aiverify_moonshot-0.4.9.dist-info}/licenses/AUTHORS.md +0 -0
  29. {aiverify_moonshot-0.4.8.dist-info → aiverify_moonshot-0.4.9.dist-info}/licenses/NOTICES.md +0 -0
@@ -5,7 +5,7 @@ from .... import api as moonshot_api
5
5
  from ..schemas.cookbook_create_dto import CookbookCreateDTO, CookbookUpdateDTO
6
6
  from ..schemas.cookbook_response_model import CookbookResponseModel
7
7
  from ..services.base_service import BaseService
8
- from ..services.recipe_service import get_total_prompt_in_recipe
8
+ from ..services.recipe_service import get_total_prompt_in_recipe, get_endpoint_dependency_in_recipe
9
9
  from ..services.utils.exceptions_handler import exception_handler
10
10
 
11
11
 
@@ -63,24 +63,24 @@ class CookbookService(BaseService):
63
63
  if cookbook not in cookbooks_list:
64
64
  cookbooks_list.append(cookbook)
65
65
  if count:
66
- cookbook.total_prompt_in_cookbook = (
67
- get_total_prompt_in_cookbook(cookbook)
66
+ cookbook.total_prompt_in_cookbook, cookbook.total_dataset_in_cookbook = (
67
+ get_total_prompt_and_dataset_in_cookbook(cookbook)
68
68
  )
69
69
 
70
70
  if tags and cookbooks_recipe_has_tags(tags, cookbook):
71
71
  if cookbook not in cookbooks_list:
72
72
  cookbooks_list.append(cookbook)
73
73
  if count:
74
- cookbook.total_prompt_in_cookbook = (
75
- get_total_prompt_in_cookbook(cookbook)
74
+ cookbook.total_prompt_in_cookbook, cookbook.total_dataset_in_cookbook = (
75
+ get_total_prompt_and_dataset_in_cookbook(cookbook)
76
76
  )
77
77
 
78
78
  if categories and cookbooks_recipe_has_categories(categories, cookbook):
79
79
  if cookbook not in cookbooks_list:
80
80
  cookbooks_list.append(cookbook)
81
81
  if count:
82
- cookbook.total_prompt_in_cookbook = (
83
- get_total_prompt_in_cookbook(cookbook)
82
+ cookbook.total_prompt_in_cookbook, cookbook.total_dataset_in_cookbook = (
83
+ get_total_prompt_and_dataset_in_cookbook(cookbook)
84
84
  )
85
85
 
86
86
  if categories_excluded and cookbooks_recipe_has_categories(
@@ -88,6 +88,9 @@ class CookbookService(BaseService):
88
88
  ):
89
89
  cookbooks_list.remove(cookbook)
90
90
 
91
+ for cookbook in cookbooks_list:
92
+ cookbook.endpoint_required = cookbook_endpoint_dependency(cookbook)
93
+
91
94
  return cookbooks_list
92
95
 
93
96
  @exception_handler
@@ -131,7 +134,7 @@ class CookbookService(BaseService):
131
134
 
132
135
 
133
136
  @staticmethod
134
- def get_total_prompt_in_cookbook(cookbook: Cookbook) -> int:
137
+ def get_total_prompt_and_dataset_in_cookbook(cookbook: Cookbook) -> tuple[int, int]:
135
138
  """
136
139
  Calculate the total number of prompts in a cookbook.
137
140
 
@@ -144,7 +147,8 @@ def get_total_prompt_in_cookbook(cookbook: Cookbook) -> int:
144
147
  int: The total count of prompts within the cookbook.
145
148
  """
146
149
  recipes = moonshot_api.api_read_recipes(cookbook.recipes)
147
- return sum(get_total_prompt_in_recipe(Recipe(**recipe)) for recipe in recipes)
150
+ total_prompts, total_datasets = zip(*(get_total_prompt_in_recipe(Recipe(**recipe)) for recipe in recipes))
151
+ return sum(total_prompts), sum(total_datasets)
148
152
 
149
153
 
150
154
  @staticmethod
@@ -192,3 +196,26 @@ def cookbooks_recipe_has_categories(categories: str, cookbook: Cookbook) -> bool
192
196
  ):
193
197
  return True
194
198
  return False
199
+
200
+ @staticmethod
201
+ def cookbook_endpoint_dependency(cookbook: Cookbook) -> list[str] | None:
202
+ """
203
+ Retrieve a list of endpoint dependencies for all recipes in a given cookbook.
204
+
205
+ Args:
206
+ cookbook (Cookbook): The cookbook object containing the recipe IDs.
207
+
208
+ Returns:
209
+ list[str] | None: A list of endpoint dependencies if any are found, otherwise None.
210
+ """
211
+ recipes_in_cookbook = cookbook.recipes
212
+ recipes = moonshot_api.api_read_recipes(recipes_in_cookbook)
213
+ list_of_endpoints = set()
214
+
215
+ for recipe in recipes:
216
+ recipe = Recipe(**recipe)
217
+ endpoints = get_endpoint_dependency_in_recipe(recipe)
218
+ if endpoints:
219
+ list_of_endpoints.update(endpoints)
220
+
221
+ return list(list_of_endpoints) if list_of_endpoints else None
@@ -1,5 +1,5 @@
1
1
  from .... import api as moonshot_api
2
- from ..schemas.dataset_create_dto import DatasetCreateDTO
2
+ from ..schemas.dataset_create_dto import CSV_Dataset_DTO, HF_Dataset_DTO
3
3
  from ..schemas.dataset_response_dto import DatasetResponseDTO
4
4
  from ..services.base_service import BaseService
5
5
  from ..services.utils.exceptions_handler import exception_handler
@@ -8,24 +8,49 @@ from .utils.file_manager import copy_file
8
8
 
9
9
  class DatasetService(BaseService):
10
10
  @exception_handler
11
- def create_dataset(self, dataset_data: DatasetCreateDTO, method: str) -> str:
11
+ def convert_dataset(self, dataset_data: CSV_Dataset_DTO) -> str:
12
12
  """
13
- Create a dataset using the specified method.
13
+ Convert a dataset using the provided dataset data.
14
14
 
15
15
  Args:
16
- dataset_data (DatasetCreateDTO): The data required to create the dataset.
17
- method (str): The method to use for creating the dataset.
18
- Supported methods are "hf" and "csv".
16
+ dataset_data (CSV_Dataset_DTO): The data required to convert the dataset.
17
+
18
+ Returns:
19
+ str: The path to the newly created dataset.
20
+
21
+ Raises:
22
+ Exception: If an error occurs during dataset conversion.
23
+ """
24
+
25
+ new_ds_path = moonshot_api.api_convert_dataset(
26
+ name=dataset_data.name,
27
+ description=dataset_data.description,
28
+ reference=dataset_data.reference,
29
+ license=dataset_data.license,
30
+ csv_file_path=dataset_data.csv_file_path,
31
+ )
32
+ return copy_file(new_ds_path)
33
+
34
+ @exception_handler
35
+ def download_dataset(self, dataset_data: HF_Dataset_DTO) -> str:
36
+ """
37
+ Download a dataset using the provided dataset data.
38
+
39
+ Args:
40
+ dataset_data (HF_Dataset_DTO): The data required to download the dataset.
41
+
42
+ Returns:
43
+ str: The path to the newly downloaded dataset.
19
44
 
20
45
  Raises:
21
- Exception: If an error occurs during dataset creation.
46
+ Exception: If an error occurs during dataset download.
22
47
  """
23
- new_ds_path = moonshot_api.api_create_datasets(
48
+
49
+ new_ds_path = moonshot_api.api_download_dataset(
24
50
  name=dataset_data.name,
25
51
  description=dataset_data.description,
26
52
  reference=dataset_data.reference,
27
53
  license=dataset_data.license,
28
- method=method,
29
54
  **dataset_data.params,
30
55
  )
31
56
  return copy_file(new_ds_path)
@@ -60,7 +60,7 @@ class RecipeService(BaseService):
60
60
  for recipe_dict in recipes:
61
61
  recipe = RecipeResponseModel(**recipe_dict)
62
62
  if count:
63
- recipe.total_prompt_in_recipe = get_total_prompt_in_recipe(recipe)
63
+ recipe.total_prompt_in_recipe, _ = get_total_prompt_in_recipe(recipe)
64
64
  filtered_recipes.append(recipe)
65
65
 
66
66
  # TODO - do all filtering in 1 pass
@@ -84,6 +84,9 @@ class RecipeService(BaseService):
84
84
  if sort_by == "id":
85
85
  filtered_recipes.sort(key=lambda x: x.id)
86
86
 
87
+ for recipe in filtered_recipes:
88
+ recipe.endpoint_required = get_endpoint_dependency_in_recipe(recipe)
89
+
87
90
  return filtered_recipes
88
91
 
89
92
  @exception_handler
@@ -126,7 +129,7 @@ class RecipeService(BaseService):
126
129
 
127
130
 
128
131
  @staticmethod
129
- def get_total_prompt_in_recipe(recipe: Recipe) -> int:
132
+ def get_total_prompt_in_recipe(recipe: Recipe) -> tuple[int, int]:
130
133
  """
131
134
  Calculate the total number of prompts in a recipe.
132
135
 
@@ -139,6 +142,7 @@ def get_total_prompt_in_recipe(recipe: Recipe) -> int:
139
142
 
140
143
  Returns:
141
144
  int: The total count of prompts within the recipe.
145
+ int: The total count of datasets within the recipe.
142
146
  """
143
147
  # Initialize total prompt count
144
148
  total_prompt_count = 0
@@ -151,4 +155,30 @@ def get_total_prompt_in_recipe(recipe: Recipe) -> int:
151
155
  if recipe.prompt_templates:
152
156
  total_prompt_count *= len(recipe.prompt_templates)
153
157
 
154
- return total_prompt_count
158
+ return total_prompt_count, int(recipe.stats.get("num_of_datasets", 0))
159
+
160
+ @staticmethod
161
+ def get_endpoint_dependency_in_recipe(recipe: Recipe) -> list[str] | None:
162
+ """
163
+ Retrieve the list of endpoint dependencies for a given recipe.
164
+
165
+ This function fetches all metrics and their associated endpoints, then
166
+ matches the metrics in the provided recipe to find and compile a list
167
+ of endpoint dependencies.
168
+
169
+ Args:
170
+ recipe (Recipe): The recipe object containing the metrics information.
171
+
172
+ Returns:
173
+ list[str] | None: A list of endpoint dependencies if found, otherwise None.
174
+ """
175
+ metrics = recipe.metrics
176
+ all_metrics = moonshot_api.api_get_all_metric()
177
+
178
+ endpoints = set()
179
+ for metric in metrics:
180
+ for m in all_metrics:
181
+ if m['id'] == metric:
182
+ endpoints.update(m['endpoints'])
183
+
184
+ return list(endpoints) if endpoints else None
@@ -47,35 +47,67 @@ def api_get_all_datasets_name() -> list[str]:
47
47
  return datasets_name
48
48
 
49
49
 
50
- def api_create_datasets(
51
- name: str, description: str, reference: str, license: str, method: str, **kwargs
50
+ def api_download_dataset(
51
+ name: str, description: str, reference: str, license: str, **kwargs
52
52
  ) -> str:
53
53
  """
54
- This function creates a new dataset.
54
+ Downloads a dataset from Hugging Face and creates a new dataset with the provided details.
55
55
 
56
- This function takes the name, description, reference, and license for a new dataset as input. It then creates a new
57
- DatasetArguments object with these details and an empty id. The id is left empty because it will be generated
58
- from the name during the creation process. The function then calls the Dataset's create method to
59
- create the new dataset.
56
+ This function takes the name, description, reference, and license for a new dataset as input, along with additional
57
+ keyword arguments for downloading the dataset from Hugging Face. It then creates a new DatasetArguments object with
58
+ these details and an empty id. The id is left empty because it will be generated from the name during the creation
59
+ process. The function then calls the Dataset's create method to create the new dataset.
60
60
 
61
61
  Args:
62
62
  name (str): The name of the new dataset.
63
63
  description (str): A brief description of the new dataset.
64
64
  reference (str): A reference link for the new dataset.
65
65
  license (str): The license of the new dataset.
66
- method (str): The method to create new dataset. (csv/hf)
67
- kwargs: Additional keyword arguments for the Dataset's create method.
66
+ kwargs: Additional keyword arguments for downloading the dataset from Hugging Face.
68
67
 
69
68
  Returns:
70
69
  str: The ID of the newly created dataset.
71
70
  """
71
+ examples = Dataset.download_hf(**kwargs)
72
72
  ds_args = DatasetArguments(
73
73
  id="",
74
74
  name=name,
75
75
  description=description,
76
76
  reference=reference,
77
77
  license=license,
78
- examples=None,
78
+ examples=examples,
79
79
  )
80
+ return Dataset.create(ds_args)
80
81
 
81
- return Dataset.create(ds_args, method, **kwargs)
82
+
83
+ def api_convert_dataset(
84
+ name: str, description: str, reference: str, license: str, csv_file_path: str
85
+ ) -> str:
86
+ """
87
+ Converts a CSV file to a dataset and creates a new dataset with the provided details.
88
+
89
+ This function takes the name, description, reference, and license for a new dataset as input, along with the file
90
+ path to a CSV file. It then creates a new DatasetArguments object with these details and an empty id. The id is left
91
+ empty because it will be generated from the name during the creation process. The function then calls the Dataset's
92
+ create method to create the new dataset.
93
+
94
+ Args:
95
+ name (str): The name of the new dataset.
96
+ description (str): A brief description of the new dataset.
97
+ reference (str): A reference link for the new dataset.
98
+ license (str): The license of the new dataset.
99
+ csv_file_path (str): The file path to the CSV file.
100
+
101
+ Returns:
102
+ str: The ID of the newly created dataset.
103
+ """
104
+ examples = Dataset.convert_data(csv_file_path)
105
+ ds_args = DatasetArguments(
106
+ id="",
107
+ name=name,
108
+ description=description,
109
+ reference=reference,
110
+ license=license,
111
+ examples=examples,
112
+ )
113
+ return Dataset.create(ds_args)
@@ -11,6 +11,7 @@ from moonshot.src.messages_constants import (
11
11
  BOOKMARK_ADD_BOOKMARK_VALIDATION_ERROR,
12
12
  BOOKMARK_DELETE_ALL_BOOKMARK_ERROR,
13
13
  BOOKMARK_DELETE_ALL_BOOKMARK_SUCCESS,
14
+ BOOKMARK_DELETE_BOOKMARK_FAIL,
14
15
  BOOKMARK_DELETE_BOOKMARK_ERROR,
15
16
  BOOKMARK_DELETE_BOOKMARK_ERROR_1,
16
17
  BOOKMARK_DELETE_BOOKMARK_SUCCESS,
@@ -209,15 +210,21 @@ class Bookmark:
209
210
  """
210
211
  if isinstance(bookmark_name, str) and bookmark_name:
211
212
  try:
212
- sql_delete_bookmark_record = textwrap.dedent(
213
- f"""
214
- DELETE FROM bookmark WHERE name = '{bookmark_name}';
215
- """
216
- )
217
- Storage.delete_database_record_in_table(
218
- self.db_instance, sql_delete_bookmark_record
219
- )
220
- return {"success": True, "message": BOOKMARK_DELETE_BOOKMARK_SUCCESS}
213
+
214
+ bookmark_info = Storage.read_database_record(
215
+ self.db_instance, (bookmark_name,), Bookmark.sql_select_bookmark_record)
216
+ if bookmark_info is not None:
217
+ sql_delete_bookmark_record = textwrap.dedent(
218
+ f"""
219
+ DELETE FROM bookmark WHERE name = '{bookmark_name}';
220
+ """
221
+ )
222
+ Storage.delete_database_record_in_table(
223
+ self.db_instance, sql_delete_bookmark_record
224
+ )
225
+ return {"success": True, "message": BOOKMARK_DELETE_BOOKMARK_SUCCESS}
226
+ else:
227
+ return {"success": False, "message": BOOKMARK_DELETE_BOOKMARK_FAIL}
221
228
  except Exception as e:
222
229
  return {
223
230
  "success": False,
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from pathlib import Path
4
+ from typing import Iterator
4
5
 
5
6
  import pandas as pd
6
7
  from datasets import load_dataset
@@ -22,34 +23,21 @@ class Dataset:
22
23
 
23
24
  @staticmethod
24
25
  @validate_call
25
- def create(ds_args: DatasetArguments, method: str, **kwargs) -> str:
26
+ def create(ds_args: DatasetArguments) -> str:
26
27
  """
27
28
  Creates a new dataset based on the provided arguments and method.
28
29
 
29
30
  This method generates a unique dataset ID using the dataset name,
30
31
  checks if a dataset with the same ID already exists, and then
31
- creates the dataset using the specified method (either 'csv' or
32
- 'hf'). The dataset information is then stored as a JSON object.
32
+ creates the dataset using the specified method. The dataset information
33
+ is then stored as a JSON object.
33
34
 
34
35
  Args:
35
36
  ds_args (DatasetArguments): The arguments containing dataset
36
- details such as name, description, reference, and license.
37
- method (str): The method to create the dataset. It can be either
38
- 'csv' or 'hf'.
39
- **kwargs: Additional keyword arguments required for the specified
40
- method.
41
- - For 'csv' method: 'csv_file_path' (str): The file path to
42
- the CSV file.
43
- - For 'hf' method: 'dataset_name' (str): The name of the
44
- Hugging Face dataset.
45
- 'dataset_config' (str): The configuration of the Hugging
46
- Face dataset.
47
- 'split' (str): The split of the dataset to load.
48
- 'input_col' (list[str]): The list of input columns.
49
- 'target_col' (str): The target column.
37
+ details such as name, description, reference, license, and examples.
50
38
 
51
39
  Returns:
52
- str: The unique ID of the created dataset.
40
+ str: The file path of the created dataset JSON object.
53
41
 
54
42
  Raises:
55
43
  RuntimeError: If a dataset with the same ID already exists.
@@ -63,56 +51,58 @@ class Dataset:
63
51
  if Storage.is_object_exists(EnvVariables.DATASETS.name, ds_id, "json"):
64
52
  raise RuntimeError(f"Dataset with ID '{ds_id}' already exists.")
65
53
 
66
- examples = [{}]
67
- if method == "csv":
68
- examples = Dataset._convert_csv(kwargs["csv_file_path"])
69
- elif method == "hf":
70
- examples = Dataset._download_hf(kwargs)
71
-
72
54
  ds_info = {
73
55
  "id": ds_id,
74
56
  "name": ds_args.name,
75
57
  "description": ds_args.description,
76
58
  "reference": ds_args.reference,
77
59
  "license": ds_args.license,
78
- "examples": examples,
79
60
  }
80
61
 
62
+ examples = ds_args.examples
63
+
81
64
  # Write as JSON output
82
- file_path = Storage.create_object(
83
- EnvVariables.DATASETS.name, ds_id, ds_info, "json"
65
+ file_path = Storage.create_object_with_iterator(
66
+ EnvVariables.DATASETS.name,
67
+ ds_id,
68
+ ds_info,
69
+ "json",
70
+ iterator_keys=["examples"],
71
+ iterator_data=examples,
84
72
  )
85
- return file_path
86
73
 
74
+ return file_path
87
75
  except Exception as e:
88
76
  logger.error(f"Failed to create dataset: {str(e)}")
89
77
  raise e
90
78
 
91
79
  @staticmethod
92
- def _convert_csv(csv_file: str) -> list[dict]:
80
+ @validate_call
81
+ def convert_data(csv_file_path: str) -> Iterator[dict]:
93
82
  """
94
- Converts a CSV file to a list of dictionaries.
83
+ Converts a CSV file to an iterator of dictionaries.
95
84
 
96
- This method reads a CSV file and converts its contents into a list of dictionaries,
85
+ This method reads a CSV file and converts its contents into an iterator of dictionaries,
97
86
  where each dictionary represents a row in the CSV file.
98
87
 
99
88
  Args:
100
- csv_file (str): The file path to the CSV file.
89
+ csv_file_path (str): The file path to the CSV file.
101
90
 
102
91
  Returns:
103
- list[dict]: A list of dictionaries representing the CSV data.
92
+ Iterator[dict]: An iterator of dictionaries representing the CSV data.
104
93
  """
105
- df = pd.read_csv(csv_file)
106
- data = df.to_dict("records")
107
- return data
94
+ df = pd.read_csv(csv_file_path, chunksize=1)
95
+ for chunk in df:
96
+ yield chunk.to_dict("records")[0]
108
97
 
109
98
  @staticmethod
110
- def _download_hf(hf_args) -> list[dict]:
99
+ @validate_call
100
+ def download_hf(**hf_args) -> Iterator[dict]:
111
101
  """
112
- Downloads a dataset from Hugging Face and converts it to a list of dictionaries.
102
+ Downloads a dataset from Hugging Face and converts it to an iterator of dictionaries.
113
103
 
114
104
  This method loads a dataset from Hugging Face based on the provided arguments and converts
115
- its contents into a list of dictionaries, where each dictionary contains 'input' and 'target' keys.
105
+ its contents into an iterator of dictionaries, where each dictionary contains 'input' and 'target' keys.
116
106
 
117
107
  Args:
118
108
  hf_args (dict): A dictionary containing the following keys:
@@ -123,15 +113,17 @@ class Dataset:
123
113
  - 'target_col' (str): The target column.
124
114
 
125
115
  Returns:
126
- list[dict]: A list of dictionaries representing the dataset.
116
+ Iterator[dict]: An iterator of dictionaries representing the dataset.
127
117
  """
128
- dataset = load_dataset(hf_args["dataset_name"], hf_args["dataset_config"])
129
- data = []
130
- for example in dataset[hf_args["split"]]:
118
+
119
+ dataset = load_dataset(
120
+ hf_args["dataset_name"], hf_args["dataset_config"], split=hf_args["split"]
121
+ )
122
+
123
+ for example in dataset:
131
124
  input_data = " ".join([str(example[col]) for col in hf_args["input_col"]])
132
125
  target_data = str(example[hf_args["target_col"]])
133
- data.append({"input": input_data, "target": target_data})
134
- return data
126
+ yield {"input": input_data, "target": target_data}
135
127
 
136
128
  @staticmethod
137
129
  @validate_call
@@ -1,5 +1,6 @@
1
+ from typing import Iterator
2
+
1
3
  from pydantic import BaseModel
2
- from pyparsing import Iterator
3
4
 
4
5
 
5
6
  class DatasetArguments(BaseModel):
@@ -15,6 +15,7 @@ BOOKMARK_GET_BOOKMARK_ERROR_1 = "[Bookmark] Invalid bookmark name: {message}"
15
15
  # BOOKMARK - delete_bookmark
16
16
  # ------------------------------------------------------------------------------
17
17
  BOOKMARK_DELETE_BOOKMARK_SUCCESS = "[Bookmark] Bookmark record deleted."
18
+ BOOKMARK_DELETE_BOOKMARK_FAIL = "[Bookmark] Bookmark record not found. Unable to delete."
18
19
  BOOKMARK_DELETE_BOOKMARK_ERROR = (
19
20
  "[Bookmark] Failed to delete bookmark record: {message}"
20
21
  )
@@ -34,6 +34,7 @@ class AttackModule:
34
34
 
35
35
  def __init__(self, am_id: str, am_arguments: AttackModuleArguments | None = None):
36
36
  self.id = am_id
37
+ self.req_and_config = self.get_attack_module_req_and_config()
37
38
  if am_arguments is not None:
38
39
  self.connector_ids = am_arguments.connector_ids
39
40
  self.prompt_templates = am_arguments.prompt_templates
@@ -529,6 +530,45 @@ class AttackModule:
529
530
 
530
531
  return am_metadata, cache_updated
531
532
 
533
+ def get_attack_module_req_and_config(self) -> dict:
534
+ """
535
+ Retrieves the configuration for a specific attack module by its identifier.
536
+
537
+ Returns:
538
+ dict: The attack module configuration as a dictionary. Returns an empty dict if the configuration
539
+ is not found.
540
+
541
+ Raises:
542
+ Exception: If reading the attack module configuration fails or if the configuration cannot be created.
543
+ """
544
+ attack_module_config = "attack_modules_config"
545
+ try:
546
+ obj_results = Storage.read_object(
547
+ EnvVariables.ATTACK_MODULES.name, attack_module_config, "json"
548
+ )
549
+ return obj_results.get(self.id, {})
550
+ except Exception as e:
551
+ logger.warning(
552
+ f"[AttackModule] Failed to read attack module configuration: {str(e)}"
553
+ )
554
+ logger.info("Attempting to create empty attack module configuration...")
555
+ try:
556
+ Storage.create_object(
557
+ obj_type=EnvVariables.ATTACK_MODULES.name,
558
+ obj_id=attack_module_config,
559
+ obj_info={},
560
+ obj_extension="json",
561
+ )
562
+ # After creation, attempt to read it again to ensure it was created successfully
563
+ obj_results = Storage.read_object(
564
+ EnvVariables.ATTACK_MODULES.name, attack_module_config, "json"
565
+ )
566
+ return obj_results.get(self.id, {})
567
+ except Exception as e:
568
+ raise Exception(
569
+ f"[AttackModule] Failed to retrieve attack modules configuration: {str(e)}"
570
+ )
571
+
532
572
  @staticmethod
533
573
  def delete(am_id: str) -> bool:
534
574
  """
@@ -1,5 +1,5 @@
1
1
  from abc import abstractmethod
2
- from typing import Any
2
+ from typing import Any, Iterator
3
3
 
4
4
 
5
5
  class IOInterface:
@@ -16,6 +16,23 @@ class IOInterface:
16
16
  """
17
17
  pass
18
18
 
19
+ @abstractmethod
20
+ def create_file_with_iterator(
21
+ self, data: dict, iterator_keys: list[str], iterator_data: Iterator[dict]
22
+ ) -> bool:
23
+ """
24
+ Creates a file using an iterator to provide the data for specified keys.
25
+
26
+ Args:
27
+ data (dict): The data to be serialized into JSON and written to the file.
28
+ iterator_keys (list[str]): A list of keys for which the values will be written using iterators.
29
+ iterator_data (Iterator[dict]): An iterator for the data to be written for the specified keys.
30
+
31
+ Returns:
32
+ bool: Always returns True to indicate the operation was executed without raising an exception.
33
+ """
34
+ pass
35
+
19
36
  @abstractmethod
20
37
  def read_file(self, filepath: str) -> dict | None:
21
38
  """