rapidata 2.29.1__py3-none-any.whl → 2.31.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rapidata might be problematic. Click here for more details.

Files changed (36) hide show
  1. rapidata/__init__.py +1 -1
  2. rapidata/api_client/__init__.py +5 -0
  3. rapidata/api_client/api/benchmark_api.py +550 -0
  4. rapidata/api_client/api/dataset_api.py +14 -14
  5. rapidata/api_client/api/leaderboard_api.py +562 -0
  6. rapidata/api_client/api/validation_set_api.py +349 -6
  7. rapidata/api_client/models/__init__.py +5 -0
  8. rapidata/api_client/models/file_type.py +1 -0
  9. rapidata/api_client/models/file_type_metadata.py +2 -2
  10. rapidata/api_client/models/file_type_metadata_model.py +2 -2
  11. rapidata/api_client/models/get_standing_by_id_result.py +4 -2
  12. rapidata/api_client/models/participant_by_benchmark.py +2 -2
  13. rapidata/api_client/models/participant_status.py +1 -0
  14. rapidata/api_client/models/prompt_by_benchmark_result.py +19 -3
  15. rapidata/api_client/models/run_status.py +39 -0
  16. rapidata/api_client/models/runs_by_leaderboard_result.py +110 -0
  17. rapidata/api_client/models/runs_by_leaderboard_result_paged_result.py +105 -0
  18. rapidata/api_client/models/standing_by_leaderboard.py +5 -3
  19. rapidata/api_client/models/update_benchmark_name_model.py +87 -0
  20. rapidata/api_client/models/update_leaderboard_name_model.py +87 -0
  21. rapidata/api_client_README.md +10 -0
  22. rapidata/rapidata_client/benchmark/leaderboard/rapidata_leaderboard.py +9 -0
  23. rapidata/rapidata_client/benchmark/rapidata_benchmark.py +66 -12
  24. rapidata/rapidata_client/benchmark/rapidata_benchmark_manager.py +24 -6
  25. rapidata/rapidata_client/filter/__init__.py +1 -0
  26. rapidata/rapidata_client/filter/_base_filter.py +20 -0
  27. rapidata/rapidata_client/filter/and_filter.py +30 -0
  28. rapidata/rapidata_client/filter/rapidata_filters.py +6 -3
  29. rapidata/rapidata_client/order/_rapidata_order_builder.py +13 -9
  30. rapidata/rapidata_client/order/rapidata_order_manager.py +2 -13
  31. rapidata/rapidata_client/validation/rapids/rapids.py +29 -47
  32. rapidata/rapidata_client/validation/validation_set_manager.py +10 -3
  33. {rapidata-2.29.1.dist-info → rapidata-2.31.0.dist-info}/METADATA +1 -1
  34. {rapidata-2.29.1.dist-info → rapidata-2.31.0.dist-info}/RECORD +36 -30
  35. {rapidata-2.29.1.dist-info → rapidata-2.31.0.dist-info}/LICENSE +0 -0
  36. {rapidata-2.29.1.dist-info → rapidata-2.31.0.dist-info}/WHEEL +0 -0
@@ -1,3 +1,4 @@
1
+ import re
1
2
  from rapidata.api_client.models.root_filter import RootFilter
2
3
  from rapidata.api_client.models.filter import Filter
3
4
  from rapidata.api_client.models.query_model import QueryModel
@@ -5,6 +6,10 @@ from rapidata.api_client.models.page_info import PageInfo
5
6
  from rapidata.api_client.models.create_leaderboard_model import CreateLeaderboardModel
6
7
  from rapidata.api_client.models.create_benchmark_participant_model import CreateBenchmarkParticipantModel
7
8
  from rapidata.api_client.models.submit_prompt_model import SubmitPromptModel
9
+ from rapidata.api_client.models.submit_prompt_model_prompt_asset import SubmitPromptModelPromptAsset
10
+ from rapidata.api_client.models.url_asset_input import UrlAssetInput
11
+ from rapidata.api_client.models.file_asset_model import FileAssetModel
12
+ from rapidata.api_client.models.source_url_metadata_model import SourceUrlMetadataModel
8
13
 
