aiverify-moonshot 0.4.8__py3-none-any.whl → 0.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. {aiverify_moonshot-0.4.8.dist-info → aiverify_moonshot-0.4.9.dist-info}/METADATA +3 -3
  2. {aiverify_moonshot-0.4.8.dist-info → aiverify_moonshot-0.4.9.dist-info}/RECORD +29 -29
  3. {aiverify_moonshot-0.4.8.dist-info → aiverify_moonshot-0.4.9.dist-info}/licenses/LICENSE.md +1 -1
  4. moonshot/__main__.py +93 -49
  5. moonshot/api.py +12 -10
  6. moonshot/integrations/cli/benchmark/metrics.py +8 -2
  7. moonshot/integrations/cli/cli_errors.py +14 -0
  8. moonshot/integrations/cli/common/common.py +14 -8
  9. moonshot/integrations/cli/common/dataset.py +303 -65
  10. moonshot/integrations/cli/redteam/attack_module.py +30 -1
  11. moonshot/integrations/web_api/app.py +1 -1
  12. moonshot/integrations/web_api/routes/dataset.py +52 -18
  13. moonshot/integrations/web_api/schemas/cookbook_response_model.py +2 -0
  14. moonshot/integrations/web_api/schemas/dataset_create_dto.py +14 -4
  15. moonshot/integrations/web_api/schemas/recipe_response_model.py +1 -0
  16. moonshot/integrations/web_api/services/cookbook_service.py +36 -9
  17. moonshot/integrations/web_api/services/dataset_service.py +34 -9
  18. moonshot/integrations/web_api/services/recipe_service.py +33 -3
  19. moonshot/src/api/api_dataset.py +43 -11
  20. moonshot/src/bookmark/bookmark.py +16 -9
  21. moonshot/src/datasets/dataset.py +37 -45
  22. moonshot/src/datasets/dataset_arguments.py +2 -1
  23. moonshot/src/messages_constants.py +1 -0
  24. moonshot/src/redteaming/attack/attack_module.py +40 -0
  25. moonshot/src/storage/io_interface.py +18 -1
  26. moonshot/src/storage/storage.py +57 -1
  27. {aiverify_moonshot-0.4.8.dist-info → aiverify_moonshot-0.4.9.dist-info}/WHEEL +0 -0
  28. {aiverify_moonshot-0.4.8.dist-info → aiverify_moonshot-0.4.9.dist-info}/licenses/AUTHORS.md +0 -0
  29. {aiverify_moonshot-0.4.8.dist-info → aiverify_moonshot-0.4.9.dist-info}/licenses/NOTICES.md +0 -0
@@ -2,118 +2,356 @@ from ast import literal_eval
2
2
 
3
3
  import cmd2
4
4
  from rich.console import Console
5
+ from rich.table import Table
6
+
7
+ from moonshot.api import (
8
+ api_delete_dataset,
9
+ api_get_all_datasets,
10
+ api_get_all_datasets_name,
11
+ api_convert_dataset,
12
+ api_download_dataset
13
+ )
5
14
 
6
- from moonshot.api import api_create_datasets
7
15
  from moonshot.integrations.cli.cli_errors import (
8
- ERROR_COMMON_ADD_DATASET_DESC_VALIDATION,
9
- ERROR_COMMON_ADD_DATASET_LICENSE_VALIDATION,
10
- ERROR_COMMON_ADD_DATASET_METHOD_VALIDATION,
11
- ERROR_COMMON_ADD_DATASET_NAME_VALIDATION,
12
- ERROR_COMMON_ADD_DATASET_PARAMS_VALIDATION,
13
- ERROR_COMMON_ADD_DATASET_REFERENCE_VALIDATION,
16
+ ERROR_BENCHMARK_DELETE_DATASET_DATASET_VALIDATION,
17
+ ERROR_BENCHMARK_LIST_DATASETS_FIND_VALIDATION,
18
+ ERROR_BENCHMARK_LIST_DATASETS_PAGINATION_VALIDATION,
19
+ ERROR_BENCHMARK_LIST_DATASETS_PAGINATION_VALIDATION_1,
20
+ ERROR_BENCHMARK_VIEW_DATASET_DATASET_FILENAME_VALIDATION,
14
21
  )
