aiverify-moonshot 0.6.1__py3-none-any.whl → 0.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: aiverify-moonshot
3
- Version: 0.6.1
3
+ Version: 0.6.2
4
4
  Summary: AI Verify advances Gen AI testing with Project Moonshot.
5
5
  Project-URL: Repository, https://github.com/aiverify-foundation/moonshot
6
6
  Project-URL: Documentation, https://aiverify-foundation.github.io/moonshot/
@@ -23,6 +23,7 @@ Requires-Dist: pandas>=2.2.2
23
23
  Requires-Dist: pydantic==2.8.2
24
24
  Requires-Dist: pyparsing>=3.1.4
25
25
  Requires-Dist: python-dotenv>=1.0.1
26
+ Requires-Dist: python-multipart>=0.0.9
26
27
  Requires-Dist: python-slugify>=8.0.4
27
28
  Requires-Dist: tenacity>=8.5.0
28
29
  Requires-Dist: xxhash>=3.5.0
@@ -47,7 +48,7 @@ Description-Content-Type: text/markdown
47
48
 
48
49
  ![Moonshot Logo](https://github.com/aiverify-foundation/moonshot/raw/main/misc/aiverify-moonshot-logo.png)
49
50
 
50
- **Version 0.6.1**
51
+ **Version 0.6.2**
51
52
 
52
53
  A simple and modular tool to evaluate any LLM application.
53
54
 
@@ -1,5 +1,5 @@
1
1
  moonshot/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- moonshot/__main__.py,sha256=1lpqD3azEnA0wMwIJ0K6cJd5sQFl-v3M8M0ehr-wrAU,11801
2
+ moonshot/__main__.py,sha256=5lD240TY2YfG_p0sQtoHsw9V4DWWI3otpzIYP1hUbC8,12056
3
3
  moonshot/api.py,sha256=wvad-BcKDKEu25c6-YrsBx_uPiLKIBRsbwgThT50Uh0,4877
4
4
  moonshot/integrations/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
5
  moonshot/integrations/cli/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -34,7 +34,7 @@ moonshot/integrations/cli/utils/process_data.py,sha256=QVL5vp2_8ZgGicmCAdeYEHkeb
34
34
  moonshot/integrations/web_api/.env.dev,sha256=0z5_Ut8rF-UqFZtgjkH2qoqORhD5_nSs2w_OeX2SteI,182
35
35
  moonshot/integrations/web_api/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
36
36
  moonshot/integrations/web_api/__main__.py,sha256=MdnLi_ZF-olAAEJwTPU1iGYFYwo-fNWNT2qfchkH3y4,2050
37
- moonshot/integrations/web_api/app.py,sha256=Jr6mYvfjiPKMUWU58QxvYS-bpvkUotd728t6up3ZS-w,3651
37
+ moonshot/integrations/web_api/app.py,sha256=6EkH5LAwEE5iQhGGDl5TIXDv_mmgWlpXCgUM9fH0Uow,3651
38
38
  moonshot/integrations/web_api/container.py,sha256=DVkJG_qm7ItcG6tgMYOqIj07wpKhPWOOfy6-bEv72y4,5915
39
39
  moonshot/integrations/web_api/logging_conf.py,sha256=t3EGRV6tZhV732KXe8_Tiy0fiwVAWxZX5Tt8VTgrrfg,3388
40
40
  moonshot/integrations/web_api/log/.gitkeep,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -45,7 +45,7 @@ moonshot/integrations/web_api/routes/benchmark_result.py,sha256=WZ_dI8qT4dli9hKP
45
45
  moonshot/integrations/web_api/routes/bookmark.py,sha256=aHUT86Llbzqo1CT3Dy7ciIhxVEzu1YgZk_VkxVeOZ3s,6304
46
46
  moonshot/integrations/web_api/routes/context_strategy.py,sha256=kJTpjrwxfYGyBLY_hAgpHOMZMtjV5Z6vpu7RIdHDylg,4828
47
47
  moonshot/integrations/web_api/routes/cookbook.py,sha256=oddmcdfhgH3qZb4_ThfUk8SBKmHOt51dFlAHubQh2fQ,8648
48
- moonshot/integrations/web_api/routes/dataset.py,sha256=9o8K81xn-fKappgPU4pcRSnLUl2J79MyR5PZijY5ueI,7312
48
+ moonshot/integrations/web_api/routes/dataset.py,sha256=YhX4xNEgI5KwbmQ-SsLc5SpD3-9EZ6-e2AxBh1QJnUc,8238
49
49
  moonshot/integrations/web_api/routes/endpoint.py,sha256=ZFx0WUe3-GGdmBz_hzYiOmjvJHN4PQy_8lCKJMBjxcE,10715
50
50
  moonshot/integrations/web_api/routes/metric.py,sha256=f_HHexxKUfqFE5FkeCwRh8n36H2mREtLnK2pDrw3A-w,2856
51
51
  moonshot/integrations/web_api/routes/prompt_template.py,sha256=M3adeNeWvLQJJlFQ0uZqSXEuNpTcagApnuqWvLiL1mg,4890
@@ -57,7 +57,7 @@ moonshot/integrations/web_api/schemas/benchmark_runner_dto.py,sha256=IIn6KeMcwxT
57
57
  moonshot/integrations/web_api/schemas/bookmark_create_dto.py,sha256=C78vG8UG02N7Cmt6RSuS8e4sX_G-MLCiAWT-cF5BE8s,374
58
58
  moonshot/integrations/web_api/schemas/cookbook_create_dto.py,sha256=wXC0tu1Q8SpSI3Qk0xKPj1vKsOJEYmfPgU4rl6QopUY,826
59
59
  moonshot/integrations/web_api/schemas/cookbook_response_model.py,sha256=COLvaE4Hrz_w-C_HQkB7feztweIr0wkY9h8N6NKNIr8,332
60
- moonshot/integrations/web_api/schemas/dataset_create_dto.py,sha256=GRqIIlQZEpzzEXwAFcbDlxOuKg0JZ399axBjg34LMp8,915
60
+ moonshot/integrations/web_api/schemas/dataset_create_dto.py,sha256=923Etaq2J1-S9-Xvh4tcrMIHpRrJ381tA8tj0M2kj1A,907
61
61
  moonshot/integrations/web_api/schemas/dataset_response_dto.py,sha256=s5x4-UXEWccWhK42E0FPXiHG6VqjuFuph-2t5atEkg4,171
62
62
  moonshot/integrations/web_api/schemas/endpoint_create_dto.py,sha256=WS8AfRybrweoOgZx6K6jiNy1Z6J3IZS1PUNnrRxGKyM,678
63
63
  moonshot/integrations/web_api/schemas/endpoint_response_model.py,sha256=OmmM2uaPSgB2aqPFfkhseKkI5OKCKilXR19gDmwFlLc,321
@@ -80,7 +80,7 @@ moonshot/integrations/web_api/services/benchmarking_service.py,sha256=lJZeNTqxEP
80
80
  moonshot/integrations/web_api/services/bookmark_service.py,sha256=jI9nXs1hjzO0CLG2LKaXSzDApLThkfCvPUkaNNV9A5A,3546
81
81
  moonshot/integrations/web_api/services/context_strategy_service.py,sha256=6YKnnG8JlE_1nlnr4Hq7rgz-sxI6oQglK0STaWPFQxQ,710
82
82
  moonshot/integrations/web_api/services/cookbook_service.py,sha256=37iJZn4ybe9tugBWB99g1SAN1YUtkmaq2mLQWj_HBQo,8736
83
- moonshot/integrations/web_api/services/dataset_service.py,sha256=ZWb3FqyDkA0C9qhlQ3X_zR0ohAlwlLsJi-mgKLvXpnI,2407
83
+ moonshot/integrations/web_api/services/dataset_service.py,sha256=FUXLgU32nghoLWWXBA_4GzeQb8eK31tjbvLu4OJBxoc,2441
84
84
  moonshot/integrations/web_api/services/endpoint_service.py,sha256=N5SXNAh44UNeBpMhA9baL0VZoTx4sHzpy4y7-Ch8O4E,2395
85
85
  moonshot/integrations/web_api/services/metric_service.py,sha256=xWC5Dk8aiU7tuHsxYedTTrEkbA3Ug1pV2nbaBas6cAg,456
86
86
  moonshot/integrations/web_api/services/prompt_template_service.py,sha256=5ds7pKDB2R0_0slVDwsCRIpIVdsgpqhI-3wQqSYcpuE,1226
@@ -103,7 +103,7 @@ moonshot/src/api/api_connector.py,sha256=Q_of-aHPuWkbefMJq4uXctJl89G2Tt6J_HfSuf1
103
103
  moonshot/src/api/api_connector_endpoint.py,sha256=lwfhlWNBJ6QotqffmURtjRmxfzbBlSIAZupeSpMt9VU,5584
104
104
  moonshot/src/api/api_context_strategy.py,sha256=uRIfNjKJ_Wk9nSrvbPRfrdQLpG0K6kH9rl5tmmHui40,2151
105
105
  moonshot/src/api/api_cookbook.py,sha256=V05abHvzElrO7LkSyhOMcAHEfCfIgopd6L0cSSO3Dro,6722
106
- moonshot/src/api/api_dataset.py,sha256=i2KwnZ-6fTm_tyn8cRw8iesrGi7_Nh0-1bFuN7m0TVo,4066
106
+ moonshot/src/api/api_dataset.py,sha256=POpkrmo_vLeUPOkuEShS8eadXdRluzbvOlUQ1DhsNqM,5521
107
107
  moonshot/src/api/api_environment_variables.py,sha256=wRx6rm95ItyL_uKUAYfSjcPZNbRxKl1GGS4PpWcTE1s,712
108
108
  moonshot/src/api/api_metrics.py,sha256=x5DiysTYQsMmcAS2y2XpgvrPobZk7GT2rhO-MaIRun4,1603
109
109
  moonshot/src/api/api_prompt_template.py,sha256=HQUl7-HGcxA620cY0vDqqo7CoY9uONkXIOlolduIgbE,1959
@@ -128,7 +128,7 @@ moonshot/src/cookbooks/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3
128
128
  moonshot/src/cookbooks/cookbook.py,sha256=DdZwRGx5-xTDIKcXtZRpp7Qb9Mm9dNGwXWLQXoQrBBo,10412
129
129
  moonshot/src/cookbooks/cookbook_arguments.py,sha256=SmNG8D5qN2K2dcImDaSBPHsna0Gy60ZR49_eTKEsvVU,1445
130
130
  moonshot/src/datasets/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
131
- moonshot/src/datasets/dataset.py,sha256=-_uhjR7zi50nkLu1WWlPCCWr14VwFUDfhTeeBHOhb70,14236
131
+ moonshot/src/datasets/dataset.py,sha256=sAPPdUJTB_aAFw3lmp2CY9gqbF2Mu8qQ9Soc1syMNGg,14940
132
132
  moonshot/src/datasets/dataset_arguments.py,sha256=rUcxxo2WTcHhLLV-WoixjOfT_Ju7hFCq811_ctjegt8,1751
133
133
  moonshot/src/metrics/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
134
134
  moonshot/src/metrics/metric.py,sha256=llqJYnwtllJRMfNhRRbKWjhzKymY961yR3Jw24COR-Y,7512
@@ -172,9 +172,9 @@ moonshot/src/utils/import_modules.py,sha256=T9zTN59PFnvY2rjyWhSV9KSIAHxWV1pyBemF
172
172
  moonshot/src/utils/log.py,sha256=YNgD7Eh2OT36XlmVBKCGUTAh9TRp4Akfe4kDdvHASgs,2502
173
173
  moonshot/src/utils/pagination.py,sha256=5seymyRoqyENIhKllAatr1T91kMCGFslcvRnJHyMSvc,814
174
174
  moonshot/src/utils/timeit.py,sha256=TvuF0w8KWhp0oZFY0cUU3UY0xlGKjchb0OkfYfgVTlc,866
175
- aiverify_moonshot-0.6.1.dist-info/METADATA,sha256=Um1dy4p7R1ZqYm9X_wnmzsVi2qclr6trbA11ijKYiRs,12419
176
- aiverify_moonshot-0.6.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
177
- aiverify_moonshot-0.6.1.dist-info/licenses/AUTHORS.md,sha256=mmAbe3i3sT8JZHJMBhxp3i1xRehV0g7WB4T_eyIBuBs,59
178
- aiverify_moonshot-0.6.1.dist-info/licenses/LICENSE.md,sha256=53izDRmJZZCjpYGfyLqlxnGQN-aNWBxasuzuMXC5Ias,11347
179
- aiverify_moonshot-0.6.1.dist-info/licenses/NOTICES.md,sha256=vS1zZYAnGjCJdwQ13xv3b2zc30wOS98ZnCKluT-AhHs,123266
180
- aiverify_moonshot-0.6.1.dist-info/RECORD,,
175
+ aiverify_moonshot-0.6.2.dist-info/METADATA,sha256=nijW6md0PES6huZ6YbdfwQkiD9BN6gH80TWXabbamFs,12458
176
+ aiverify_moonshot-0.6.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
177
+ aiverify_moonshot-0.6.2.dist-info/licenses/AUTHORS.md,sha256=mmAbe3i3sT8JZHJMBhxp3i1xRehV0g7WB4T_eyIBuBs,59
178
+ aiverify_moonshot-0.6.2.dist-info/licenses/LICENSE.md,sha256=53izDRmJZZCjpYGfyLqlxnGQN-aNWBxasuzuMXC5Ias,11347
179
+ aiverify_moonshot-0.6.2.dist-info/licenses/NOTICES.md,sha256=vS1zZYAnGjCJdwQ13xv3b2zc30wOS98ZnCKluT-AhHs,123266
180
+ aiverify_moonshot-0.6.2.dist-info/RECORD,,
moonshot/__main__.py CHANGED
@@ -116,6 +116,13 @@ def download_nltk_resources() -> None:
116
116
  raise
117
117
 
118
118
 
119
+ def download_spacy_model() -> None:
120
+ """
121
+ Downloads the en_core_web_lg model using the spacy module (for entity processor module).
122
+ """
123
+ subprocess.run(["python", "-m", "spacy", "download", "en_core_web_lg"])
124
+
125
+
119
126
  def moonshot_data_installation(unattended: bool, overwrite: bool) -> None:
120
127
  """
121
128
  Install Moonshot Data from GitHub.
@@ -175,10 +182,12 @@ def moonshot_data_installation(unattended: bool, overwrite: bool) -> None:
175
182
  if os.path.exists("requirements.txt"):
176
183
  run_subprocess(["pip", "install", "-r", "requirements.txt"], check=True)
177
184
  download_nltk_resources()
185
+ download_spacy_model()
178
186
 
179
187
  # Change back to the base directory
180
188
  os.chdir("..")
181
189
 
190
+
182
191
  def check_node() -> bool:
183
192
  """
184
193
  Check if Node.js is installed on the user's machine.
@@ -71,7 +71,7 @@ def create_app(cfg: providers.Configuration) -> CustomFastAPI:
71
71
  }