9
14
  from rapidata.rapidata_client.logging import logger
10
15
  from rapidata.service.openapi_service import OpenAPIService
@@ -29,7 +34,8 @@ class RapidataBenchmark:
29
34
  self.name = name
30
35
  self.id = id
31
36
  self.__openapi_service = openapi_service
32
- self.__prompts: list[str] = []
37
+ self.__prompts: list[str | None] = []
38
+ self.__prompt_assets: list[str | None] = []
33
39
  self.__leaderboards: list[RapidataLeaderboard] = []
34
40
  self.__identifiers: list[str] = []
35
41
 
@@ -53,8 +59,16 @@ class RapidataBenchmark:
53
59
 
54
60
  total_pages = prompts_result.total_pages
55
61
 
56
- self.__prompts.extend([prompt.prompt for prompt in prompts_result.items])
57
- self.__identifiers.extend([prompt.identifier for prompt in prompts_result.items])
62
+ for prompt in prompts_result.items:
63
+ self.__prompts.append(prompt.prompt)
64
+ self.__identifiers.append(prompt.identifier)
65
+ if prompt.prompt_asset is None:
66
+ self.__prompt_assets.append(None)
67
+ else:
68
+ assert isinstance(prompt.prompt_asset.actual_instance, FileAssetModel)
69
+ source_url = prompt.prompt_asset.actual_instance.metadata["sourceUrl"].actual_instance
70
+ assert isinstance(source_url, SourceUrlMetadataModel)
71
+ self.__prompt_assets.append(source_url.url)
58
72
 
59
73
  if current_page >= total_pages:
60
74
  break
@@ -62,7 +76,14 @@ class RapidataBenchmark:
62
76
  current_page += 1
63
77
 
64
78
  @property
65
- def prompts(self) -> list[str]:
79
+ def identifiers(self) -> list[str]:
80
+ if not self.__identifiers:
81
+ self.__instantiate_prompts()
82
+
83
+ return self.__identifiers
84
+
85
+ @property
86
+ def prompts(self) -> list[str | None]:
66
87
  """
67
88
  Returns the prompts that are registered for the leaderboard.
68
89
  """
@@ -72,11 +93,14 @@ class RapidataBenchmark:
72
93
  return self.__prompts
73
94
 
74
95
  @property
75
- def identifiers(self) -> list[str]:
76
- if not self.__identifiers:
96
+ def prompt_assets(self) -> list[str | None]:
97
+ """
98
+ Returns the prompt assets that are registered for the benchmark.
99
+ """
100
+ if not self.__prompt_assets:
77
101
  self.__instantiate_prompts()
78
102
 
79
- return self.__identifiers
103
+ return self.__prompt_assets
80
104
 
81
105
  @property
82
106
  def leaderboards(self) -> list[RapidataLeaderboard]:
@@ -112,6 +136,7 @@ class RapidataBenchmark:
112
136
  leaderboard.name,
113
137
  leaderboard.instruction,
114
138
  leaderboard.show_prompt,
139
+ leaderboard.show_prompt_asset,
115
140
  leaderboard.is_inversed,
116
141
  leaderboard.min_responses,
117
142
  leaderboard.response_budget,
@@ -126,24 +151,49 @@ class RapidataBenchmark:
126
151
 
127
152
  return self.__leaderboards
128
153
 
129
- def add_prompt(self, identifier: str, prompt: str):
154
+ def add_prompt(self, identifier: str, prompt: str | None = None, asset: str | None = None):
130
155
  """
131
156
  Adds a prompt to the benchmark.
157
+
158
+ Args:
159
+ identifier: The identifier of the prompt/asset that will be used to match up the media.
160
+ prompt: The prompt that will be used to evaluate the model.
161
+ asset: The asset that will be used to evaluate the model. Provided as a link to the asset.
132
162
  """
133
- if not isinstance(identifier, str) or not isinstance(prompt, str):
134
- raise ValueError("Identifier and prompt must be strings.")
163
+ if not isinstance(identifier, str):
164
+ raise ValueError("Identifier must be a string.")
165
+
166
+ if prompt is None and asset is None:
167
+ raise ValueError("Prompt or asset must be provided.")
168
+
169
+ if prompt is not None and not isinstance(prompt, str):
170
+ raise ValueError("Prompt must be a string.")
171
+
172
+ if asset is not None and not isinstance(asset, str):
173
+ raise ValueError("Asset must be a string. That is the link to the asset.")
135
174
 