15
22
 
16
- console = Console()
23
+ from moonshot.integrations.cli.common.display_helper import display_view_str_format
24
+ from moonshot.integrations.cli.utils.process_data import filter_data
17
25
 
26
+ console = Console()
18
27
 
19
- def add_dataset(args) -> None:
28
+ def list_datasets(args) -> list | None:
20
29
  """
21
- Create a new dataset using the provided arguments and log the result.
30
+ List all available datasets.
22
31
 
23
- This function attempts to create a new dataset by calling the `api_create_datasets`
24
- function with the necessary parameters extracted from `args`. If successful, it logs
25
- the creation of the dataset with its ID. If an exception occurs, it logs the error.
32
+ This function retrieves all available datasets by calling the api_get_all_datasets function from the
33
+ moonshot.api module. It then filters the datasets based on the provided keyword and pagination arguments.
34
+ If there are no datasets, it prints a message indicating that no datasets were found.
26
35
 
27
36
  Args:
28
- args: An argparse.Namespace object containing the following attributes:
29
- - name (str): Name of the new dataset.
30
- - description (str): Description of the new dataset.
31
- - reference (str): Reference URL for the new dataset.
32
- - license (str): License type for the new dataset.
33
- - method (str): Method to convert the new dataset ('hf' or 'csv').
34
- - params (dict): Additional parameters for dataset creation.
37
+ args: A namespace object from argparse. It should have optional attributes:
38
+ find (str): Optional keyword to filter datasets.
39
+ pagination (str): Optional tuple to paginate datasets.
40
+
41
+ Returns:
42
+ list | None: A list of datasets or None if there are no datasets.
35
43
  """
36
44
  try:
37
- if not isinstance(args.name, str) or not args.name or args.name is None:
38
- raise TypeError(ERROR_COMMON_ADD_DATASET_NAME_VALIDATION)
45
+ print("Listing datasets may take a while...")
46
+ if args.find is not None:
47
+ if not isinstance(args.find, str) or not args.find:
48
+ raise TypeError(ERROR_BENCHMARK_LIST_DATASETS_FIND_VALIDATION)
39
49
 
40
- if (
41
- not isinstance(args.description, str)
42
- or not args.description
43
- or args.description is None
44
- ):
45
- raise TypeError(ERROR_COMMON_ADD_DATASET_DESC_VALIDATION)
50
+ if args.pagination is not None:
51
+ if not isinstance(args.pagination, str) or not args.pagination:
52
+ raise TypeError(ERROR_BENCHMARK_LIST_DATASETS_PAGINATION_VALIDATION)
53
+ try:
54
+ pagination = literal_eval(args.pagination)
55
+ if not (
56
+ isinstance(pagination, tuple)
57
+ and len(pagination) == 2
58
+ and all(isinstance(i, int) for i in pagination)
59
+ ):
60
+ raise ValueError(
61
+ ERROR_BENCHMARK_LIST_DATASETS_PAGINATION_VALIDATION_1
62
+ )
63
+ except (ValueError, SyntaxError):
64
+ raise ValueError(ERROR_BENCHMARK_LIST_DATASETS_PAGINATION_VALIDATION_1)
65
+ else:
66
+ pagination = ()
46
67
 
47
- if (
48
- not isinstance(args.reference, str)
49
- or not args.reference
50
- or args.reference is None
51
- ):
52
- raise TypeError(ERROR_COMMON_ADD_DATASET_REFERENCE_VALIDATION)
68
+ datasets_list = api_get_all_datasets()
69
+ keyword = args.find.lower() if args.find else ""
70
+
71
+ if datasets_list:
72
+ filtered_datasets_list = filter_data(datasets_list, keyword, pagination)
73
+ if filtered_datasets_list:
74
+ _display_datasets(filtered_datasets_list)
75
+ return filtered_datasets_list
76
+
77
+ console.print("[red]There are no datasets found.[/red]")
78
+ return None
53
79
 