72
72
 
73
73
  app: CustomFastAPI = CustomFastAPI(
74
- title="Project Moonshot", version="0.6.1", **app_kwargs
74
+ title="Project Moonshot", version="0.6.2", **app_kwargs
75
75
  )
76
76
 
77
77
  if cfg.cors.enabled():
@@ -1,5 +1,8 @@
1
+ import os
2
+ import tempfile
3
+
1
4
  from dependency_injector.wiring import Provide, inject
2
- from fastapi import APIRouter, Depends, HTTPException
5
+ from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
3
6
 
4
7
  from ..container import Container
5
8
  from ..schemas.dataset_create_dto import CSV_Dataset_DTO, HF_Dataset_DTO
@@ -10,10 +13,14 @@ from ..services.utils.exceptions_handler import ServiceException
10
13
  router = APIRouter(tags=["Datasets"])
11
14
 
12
15
 
13
- @router.post("/api/v1/datasets/csv")
16
+ @router.post("/api/v1/datasets/file")
14
17
  @inject
15
- def convert_dataset(
16
- dataset_data: CSV_Dataset_DTO,
18
+ async def upload_dataset(
19
+ file: UploadFile = File(...),
20
+ name: str = Form(..., min_length=1),
21
+ description: str = Form(default="", min_length=1),
22
+ license: str = Form(default=""),
23
+ reference: str = Form(default=""),
17
24
  dataset_service: DatasetService = Depends(Provide[Container.dataset_service]),
18
25
  ) -> str:
19
26
  """
@@ -32,7 +39,24 @@ def convert_dataset(
32
39
  An error with status code 400 if there is a validation error.
33
40
  An error with status code 500 for any other server-side error.
34
41
  """
42
+
43
+ # Create a temporary file with a secure random name
44
+ with tempfile.NamedTemporaryFile(
45
+ delete=False, suffix=os.path.splitext(file.filename)[1]
46
+ ) as tmp_file:
47
+ content = await file.read()
48
+ tmp_file.write(content)
49
+ temp_file_path = tmp_file.name
50
+
35
51
  try:
52
+ # Create the DTO with the form data including optional fields
53
+ dataset_data = CSV_Dataset_DTO(
54
+ name=name,
55
+ description=description,
56
+ license=license,
57
+ reference=reference,
58
+ file_path=temp_file_path,
59
+ )
36
60
  return dataset_service.convert_dataset(dataset_data)
37
61
  except ServiceException as e:
38
62
  if e.error_code == "FileNotFound":
@@ -47,6 +71,10 @@ def convert_dataset(
47
71
  raise HTTPException(
48
72
  status_code=500, detail=f"Failed to convert dataset: {e.msg}"
49
73
  )
74
+ finally:
75
+ # Clean up the temporary file
76
+ if os.path.exists(temp_file_path):
77
+ os.unlink(temp_file_path)
50
78
 
51
79
 
52
80
  @router.post("/api/v1/datasets/hf")
@@ -8,13 +8,13 @@ from moonshot.src.datasets.dataset_arguments import (
8
8
 
9
9
 
10
10
  class CSV_Dataset_DTO(DatasetPydanticModel):
11
- id: Optional[str] = None # Not a required from user
12
- examples: Optional[Any] = None # Not a required from user
11
+ id: Optional[str] = None # Not required from user
12
+ examples: Optional[Any] = None # Not required from user
13
13
  name: str = Field(..., min_length=1)
14
14
  description: str = Field(default="", min_length=1)
15
15
  license: Optional[str] = ""
16
16
  reference: Optional[str] = ""
17
- csv_file_path: str = Field(..., min_length=1)
17
+ file_path: str = Field(..., min_length=1)
18
18
 
19
19
 
20
20
  class HF_Dataset_DTO(DatasetPydanticModel):
@@ -4,6 +4,7 @@ from ..schemas.dataset_response_dto import DatasetResponseDTO
4
4
  from ..services.base_service import BaseService
5
5
  from ..services.utils.exceptions_handler import exception_handler
6
6
  from .utils.file_manager import copy_file
7
+ import os
7
8
 
8
9
 
9
10
  class DatasetService(BaseService):
@@ -16,7 +17,7 @@ class DatasetService(BaseService):
16
17
  dataset_data (CSV_Dataset_DTO): The data required to convert the dataset.
17
18
 
18
19
  Returns:
19
- str: The path to the newly created dataset.
20
+ str: The filename of the newly created dataset.
20
21
 
21
22
  Raises:
22
23
  Exception: If an error occurs during dataset conversion.
@@ -27,9 +28,9 @@ class DatasetService(BaseService):
27
28
  description=dataset_data.description,
28
29
  reference=dataset_data.reference,
29
30
  license=dataset_data.license,
30
- csv_file_path=dataset_data.csv_file_path,
31
+ file_path=dataset_data.file_path,
31
32
  )
32
- return copy_file(new_ds_path)
33
+ return os.path.splitext(os.path.basename(new_ds_path))[0]
33
34
 
34
35
  @exception_handler
35
36
  def download_dataset(self, dataset_data: HF_Dataset_DTO) -> str:
@@ -1,3 +1,6 @@
1
+ import json
2
+ import os
3
+
1
4
  from pydantic import validate_call
2
5
 
3
6
  from moonshot.src.datasets.dataset import Dataset
@@ -81,10 +84,10 @@ def api_download_dataset(
81
84
 
82
85
 
83
86
  def api_convert_dataset(
84
- name: str, description: str, reference: str, license: str, csv_file_path: str
87
+ name: str, description: str, reference: str, license: str, file_path: str
85
88
  ) -> str:
86
89
  """
87
- Converts a CSV file to a dataset and creates a new dataset with the provided details.
90
+ Converts a CSV or JSON file to a dataset and creates a new dataset with the provided details.
88
91
 
89
92
  This function takes the name, description, reference, and license for a new dataset as input, along with the file
90
93
  path to a CSV file. It then creates a new DatasetArguments object with these details and an empty id. The id is left
@@ -96,18 +99,55 @@ def api_convert_dataset(
96
99
  description (str): A brief description of the new dataset.
97
100
  reference (str): A reference link for the new dataset.
98
101
  license (str): The license of the new dataset.
99
- csv_file_path (str): The file path to the CSV file.
102
+ file_path (str): The file path to the CSV or JSONfile.
100
103
 
101
104
  Returns:
102
105
  str: The ID of the newly created dataset.
103
106
  """
104
- examples = Dataset.convert_data(csv_file_path)
105
- ds_args = DatasetArguments(
106
- id="",
107
- name=name,
108
- description=description,
109
- reference=reference,
110
- license=license,
111
- examples=examples,
112
- )
107
+ ds_args = None
108
+
109
+ # Check if file is in a supported format
110
+ if not (file_path.endswith(".json") or file_path.endswith(".csv")):
111
+ raise ValueError("Unsupported file format. Please provide a JSON or CSV file.")
112
+
113
+ # Check that file is not empty
114
+ if os.path.getsize(file_path) == 0:
115
+ raise ValueError("The uploaded file is empty.")
116
+
117
+ # if file is already in json format
118
+ if file_path.endswith(".json"):
119
+ json_data = json.load(open(file_path))
120
+
121
+ try:
122
+ if "examples" in json_data and json_data["examples"]:
123
+ ds_args = DatasetArguments(
124
+ id="",
125
+ name=json_data.get("name", name),
126
+ description=json_data.get("description", description),
127
+ reference=json_data.get("reference", reference),
128
+ license=json_data.get("license", license),
129
+ examples=iter(json_data["examples"]),
130
+ )
131
+ else:
132
+ raise KeyError(
133
+ "examples is either empty or this key is not in the JSON file. "
134
+ "Please ensure that this field is present."
135
+ )
136
+ except Exception as e:
137
+ raise e
138
+
139
+ # if file is in csv format, convert data
140
+ else:
141
+ try:
142
+ examples = Dataset.convert_data(file_path)
143
+ ds_args = DatasetArguments(
144
+ id="",
145
+ name=name,
146
+ description=description,
147
+ reference=reference,
148
+ license=license,
149
+ examples=examples,
150
+ )
151
+ except Exception as e:
152
+ raise e
113
153
  return Dataset.create(ds_args)
@@ -60,7 +60,6 @@ class Dataset:
60
60
  }
61
61
 
62
62
  examples = ds_args.examples
63
-
64
63
  # Write as JSON output
65
64
  file_path = Storage.create_object_with_iterator(
66
65
  EnvVariables.DATASETS.name,
@@ -91,9 +90,26 @@ class Dataset:
91
90
  Returns:
92
91
  Iterator[dict]: An iterator of dictionaries representing the CSV data.
93
92
  """
93
+ # validate headers
94
+ df_header = pd.read_csv(csv_file_path, nrows=1)
95
+ headers = df_header.columns.tolist()
96
+ required_headers = ["input", "target"]
97
+ if not all(header in headers for header in required_headers):
98
+ raise KeyError(
99
+ f"Required headers not found in the dataset. Required headers are {required_headers}."
100
+ )
101
+
94
102
  df = pd.read_csv(csv_file_path, chunksize=1)
95
- for chunk in df:
96
- yield chunk.to_dict("records")[0]
103
+ # validate dataset
104
+ first_chunk = next(df, None)
105
+ if first_chunk is None or first_chunk.empty:
106
+ raise ValueError("The uploaded file does not contain any data.")
107
+
108
+ # Reset df after performing next(df)
109
+ df = pd.read_csv(csv_file_path, chunksize=1)
110
+
111
+ result = [chunk.to_dict("records")[0] for chunk in df]
112
+ return iter(result)
97
113
 
98
114
  @staticmethod
99
115
  @validate_call