136
175
  if identifier in self.identifiers:
137
176
  raise ValueError("Identifier already exists in the benchmark.")
138
177
 
178
+ if asset is not None and not re.match(r'^https?://', asset):
179
+ raise ValueError("Asset must be a link to the asset.")
180
+
139
181
  self.__identifiers.append(identifier)
182
+
140
183
  self.__prompts.append(prompt)
184
+ self.__prompt_assets.append(asset)
141
185
 
142
186
  self.__openapi_service.benchmark_api.benchmark_benchmark_id_prompt_post(
143
187
  benchmark_id=self.id,
144
188
  submit_prompt_model=SubmitPromptModel(
145
189
  identifier=identifier,
146
190
  prompt=prompt,
191
+ promptAsset=SubmitPromptModelPromptAsset(
192
+ UrlAssetInput(
193
+ _t="UrlAssetInput",
194
+ url=asset
195
+ )
196
+ ) if asset is not None else None
147
197
  )
148
198
  )
149
199
 
@@ -151,7 +201,8 @@ class RapidataBenchmark:
151
201
  self,
152
202
  name: str,
153
203
  instruction: str,
154
- show_prompt: bool,
204
+ show_prompt: bool = False,
205
+ show_prompt_asset: bool = False,
155
206
  inverse_ranking: bool = False,
156
207
  min_responses: int | None = None,
157
208
  response_budget: int | None = None
@@ -162,7 +213,8 @@ class RapidataBenchmark:
162
213
  Args:
163
214
  name: The name of the leaderboard. (not shown to the users)
164
215
  instruction: The instruction decides how the models will be evaluated.
165
- show_prompt: Whether to show the prompt to the users.
216
+ show_prompt: Whether to show the prompt to the users. (default: False)
217
+ show_prompt_asset: Whether to show the prompt asset to the users. (only works if the prompt asset is a URL) (default: False)
166
218
  inverse_ranking: Whether to inverse the ranking of the leaderboard. (if the question is inversed, e.g. "Which video is worse?")
167
219
  min_responses: The minimum amount of responses that get collected per comparison. if None, it will be defaulted.
168
220
  response_budget: The total amount of responses that get collected per new model evaluation. if None, it will be defaulted. Values below 2000 are not recommended.
@@ -177,6 +229,7 @@ class RapidataBenchmark:
177
229
  name=name,
178
230
  instruction=instruction,
179
231
  showPrompt=show_prompt,
232
+ showPromptAsset=show_prompt_asset,
180
233
  isInversed=inverse_ranking,
181
234
  minResponses=min_responses,
182
235
  responseBudget=response_budget
@@ -189,6 +242,7 @@ class RapidataBenchmark:
189
242
  name,
190
243
  instruction,
191
244
  show_prompt,
245
+ show_prompt_asset,
192
246
  inverse_ranking,
193
247
  leaderboard_result.min_responses,
194
248
  leaderboard_result.response_budget,
@@ -1,3 +1,4 @@
1
+ from typing import Optional
1
2
  from rapidata.rapidata_client.benchmark.rapidata_benchmark import RapidataBenchmark
2
3
  from rapidata.api_client.models.create_benchmark_model import CreateBenchmarkModel
3
4
  from rapidata.service.openapi_service import OpenAPIService
@@ -24,27 +25,40 @@ class RapidataBenchmarkManager:
24
25
  def create_new_benchmark(self,
25
26
  name: str,
26
27
  identifiers: list[str],
27
- prompts: list[str],
28
+ prompts: Optional[list[str]] = None,
29
+ prompt_assets: Optional[list[str]] = None,
28
30
  ) -> RapidataBenchmark:
29
31
  """
30
- Creates a new benchmark with the given name, prompts, and leaderboards.
32
+ Creates a new benchmark with the given name, identifiers, prompts, and media assets.
33
+
34
+ prompts or prompt_assets must be provided.
31
35
 
32
36
  Args:
33
37
  name: The name of the benchmark.
34
38
  prompts: The prompts that will be registered for the benchmark.
39
+ prompt_assets: The prompt assets that will be registered for the benchmark.
35
40
  """
36
41
  if not isinstance(name, str):
37
42
  raise ValueError("Name must be a string.")
38
43
 
39
- if not isinstance(prompts, list) or not all(isinstance(prompt, str) for prompt in prompts):
44
+ if prompts and (not isinstance(prompts, list) or not all(isinstance(prompt, str) for prompt in prompts)):
40
45
  raise ValueError("Prompts must be a list of strings.")
41
46
 
47
+ if prompt_assets and (not isinstance(prompt_assets, list) or not all(isinstance(asset, str) for asset in prompt_assets)):
48
+ raise ValueError("Media assets must be a list of strings.")
49
+
42
50
  if not isinstance(identifiers, list) or not all(isinstance(identifier, str) for identifier in identifiers):
43
51
  raise ValueError("Identifiers must be a list of strings.")
44
52
 
45
- if len(identifiers) != len(prompts):
53
+ if prompts and len(identifiers) != len(prompts):
46
54
  raise ValueError("Identifiers and prompts must have the same length.")
47
55
 
56
+ if prompt_assets and len(identifiers) != len(prompt_assets):
57
+ raise ValueError("Identifiers and media assets must have the same length.")
58
+
59
+ if not prompts and not prompt_assets:
60
+ raise ValueError("At least one of prompts or media assets must be provided.")
61
+
48
62
  if len(set(identifiers)) != len(identifiers):
49
63
  raise ValueError("Identifiers must be unique.")
50
64
 
@@ -55,8 +69,12 @@ class RapidataBenchmarkManager:
55
69
  )
56
70
 
57
71
  benchmark = RapidataBenchmark(name, benchmark_result.id, self.__openapi_service)
58
- for identifier, prompt in zip(identifiers, prompts):
59
- benchmark.add_prompt(identifier, prompt)
72
+
73
+ prompts_list = prompts if prompts is not None else [None] * len(identifiers)
74
+ media_assets_list = prompt_assets if prompt_assets is not None else [None] * len(identifiers)
75
+
76
+ for identifier, prompt, asset in zip(identifiers, prompts_list, media_assets_list):
77
+ benchmark.add_prompt(identifier, prompt, asset)
60
78
 
61
79
  return benchmark
62
80
 
@@ -8,5 +8,6 @@ from .user_score_filter import UserScoreFilter
8
8
  from .custom_filter import CustomFilter
9
9
  from .not_filter import NotFilter
10
10
  from .or_filter import OrFilter
11
+ from .and_filter import AndFilter
11
12
  from .response_count_filter import ResponseCountFilter
12
13
  from .new_user_filter import NewUserFilter
@@ -29,6 +29,26 @@ class RapidataFilter:
29
29
  else:
30
30
  return OrFilter([self, other])
31
31
 
32
+ def __and__(self, other):
33
+ """Enable the & operator to create AndFilter combinations."""
34
+ if not isinstance(other, RapidataFilter):
35
+ return NotImplemented
36
+
37
+ from rapidata.rapidata_client.filter.and_filter import AndFilter
38
+
39
+ # If self is already an AndFilter, extend its filters list
40
+ if isinstance(self, AndFilter):
41
+ if isinstance(other, AndFilter):
42
+ return AndFilter(self.filters + other.filters)
43
+ else:
44
+ return AndFilter(self.filters + [other])
45
+ # If other is an AndFilter, prepend self to its filters
46
+ elif isinstance(other, AndFilter):
47
+ return AndFilter([self] + other.filters)
48
+ # Neither is an AndFilter, create a new one
49
+ else:
50
+ return AndFilter([self, other])
51
+
32
52
  def __invert__(self):
33
53
  """Enable the ~ operator to create NotFilter negations."""
34
54
  from rapidata.rapidata_client.filter.not_filter import NotFilter