80
+ except Exception as e:
81
+ print(f"[list_datasets]: {str(e)}")
82
+ return None
83
+
84
+
85
+ def view_dataset(args) -> None:
86
+ """
87
+ View a specific dataset.
88
+
89
+ This function retrieves all available datasets and their names by calling the api_get_all_datasets and
90
+ api_get_all_datasets_name functions. It then finds the dataset with the name specified in args.dataset_filename
91
+ and displays it using the _display_datasets function. If an exception occurs, it prints an error message.
92
+
93
+ Args:
94
+ args: A namespace object from argparse. It should have the following attribute:
95
+ dataset_filename (str): The name of the dataset to view.
96
+
97
+ Returns:
98
+ None
99
+ """
100
+ try:
101
+ print("Viewing datasets may take a while...")
54
102
  if (
55
- not isinstance(args.license, str)
56
- or not args.license
57
- or args.license is None
103
+ not isinstance(args.dataset_filename, str)
104
+ or not args.dataset_filename
105
+ or args.dataset_filename is None
58
106
  ):
59
- raise TypeError(ERROR_COMMON_ADD_DATASET_LICENSE_VALIDATION)
107
+ raise TypeError(ERROR_BENCHMARK_VIEW_DATASET_DATASET_FILENAME_VALIDATION)
60
108
 
109
+ datasets_list = api_get_all_datasets()
110
+ datasets_name_list = api_get_all_datasets_name()
111
+
112
+ # Find the index of the dataset with the name args.dataset_filename
113
+ dataset_index = datasets_name_list.index(args.dataset_filename)
114
+ # Pass the corresponding dataset from datasets_list to _display_datasets
115
+ _display_datasets([datasets_list[dataset_index]])
116
+
117
+ except Exception as e:
118
+ print(f"[view_dataset]: {str(e)}")
119
+
120
+
121
+ def delete_dataset(args) -> None:
122
+ """
123
+ Delete a dataset.
124
+
125
+ This function deletes a dataset with the specified name. It prompts the user for confirmation before proceeding
126
+ with the deletion. If the user confirms, it calls the api_delete_dataset function from the moonshot.api module to
127
+ delete the dataset. If the deletion is successful, it prints a confirmation message. If an exception occurs, it
128
+ prints an error message.
129
+
130
+ Args:
131
+ args: A namespace object from argparse. It should have the following attribute:
132
+ dataset (str): The name of the dataset to delete.
133
+
134
+ Returns:
135
+ None
136
+ """
137
+ # Confirm with the user before deleting a dataset
138
+ confirmation = console.input(
139
+ "[bold red]Are you sure you want to delete the dataset (y/N)? [/]"
140
+ )
141
+ if confirmation.lower() != "y":
142
+ console.print("[bold yellow]Dataset deletion cancelled.[/]")
143
+ return
144
+
145
+ try:
61
146
  if (
62
- not isinstance(args.method, str)
63
- or not args.method
64
- or args.method is None
65
- or args.method.lower() not in ["hf", "csv"]
147
+ args.dataset is None
148
+ or not isinstance(args.dataset, str)
149
+ or not args.dataset
66
150
  ):
67
- raise TypeError(ERROR_COMMON_ADD_DATASET_METHOD_VALIDATION)
151
+ raise ValueError(ERROR_BENCHMARK_DELETE_DATASET_DATASET_VALIDATION)
68
152
 