@@ -0,0 +1,30 @@
1
+ from typing import Any
2
+ from rapidata.rapidata_client.filter._base_filter import RapidataFilter
3
+ from rapidata.api_client.models.and_user_filter_model import AndUserFilterModel
4
+ from rapidata.api_client.models.and_user_filter_model_filters_inner import AndUserFilterModelFiltersInner
5
+
6
+
7
+ class AndFilter(RapidataFilter):
8
+ """A filter that combines multiple filters with a logical AND operation.
9
+ This class implements a logical AND operation on a list of filters, where the condition is met if all of the filters' conditions are met.
10
+
11
+ Args:
12
+ filters (list[RapidataFilter]): A list of filters to be combined with AND.
13
+
14
+ Example:
15
+ ```python
16
+ from rapidata import AndFilter, LanguageFilter, CountryFilter
17
+
18
+ AndFilter([LanguageFilter(["en"]), CountryFilter(["US"])])
19
+ ```
20
+
21
+ This will match users who have their phone set to English AND are located in the United States.
22
+ """
23
+ def __init__(self, filters: list[RapidataFilter]):
24
+ if not all(isinstance(filter, RapidataFilter) for filter in filters):
25
+ raise ValueError("Filters must be a RapidataFilter object")
26
+
27
+ self.filters = filters
28
+
29
+ def _to_model(self):
30
+ return AndUserFilterModel(_t="AndFilter", filters=[AndUserFilterModelFiltersInner(filter._to_model()) for filter in self.filters])
@@ -5,7 +5,8 @@ from rapidata.rapidata_client.filter import (
5
5
  LanguageFilter,
6
6
  UserScoreFilter,
7
7
  NotFilter,
8
- OrFilter)
8
+ OrFilter,
9
+ AndFilter)
9
10
 
10
11
  class RapidataFilters:
11
12
  """RapidataFilters Classes
@@ -25,6 +26,7 @@ class RapidataFilters:
25
26
  language (LanguageFilter): Filters for users with a specific language.
26
27
  not_filter (NotFilter): Inverts the filter.
27
28
  or_filter (OrFilter): Combines multiple filters with a logical OR operation.
29
+ and_filter (AndFilter): Combines multiple filters with a logical AND operation.
28
30
 
29
31
  Example:
30
32
  ```python
@@ -40,10 +42,10 @@ class RapidataFilters:
40
42
 
41
43
  ```python
42
44
  from rapidata import AgeFilter, LanguageFilter, CountryFilter
43
- filters=[~AgeFilter([AgeGroup.UNDER_18]), CountryFilter(["US"]) | LanguageFilter(["en"])]
45
+ filters=[~AgeFilter([AgeGroup.UNDER_18]), CountryFilter(["US"]) | (CountryFilter(["CA"]) & LanguageFilter(["en"]))]
44
46
  ```
45
47
 
46
- This would return users who are not under 18 years old and are from the US or whose phones are set to English.
48
+ This would return users who are not under 18 years old and are from the US or who are from Canada and whose phones are set to English.
47
49
  """
48
50
  user_score = UserScoreFilter
49
51
  age = AgeFilter
@@ -52,3 +54,4 @@ class RapidataFilters:
52
54
  language = LanguageFilter
53
55
  not_filter = NotFilter
54
56
  or_filter = OrFilter
57
+ and_filter = AndFilter
@@ -58,7 +58,7 @@ class RapidataOrderBuilder:
58
58
  self.__settings: Sequence[RapidataSetting] | None = None
59
59
  self.__user_filters: list[RapidataFilter] = []
60
60
  self.__selections: list[RapidataSelection] = []
61
- self.__priority: int = 50
61
+ self.__priority: int | None = None
62
62
  self.__assets: Sequence[BaseAsset] = []
63
63
 
64
64
  def _to_model(self) -> CreateOrderModel:
@@ -93,10 +93,14 @@ class RapidataOrderBuilder:
93
93
  if self.__settings is not None
94
94
  else None
95
95
  ),
96
- selections=[
97
- AbTestSelectionAInner(selection._to_model())
98
- for selection in self.__selections
99
- ],
96
+ selections=(
97
+ [
98
+ AbTestSelectionAInner(selection._to_model())
99
+ for selection in self.__selections
100
+ ]
101
+ if self.__selections
102
+ else None
103
+ ),
100
104
  priority=self.__priority,