69
- if not isinstance(args.params, dict) or not args.params or args.params is None:
70
- raise TypeError(ERROR_COMMON_ADD_DATASET_PARAMS_VALIDATION)
153
+ api_delete_dataset(args.dataset)
154
+ print("[delete_dataset]: Dataset deleted.")
155
+ except Exception as e:
156
+ print(f"[delete_dataset]: {str(e)}")
157
+
158
+ def convert_dataset(args) -> None:
159
+ """
160
+ Convert an existing dataset to a new format.
71
161
 
72
- new_dataset_id = api_create_datasets(
162
+ Args:
163
+ args: A namespace object from argparse with the following attributes:
164
+ - name (str): Name of the new dataset.
165
+ - description (str): Description of the new dataset.
166
+ - reference (str): Reference of the new dataset.
167
+ - license (str): License of the new dataset.
168
+ - csv_file_path (str): Path to the existing dataset file.
169
+
170
+ Returns:
171
+ None
172
+ """
173
+ try:
174
+ new_dataset_id = api_convert_dataset(
175
+ args.name,
176
+ args.description,
177
+ args.reference,
178
+ args.license,
179
+ args.csv_file_path,
180
+ )
181
+ print(f"[convert_dataset]: Dataset ({new_dataset_id}) created.")
182
+ except Exception as e:
183
+ print(f"[convert_dataset]: {str(e)}")
184
+
185
+
186
+ def download_dataset(args) -> None:
187
+ """
188
+ Download a dataset from Hugging Face.
189
+
190
+ Args:
191
+ args: A namespace object from argparse with the following attributes:
192
+ - name (str): Name of the new dataset.
193
+ - description (str): Description of the new dataset.
194
+ - reference (str): Reference of the new dataset.
195
+ - license (str): License of the new dataset.
196
+ - params (dict): Parameters for the dataset in dictionary format.
197
+
198
+ Returns:
199
+ None
200
+ """
201
+ try:
202
+ new_dataset_id = api_download_dataset(
73
203
  args.name,
74
204
  args.description,
75
205
  args.reference,
76
206
  args.license,
77
- args.method,
78
207
  **args.params,
79
208
  )
80
- print(f"[add_dataset]: Dataset ({new_dataset_id}) created.")
209
+ print(f"[download_dataset]: Dataset ({new_dataset_id}) created.")
81
210
  except Exception as e:
82
- print(f"[add_dataset]: {str(e)}")
211
+ print(f"[download_dataset]: {str(e)}")
212
+
213
+
214
+ # ------------------------------------------------------------------------------
215
+ # Helper functions: Display on cli
216
+ # ------------------------------------------------------------------------------
217
+ def _display_datasets(datasets_list: list):
218
+ """
219
+ Displays a list of datasets in a table format.
220
+
221
+ This function takes a list of datasets and displays them in a table format with each dataset's name, description,
222
+ and other relevant details. If the list is empty, it prints a message indicating that no datasets are found.
223
+
224
+ Args:
225
+ datasets_list (list): A list of dictionaries, where each dictionary contains the details of a dataset.
226
+
227
+ Returns:
228
+ None
229
+ """
230
+ table = Table(
231
+ title="List of Datasets", show_lines=True, expand=True, header_style="bold"
232
+ )
233
+ table.add_column("No.", width=2)
234
+ table.add_column("Dataset", justify="left", width=78)
235
+ for idx, dataset in enumerate(datasets_list, 1):
236
+ (
237
+ id,
238
+ name,
239
+ description,
240
+ _,
241
+ num_of_dataset_prompts,
242
+ created_date,
243
+ reference,
244
+ license,
245
+ *other_args,
246
+ ) = dataset.values()
247
+
248
+ idx = dataset.get("idx", idx)
249
+ prompt_info = display_view_str_format("Prompts", num_of_dataset_prompts)
250
+ created_date_info = display_view_str_format("Created Date", created_date)
251
+ license_info = display_view_str_format("License", license)
252
+ reference_info = display_view_str_format("Reference", reference)
253
+
254
+ dataset_info = (
255
+ f"[red]{id}[/red]\n\n[blue]{name}[/blue]\n{description}\n\n"
256
+ f"{prompt_info}\n\n{created_date_info}\n\n{license_info}\n\n{reference_info}"
257
+ )
258
+
259
+ table.add_section()
260
+ table.add_row(str(idx), dataset_info)
261
+ console.print(table)
83
262
 