101
105
  )
102
106
 
@@ -276,7 +280,7 @@ class RapidataOrderBuilder:
276
280
  self.__user_filters = filters
277
281
  return self
278
282
 
279
- def _validation_set_id(self, validation_set_id: str) -> "RapidataOrderBuilder":
283
+ def _validation_set_id(self, validation_set_id: str | None = None) -> "RapidataOrderBuilder":
280
284
  """
281
285
  Set the validation set ID for the order.
282
286
 
@@ -286,7 +290,7 @@ class RapidataOrderBuilder:
286
290
  Returns:
287
291
  RapidataOrderBuilder: The updated RapidataOrderBuilder instance.
288
292
  """
289
- if not isinstance(validation_set_id, str):
293
+ if validation_set_id is not None and not isinstance(validation_set_id, str):
290
294
  raise TypeError("Validation set ID must be of type str.")
291
295
 
292
296
  self.__validation_set_id = validation_set_id
@@ -329,7 +333,7 @@ class RapidataOrderBuilder:
329
333
  self.__selections = selections # type: ignore
330
334
  return self
331
335
 
332
- def _priority(self, priority: int) -> "RapidataOrderBuilder":
336
+ def _priority(self, priority: int | None = None) -> "RapidataOrderBuilder":
333
337
  """
334
338
  Set the priority for the order.
335
339
 
@@ -339,7 +343,7 @@ class RapidataOrderBuilder:
339
343
  Returns:
340
344
  RapidataOrderBuilder: The updated RapidataOrderBuilder instance.
341
345
  """
342
- if not isinstance(priority, int):
346
+ if priority is not None and not isinstance(priority, int):
343
347
  raise TypeError("Priority must be of type int.")
344
348
 
345
349
  self.__priority = priority
@@ -53,13 +53,8 @@ class RapidataOrderManager:
53
53
  self.filters = RapidataFilters
54
54
  self.settings = RapidataSettings
55
55
  self.selections = RapidataSelections
56
- self.__priority = 50
56
+ self.__priority: int | None = None
57
57
  logger.debug("RapidataOrderManager initialized")
58
-
59
- def __get_selections(self, validation_set_id: str | None, labeling_amount=3) -> Sequence[RapidataSelection]:
60
- if validation_set_id:
61
- return [ValidationSelection(validation_set_id=validation_set_id), LabelingSelection(amount=labeling_amount-1)]
62
- return [LabelingSelection(amount=labeling_amount)]
63
58
 
64
59
  def _create_general_order(self,
65
60
  name: str,
@@ -75,7 +70,6 @@ class RapidataOrderManager:
75
70
  sentences: list[str] | None = None,
76
71
  selections: Sequence[RapidataSelection] = [],
77
72
  private_notes: list[str] | None = None,
78
- default_labeling_amount: int = 3
79
73
  ) -> RapidataOrder:
80
74
 
81
75
  if not assets:
@@ -108,9 +102,6 @@ class RapidataOrderManager:
108
102
 
109
103
  if selections and validation_set_id:
110
104
  logger.warning("Warning: Both selections and validation_set_id provided. Ignoring validation_set_id.")
111
-
112
- if not selections:
113
- selections = self.__get_selections(validation_set_id, labeling_amount=default_labeling_amount)
114
105
 
115
106
  prompts_metadata = [PromptMetadata(prompt=prompt) for prompt in contexts] if contexts else None
116
107
  sentence_metadata = [SelectWordsMetadata(select_words=sentence) for sentence in sentences] if sentences else None
@@ -135,6 +126,7 @@ class RapidataOrderManager:
135
126
  ._filters(filters)
136
127
  ._selections(selections)
137
128
  ._settings(settings)
129
+ ._validation_set_id(validation_set_id if not selections else None)
138
130
  ._priority(self.__priority)
139
131
  ._create()
140
132
  )
@@ -398,7 +390,6 @@ class RapidataOrderManager:
398
390
  filters=filters,
399
391
  selections=selections,
400
392
  settings=settings,
401
- default_labeling_amount=1,
402
393
  private_notes=private_notes