84
263
 
85
264
  # ------------------------------------------------------------------------------
86
265
  # Cmd2 Arguments Parsers
87
266
  # ------------------------------------------------------------------------------
88
- # Add dataset arguments
89
- add_dataset_args = cmd2.Cmd2ArgumentParser(
90
- description="Add a new dataset. The 'name' argument will be slugified to create a unique identifier.",
267
+ # View dataset arguments
268
+ view_dataset_args = cmd2.Cmd2ArgumentParser(
269
+ description="View a dataset file.",
270
+ epilog="Example:\n view_dataset bbq-lite-age-ambiguous",
271
+ )
272
+ view_dataset_args.add_argument(
273
+ "dataset_filename", type=str, help="Name of the dataset file"
274
+ )
275
+
276
+ # Delete dataset arguments
277
+ delete_dataset_args = cmd2.Cmd2ArgumentParser(
278
+ description="Delete a dataset.",
279
+ epilog="Example:\n delete_dataset bbq-lite-age-ambiguous",
280
+ )
281
+ delete_dataset_args.add_argument("dataset", type=str, help="Name of the dataset")
282
+
283
+ # List dataset arguments
284
+ list_datasets_args = cmd2.Cmd2ArgumentParser(
285
+ description="List all datasets.",
286
+ epilog='Example:\n list_datasets -f "bbq"',
287
+ )
288
+
289
+ list_datasets_args.add_argument(
290
+ "-f",
291
+ "--find",
292
+ type=str,
293
+ help="Optional field to find dataset(s) with keyword",
294
+ nargs="?",
295
+ )
296
+
297
+ list_datasets_args.add_argument(
298
+ "-p",
299
+ "--pagination",
300
+ type=str,
301
+ help="Optional tuple to paginate dataset(s). E.g. (2,10) returns 2nd page with 10 items in each page.",
302
+ nargs="?",
303
+ )
304
+
305
+ # Convert dataset arguments
306
+ convert_dataset_args = cmd2.Cmd2ArgumentParser(
307
+ description="Convert your dataset. The 'name' argument will be slugified to create a unique identifier.",
91
308
  epilog=(
92
309
  "Examples:\n"
93
- "1. add_dataset 'dataset-name' 'A brief description' 'http://reference.com' 'MIT' 'csv' \"{'csv_file_path': '/path/to/your/file.csv'}\"\n" # noqa: E501
94
- "2. add_dataset 'dataset-name' 'A brief description' 'http://reference.com' 'MIT' 'hf' \"{'dataset_name': 'cais/mmlu', 'dataset_config': 'college_biology', 'split': 'test', 'input_col': ['question','choices'], 'target_col': 'answer'}\"" # noqa: E501
310
+ "convert_dataset 'dataset-name' 'A brief description' 'http://reference.com' 'MIT' '/path/to/your/file.csv'"
95
311
  ),
96
312
  )
97
- add_dataset_args.add_argument("name", type=str, help="Name of the new dataset")
98
- add_dataset_args.add_argument(
313
+ convert_dataset_args.add_argument("name", type=str, help="Name of the new dataset")
314
+ convert_dataset_args.add_argument(
99
315
  "description", type=str, help="Description of the new dataset"
100
316
  )
101
- add_dataset_args.add_argument(
317
+ convert_dataset_args.add_argument(
102
318
  "reference", type=str, help="Reference of the new dataset"
103
319
  )
104
- add_dataset_args.add_argument("license", type=str, help="License of the new dataset")
105
- add_dataset_args.add_argument(
106
- "method",
107
- type=str,
108
- choices=["hf", "csv"],
109
- help="Method to convert the new dataset. Choose either 'hf' or 'csv'.",
320
+ convert_dataset_args.add_argument(
321
+ "license", type=str, help="License of the new dataset"
322
+ )
323
+ convert_dataset_args.add_argument(
324
+ "csv_file_path", type=str, help="Path to your existing dataset"
325
+ )
326
+
327
+
328
+ # Download dataset arguments
329
+ download_dataset_args = cmd2.Cmd2ArgumentParser(
330
+ description="Download dataset from Hugging Face. The 'name' argument will be slugified to create a unique ID.",
331
+ epilog=(
332
+ "Examples:\n"
333
+ "download_dataset 'dataset-name' 'A brief description' 'http://reference.com' 'MIT' "
334
+ "\"{'dataset_name': 'cais/mmlu', 'dataset_config': 'college_biology', 'split': 'dev', "
335
+ "'input_col': ['question','choices'], 'target_col': 'answer'}\""
336
+ ),
337
+ )
338
+ download_dataset_args.add_argument("name", type=str, help="Name of the new dataset")
339
+ download_dataset_args.add_argument(
340
+ "description", type=str, help="Description of the new dataset"
341
+ )
342
+ download_dataset_args.add_argument(
343
+ "reference", type=str, help="Reference of the new dataset"
344
+ )
345
+ download_dataset_args.add_argument(
346
+ "license", type=str, help="License of the new dataset"
110
347
  )
111
- add_dataset_args.add_argument(
348
+ download_dataset_args.add_argument(
112
349
  "params",
113
350
  type=literal_eval,
114
351
  help=(
115
352
  "Params of the new dataset in dictionary format. For example: \n"
116
353
  "1. For 'csv' method: \"{'csv_file_path': '/path/to/your/file.csv'}\"\n"
117
- "2. For 'hf' method: \"{'dataset_name': 'cais_mmlu', 'dataset_config': 'college_biology', 'split': 'test', 'input_col': ['questions','choices'], 'target_col': 'answer'}\"" # noqa: E501
354
+ "2. For 'hf' method: \"{'dataset_name': 'cais_mmlu', 'dataset_config': 'college_biology', 'split': 'test', "
355
+ "'input_col': ['questions','choices'], 'target_col': 'answer'}\""
118
356
  ),
119
357
  )
@@ -5,6 +5,11 @@ from rich.console import Console
5
5
  from rich.table import Table
6
6
 
7
7
  from moonshot.api import api_delete_attack_module, api_get_all_attack_module_metadata
8
+ from moonshot.integrations.cli.cli_errors import (
9
+ ERROR_RED_TEAMING_LIST_ATTACK_MODULES_FIND_VALIDATION,
10
+ ERROR_RED_TEAMING_LIST_ATTACK_MODULES_PAGINATION_VALIDATION,
11
+ ERROR_RED_TEAMING_LIST_ATTACK_MODULES_PAGINATION_VALIDATION_1,
12
+ )
8
13
  from moonshot.integrations.cli.utils.process_data import filter_data
9
14
 
10
15
  console = Console()
@@ -28,6 +33,31 @@ def list_attack_modules(args) -> list | None:
28
33
  try:
29
34
  print("Listing attack modules may take a while...")
30
35
  attack_module_metadata_list = api_get_all_attack_module_metadata()
36
+
37
+ if args.find is not None:
38
+ if not isinstance(args.find, str) or not args.find:
39
+ raise TypeError(ERROR_RED_TEAMING_LIST_ATTACK_MODULES_FIND_VALIDATION)
40
+
41
+ if args.pagination is not None:
42
+ if not isinstance(args.pagination, str) or not args.pagination:
43
+ raise TypeError(
44
+ ERROR_RED_TEAMING_LIST_ATTACK_MODULES_PAGINATION_VALIDATION
45
+ )
46
+ try:
47
+ pagination = literal_eval(args.pagination)
48
+ if not (
49
+ isinstance(pagination, tuple)
50
+ and len(pagination) == 2
51
+ and all(isinstance(i, int) for i in pagination)
52
+ ):
53
+ raise ValueError(
54
+ ERROR_RED_TEAMING_LIST_ATTACK_MODULES_PAGINATION_VALIDATION_1
55
+ )
56
+ except (ValueError, SyntaxError):
57
+ raise ValueError(
58
+ ERROR_RED_TEAMING_LIST_ATTACK_MODULES_PAGINATION_VALIDATION_1
59
+ )
60
+
31
61
  keyword = args.find.lower() if args.find else ""
32
62
  pagination = literal_eval(args.pagination) if args.pagination else ()
33
63
 
@@ -41,7 +71,6 @@ def list_attack_modules(args) -> list | None:
41
71
 
42
72
  console.print("[red]There are no attack modules found.[/red]")
43
73
  return None
44
-
45
74
  except Exception as e:
46
75
  print(f"[list_attack_modules]: {str(e)}")
47
76
 
@@ -71,7 +71,7 @@ def create_app(cfg: providers.Configuration) -> CustomFastAPI:
71
71
  }
72
72
 
73
73
  app: CustomFastAPI = CustomFastAPI(
74
- title="Project Moonshot", version="0.4.8", **app_kwargs
74
+ title="Project Moonshot", version="0.4.9", **app_kwargs
75
75
  )
76
76
 
77
77
  if cfg.cors.enabled():
@@ -1,8 +1,8 @@
1
1
  from dependency_injector.wiring import Provide, inject
2
- from fastapi import APIRouter, Depends, HTTPException, Query
2
+ from fastapi import APIRouter, Depends, HTTPException
3
3
 
4
4
  from ..container import Container
5
- from ..schemas.dataset_create_dto import DatasetCreateDTO
5
+ from ..schemas.dataset_create_dto import CSV_Dataset_DTO, HF_Dataset_DTO
6
6
  from ..schemas.dataset_response_dto import DatasetResponseDTO
7
7
  from ..services.dataset_service import DatasetService
8
8
  from ..services.utils.exceptions_handler import ServiceException
@@ -10,27 +10,22 @@ from ..services.utils.exceptions_handler import ServiceException
10
10
  router = APIRouter(tags=["Datasets"])
11
11
 
12
12
 
13
- @router.post("/api/v1/datasets")
13
+ @router.post("/api/v1/datasets/csv")
14
14
  @inject
15
- def create_dataset(
16
- dataset_data: DatasetCreateDTO,
17
- method: str = Query(
18
- ...,
19
- description="The method to use for creating the dataset. Supported methods are 'hf' and 'csv'.",
20
- ),
15
+ def convert_dataset(
16
+ dataset_data: CSV_Dataset_DTO,
21
17
  dataset_service: DatasetService = Depends(Provide[Container.dataset_service]),
22
18
  ) -> str:
23
19
  """
24
- Create a new dataset using the specified method.
20
+ Convert a CSV dataset to the desired format.
25
21
 
26
22
  Args:
27
- dataset_data (DatasetCreateDTO): The data required to create the dataset.
28
- method (str): The method to use for creating the dataset. Supported methods are "hf" and "csv".
29
- dataset_service (DatasetService, optional): The service responsible for creating the dataset.
23
+ dataset_data (CSV_Dataset_DTO): The data required to convert the dataset.
24
+ dataset_service (DatasetService, optional): The service responsible for converting the dataset.
30
25
  Defaults to Depends(Provide[Container.dataset_service]).
31
26
 
32
27
  Returns:
33
- dict: A message indicating the dataset was created successfully.
28
+ str: The path to the newly created dataset.
34
29
 
35
30
  Raises:
36
31
  HTTPException: An error with status code 404 if the dataset file is not found.
@@ -38,19 +33,58 @@ def create_dataset(
38
33
  An error with status code 500 for any other server-side error.
39
34
  """
40
35
  try:
41
- return dataset_service.create_dataset(dataset_data, method)
36
+ return dataset_service.convert_dataset(dataset_data)
42
37
  except ServiceException as e:
43
38
  if e.error_code == "FileNotFound":
44
39
  raise HTTPException(
45
- status_code=404, detail=f"Failed to retrieve datasets: {e.msg}"
40
+ status_code=404, detail=f"Failed to convert dataset: {e.msg}"
46
41
  )
47
42
  elif e.error_code == "ValidationError":
48
43
  raise HTTPException(
49
- status_code=400, detail=f"Failed to retrieve datasets: {e.msg}"
44
+ status_code=400, detail=f"Failed to convert dataset: {e.msg}"
50
45
  )
51
46
  else:
52
47
  raise HTTPException(
53
- status_code=500, detail=f"Failed to retrieve datasets: {e.msg}"
48
+ status_code=500, detail=f"Failed to convert dataset: {e.msg}"
49
+ )
50
+
51
+
52
+ @router.post("/api/v1/datasets/hf")
53
+ @inject
54
+ def download_dataset(
55
+ dataset_data: HF_Dataset_DTO,
56
+ dataset_service: DatasetService = Depends(Provide[Container.dataset_service]),
57
+ ) -> str:
58
+ """
59
+ Download a dataset from Hugging Face using the provided dataset data.
60
+
61
+ Args:
62
+ dataset_data (HF_Dataset_DTO): The data required to download the dataset.
63
+ dataset_service (DatasetService, optional): The service responsible for downloading the dataset.
64
+ Defaults to Depends(Provide[Container.dataset_service]).
65
+
66
+ Returns:
67
+ str: The path to the newly downloaded dataset.
68
+
69
+ Raises:
70
+ HTTPException: An error with status code 404 if the dataset file is not found.
71
+ An error with status code 400 if there is a validation error.
72
+ An error with status code 500 for any other server-side error.
73
+ """
74
+ try:
75
+ return dataset_service.download_dataset(dataset_data)
76
+ except ServiceException as e:
77
+ if e.error_code == "FileNotFound":
78
+ raise HTTPException(
79
+ status_code=404, detail=f"Failed to download dataset: {e.msg}"
80
+ )
81
+ elif e.error_code == "ValidationError":
82
+ raise HTTPException(
83
+ status_code=400, detail=f"Failed to download dataset: {e.msg}"
84
+ )
85
+ else:
86
+ raise HTTPException(
87
+ status_code=500, detail=f"Failed to download dataset: {e.msg}"
54
88
  )
55
89
 
56
90
 
@@ -7,3 +7,5 @@ from moonshot.src.cookbooks.cookbook_arguments import (
7
7
 
8
8
  class CookbookResponseModel(CookbookPydanticModel):
9
9
  total_prompt_in_cookbook: Optional[int] = None
10
+ total_dataset_in_cookbook: Optional[int] = None
11
+ endpoint_required: Optional[list[str]] = None
@@ -8,11 +8,21 @@ from moonshot.src.datasets.dataset_arguments import (
8
8
  )
9
9
 
10
10
 
11
- class DatasetCreateDTO(DatasetPydanticModel):
12
- id: Optional[str] = None
13
- examples: Iterator[dict] = None
11
+ class CSV_Dataset_DTO(DatasetPydanticModel):
12
+ id: Optional[str] = None # Not a required from user
13
+ examples: Optional[Iterator[dict]] = None # Not a required from user
14
14
  name: str = Field(..., min_length=1)
15
15
  description: str = Field(default="", min_length=1)
16
16
  license: Optional[str] = ""
17
17
  reference: Optional[str] = ""
18
- params: dict
18
+ csv_file_path: str = Field(..., min_length=1)
19
+
20
+
21
+ class HF_Dataset_DTO(DatasetPydanticModel):
22
+ id: Optional[str] = None # Not a required from user
23
+ examples: Optional[Iterator[dict]] = None # Not a required from user
24
+ name: str = Field(..., min_length=1)
25
+ description: str = Field(default="", min_length=1)
26
+ license: Optional[str] = ""
27
+ reference: Optional[str] = ""
28
+ params: dict = Field(..., min_length=1)
@@ -5,3 +5,4 @@ from moonshot.src.recipes.recipe_arguments import RecipeArguments as RecipePydan
5
5
 
6
6
  class RecipeResponseModel(RecipePydanticModel):
7
7
  total_prompt_in_recipe: Optional[int] = None
8
+ endpoint_required: Optional[list[str]] = None