403
394
  )
404
395
 
@@ -451,7 +442,6 @@ class RapidataOrderManager:
451
442
  selections=selections,
452
443
  settings=settings,
453
444
  sentences=sentences,
454
- default_labeling_amount=2,
455
445
  private_notes=private_notes
456
446
  )
457
447
 
@@ -623,7 +613,6 @@ class RapidataOrderManager:
623
613
  filters=filters,
624
614
  selections=selections,
625
615
  settings=settings,
626
- default_labeling_amount=2,
627
616
  private_notes=private_notes
628
617
  )
629
618
 
@@ -1,14 +1,9 @@
1
- from pydantic import StrictBytes, StrictStr
2
1
  from rapidata.rapidata_client.assets import MediaAsset, TextAsset, MultiAsset
3
2
  from rapidata.rapidata_client.metadata import Metadata
4
- from typing import Sequence
5
- from typing import Any
3
+ from typing import Sequence, Any, cast
6
4
  from rapidata.api_client.models.add_validation_rapid_model import (
7
5
  AddValidationRapidModel,
8
6
  )
9
- from rapidata.api_client.models.add_validation_text_rapid_model import (
10
- AddValidationTextRapidModel,
11
- )
12
7
  from rapidata.api_client.models.add_validation_rapid_model_payload import (
13
8
  AddValidationRapidModelPayload,
14
9
  )
@@ -32,38 +27,52 @@ class Rapid():
32
27
  logger.debug(f"Created Rapid with asset: {self.asset}, metadata: {self.metadata}, payload: {self.payload}, truth: {self.truth}, randomCorrectProbability: {self.randomCorrectProbability}, explanation: {self.explanation}")
33
28
 
34
29
  def _add_to_validation_set(self, validationSetId: str, openapi_service: OpenAPIService) -> None:
35
- if isinstance(self.asset, TextAsset) or (isinstance(self.asset, MultiAsset) and isinstance(self.asset.assets[0], TextAsset)):
36
- openapi_service.validation_api.validation_set_validation_set_id_rapid_texts_post(
30
+ model = self.__to_model()
31
+ assets = self.__convert_to_assets()
32
+ if isinstance(assets[0], TextAsset):
33
+ assert all(isinstance(asset, TextAsset) for asset in assets)
34
+ texts = cast(list[TextAsset], assets)
35
+ openapi_service.validation_api.validation_set_validation_set_id_rapid_post(
37
36
  validation_set_id=validationSetId,
38
- add_validation_text_rapid_model=self.__to_text_model()
37
+ model=model,
38
+ texts=[asset.text for asset in texts]
39
39
  )
40
40
 
41
- elif isinstance(self.asset, MediaAsset) or (isinstance(self.asset, MultiAsset) and isinstance(self.asset.assets[0], MediaAsset)):
42
- model = self.__to_media_model()
43
- openapi_service.validation_api.validation_set_validation_set_id_rapid_files_post(
41
+ elif isinstance(assets[0], MediaAsset):
42
+ assert all(isinstance(asset, MediaAsset) for asset in assets)
43
+ files = cast(list[MediaAsset], assets)
44
+ openapi_service.validation_api.validation_set_validation_set_id_rapid_post(
44
45
  validation_set_id=validationSetId,
45
- model=model[0], files=model[1]
46
+ model=model,
47
+ files=[asset.to_file() for asset in files],
48
+ urls=[asset.path for asset in files if not asset.is_local()]
46
49
  )
47
50
 
48
51
  else:
49
52
  raise TypeError("The asset must be a MediaAsset, TextAsset, or MultiAsset")
50
-
51
- def __to_media_model(self) -> tuple[AddValidationRapidModel, list[StrictStr | tuple[StrictStr, StrictBytes] | StrictBytes]]:
52
- assets: list[MediaAsset] = []
53
+
54
+
55
+ def __convert_to_assets(self) -> list[MediaAsset | TextAsset]:
56
+ assets: list[MediaAsset | TextAsset] = []
53
57
  if isinstance(self.asset, MultiAsset):
54
58
  for asset in self.asset.assets:
55
59
  if isinstance(asset, MediaAsset):
56
60
  assets.append(asset)
61
+ elif isinstance(asset, TextAsset):
62
+ assets.append(asset)
57
63
  else:
58
- raise TypeError("The asset is a multiasset, but not all assets are MediaAssets")
64
+ raise TypeError("The asset is a multiasset, but not all assets are MediaAssets or TextAssets")
59
65
 
60
66
  if isinstance(self.asset, TextAsset):
61
- raise TypeError("The asset must contain Media")
67
+ assets = [self.asset]
62
68
 
63
69
  if isinstance(self.asset, MediaAsset):
64
70
  assets = [self.asset]
65
71
 
66
- return (AddValidationRapidModel(
72
+ return assets
73
+
74
+ def __to_model(self) -> AddValidationRapidModel:
75
+ return AddValidationRapidModel(
67
76
  payload=AddValidationRapidModelPayload(self.payload),
68
77
  truth=AddValidationRapidModelTruth(self.truth),
69
78
  metadata=[
@@ -72,31 +81,4 @@ class Rapid():
72
81
  ],
73
82
  randomCorrectProbability=self.randomCorrectProbability,
74
83
  explanation=self.explanation
75
- ), [asset.to_file() for asset in assets])
76
-
77
- def __to_text_model(self) -> AddValidationTextRapidModel:
78
- texts: list[str] = []
79
- if isinstance(self.asset, MultiAsset):
80
- for asset in self.asset.assets:
81
- if isinstance(asset, TextAsset):
82
- texts.append(asset.text)
83
- else:
84
- raise TypeError("The asset is a multiasset, but not all assets are TextAssets")
85
-
86
- if isinstance(self.asset, MediaAsset):
87
- raise TypeError("The asset must contain Text")
88
-
89
- if isinstance(self.asset, TextAsset):
90
- texts = [self.asset.text]
91
-
92
- return AddValidationTextRapidModel(
93
- payload=AddValidationRapidModelPayload(self.payload),
94
- truth=AddValidationRapidModelTruth(self.truth),
95
- metadata=[
96
- DatasetDatasetIdDatapointsPostRequestMetadataInner(meta.to_model())
97
- for meta in self.metadata
98
- ],
99
- randomCorrectProbability=self.randomCorrectProbability,
100
- texts=texts,
101
- explanation=self.explanation
102
- )
84
+ )
@@ -11,7 +11,6 @@ from rapidata.api_client.models.page_info import PageInfo
11
11
  from rapidata.api_client.models.root_filter import RootFilter
12
12
  from rapidata.api_client.models.filter import Filter
13
13
  from rapidata.api_client.models.sort_criterion import SortCriterion
14
- from urllib3._collections import HTTPHeaderDict # type: ignore[import]
15
14
 
16
15
  from rapidata.rapidata_client.validation.rapids.box import Box
17
16
 
@@ -527,9 +526,17 @@ class ValidationSetManager:
527
526
  )
528
527
 
529
528
  logger.debug("Adding rapids to validation set")
529
+ failed_rapids = []
530
530
  for rapid in tqdm(rapids, desc="Uploading validation tasks", disable=RapidataOutputManager.silent_mode):
531
- validation_set.add_rapid(rapid)
532
-
531
+ try:
532
+ validation_set.add_rapid(rapid)
533
+ except Exception:
534
+ failed_rapids.append(rapid.asset)
535
+
536
+ if failed_rapids:
537
+ logger.error(f"Failed to add {len(failed_rapids)} datapoints to validation set: {failed_rapids}")
538
+ raise RuntimeError(f"Failed to add {len(failed_rapids)} datapoints to validation set: {failed_rapids}")
539
+
533
540
  managed_print()
534
541
  managed_print(f"Validation set '{name}' created with ID {validation_set_id}\n",
535
542
  f"Now viewable under: https://app.{self.__openapi_service.environment}/validation-set/detail/{validation_set_id}",
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rapidata
3
- Version: 2.29.1
3
+ Version: 2.31.0
4
4
  Summary: Rapidata package containing the Rapidata Python Client to interact with the Rapidata Web API in an easy way.
5
5
  License: Apache-2.0
6
6
  Author: Rapidata AG