rapidata 2.29.1__py3-none-any.whl → 2.31.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rapidata might be problematic. Click here for more details.
- rapidata/__init__.py +1 -1
- rapidata/api_client/__init__.py +5 -0
- rapidata/api_client/api/benchmark_api.py +550 -0
- rapidata/api_client/api/dataset_api.py +14 -14
- rapidata/api_client/api/leaderboard_api.py +562 -0
- rapidata/api_client/api/validation_set_api.py +349 -6
- rapidata/api_client/models/__init__.py +5 -0
- rapidata/api_client/models/file_type.py +1 -0
- rapidata/api_client/models/file_type_metadata.py +2 -2
- rapidata/api_client/models/file_type_metadata_model.py +2 -2
- rapidata/api_client/models/get_standing_by_id_result.py +4 -2
- rapidata/api_client/models/participant_by_benchmark.py +2 -2
- rapidata/api_client/models/participant_status.py +1 -0
- rapidata/api_client/models/prompt_by_benchmark_result.py +19 -3
- rapidata/api_client/models/run_status.py +39 -0
- rapidata/api_client/models/runs_by_leaderboard_result.py +110 -0
- rapidata/api_client/models/runs_by_leaderboard_result_paged_result.py +105 -0
- rapidata/api_client/models/standing_by_leaderboard.py +5 -3
- rapidata/api_client/models/update_benchmark_name_model.py +87 -0
- rapidata/api_client/models/update_leaderboard_name_model.py +87 -0
- rapidata/api_client_README.md +10 -0
- rapidata/rapidata_client/benchmark/leaderboard/rapidata_leaderboard.py +9 -0
- rapidata/rapidata_client/benchmark/rapidata_benchmark.py +66 -12
- rapidata/rapidata_client/benchmark/rapidata_benchmark_manager.py +24 -6
- rapidata/rapidata_client/filter/__init__.py +1 -0
- rapidata/rapidata_client/filter/_base_filter.py +20 -0
- rapidata/rapidata_client/filter/and_filter.py +30 -0
- rapidata/rapidata_client/filter/rapidata_filters.py +6 -3
- rapidata/rapidata_client/order/_rapidata_order_builder.py +13 -9
- rapidata/rapidata_client/order/rapidata_order_manager.py +2 -13
- rapidata/rapidata_client/validation/rapids/rapids.py +29 -47
- rapidata/rapidata_client/validation/validation_set_manager.py +10 -3
- {rapidata-2.29.1.dist-info → rapidata-2.31.0.dist-info}/METADATA +1 -1
- {rapidata-2.29.1.dist-info → rapidata-2.31.0.dist-info}/RECORD +36 -30
- {rapidata-2.29.1.dist-info → rapidata-2.31.0.dist-info}/LICENSE +0 -0
- {rapidata-2.29.1.dist-info → rapidata-2.31.0.dist-info}/WHEEL +0 -0
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import re
|
|
1
2
|
from rapidata.api_client.models.root_filter import RootFilter
|
|
2
3
|
from rapidata.api_client.models.filter import Filter
|
|
3
4
|
from rapidata.api_client.models.query_model import QueryModel
|
|
@@ -5,6 +6,10 @@ from rapidata.api_client.models.page_info import PageInfo
|
|
|
5
6
|
from rapidata.api_client.models.create_leaderboard_model import CreateLeaderboardModel
|
|
6
7
|
from rapidata.api_client.models.create_benchmark_participant_model import CreateBenchmarkParticipantModel
|
|
7
8
|
from rapidata.api_client.models.submit_prompt_model import SubmitPromptModel
|
|
9
|
+
from rapidata.api_client.models.submit_prompt_model_prompt_asset import SubmitPromptModelPromptAsset
|
|
10
|
+
from rapidata.api_client.models.url_asset_input import UrlAssetInput
|
|
11
|
+
from rapidata.api_client.models.file_asset_model import FileAssetModel
|
|
12
|
+
from rapidata.api_client.models.source_url_metadata_model import SourceUrlMetadataModel
|
|
8
13
|
|
|
9
14
|
from rapidata.rapidata_client.logging import logger
|
|
10
15
|
from rapidata.service.openapi_service import OpenAPIService
|
|
@@ -29,7 +34,8 @@ class RapidataBenchmark:
|
|
|
29
34
|
self.name = name
|
|
30
35
|
self.id = id
|
|
31
36
|
self.__openapi_service = openapi_service
|
|
32
|
-
self.__prompts: list[str] = []
|
|
37
|
+
self.__prompts: list[str | None] = []
|
|
38
|
+
self.__prompt_assets: list[str | None] = []
|
|
33
39
|
self.__leaderboards: list[RapidataLeaderboard] = []
|
|
34
40
|
self.__identifiers: list[str] = []
|
|
35
41
|
|
|
@@ -53,8 +59,16 @@ class RapidataBenchmark:
|
|
|
53
59
|
|
|
54
60
|
total_pages = prompts_result.total_pages
|
|
55
61
|
|
|
56
|
-
|
|
57
|
-
|
|
62
|
+
for prompt in prompts_result.items:
|
|
63
|
+
self.__prompts.append(prompt.prompt)
|
|
64
|
+
self.__identifiers.append(prompt.identifier)
|
|
65
|
+
if prompt.prompt_asset is None:
|
|
66
|
+
self.__prompt_assets.append(None)
|
|
67
|
+
else:
|
|
68
|
+
assert isinstance(prompt.prompt_asset.actual_instance, FileAssetModel)
|
|
69
|
+
source_url = prompt.prompt_asset.actual_instance.metadata["sourceUrl"].actual_instance
|
|
70
|
+
assert isinstance(source_url, SourceUrlMetadataModel)
|
|
71
|
+
self.__prompt_assets.append(source_url.url)
|
|
58
72
|
|
|
59
73
|
if current_page >= total_pages:
|
|
60
74
|
break
|
|
@@ -62,7 +76,14 @@ class RapidataBenchmark:
|
|
|
62
76
|
current_page += 1
|
|
63
77
|
|
|
64
78
|
@property
|
|
65
|
-
def
|
|
79
|
+
def identifiers(self) -> list[str]:
|
|
80
|
+
if not self.__identifiers:
|
|
81
|
+
self.__instantiate_prompts()
|
|
82
|
+
|
|
83
|
+
return self.__identifiers
|
|
84
|
+
|
|
85
|
+
@property
|
|
86
|
+
def prompts(self) -> list[str | None]:
|
|
66
87
|
"""
|
|
67
88
|
Returns the prompts that are registered for the leaderboard.
|
|
68
89
|
"""
|
|
@@ -72,11 +93,14 @@ class RapidataBenchmark:
|
|
|
72
93
|
return self.__prompts
|
|
73
94
|
|
|
74
95
|
@property
|
|
75
|
-
def
|
|
76
|
-
|
|
96
|
+
def prompt_assets(self) -> list[str | None]:
|
|
97
|
+
"""
|
|
98
|
+
Returns the prompt assets that are registered for the benchmark.
|
|
99
|
+
"""
|
|
100
|
+
if not self.__prompt_assets:
|
|
77
101
|
self.__instantiate_prompts()
|
|
78
102
|
|
|
79
|
-
return self.
|
|
103
|
+
return self.__prompt_assets
|
|
80
104
|
|
|
81
105
|
@property
|
|
82
106
|
def leaderboards(self) -> list[RapidataLeaderboard]:
|
|
@@ -112,6 +136,7 @@ class RapidataBenchmark:
|
|
|
112
136
|
leaderboard.name,
|
|
113
137
|
leaderboard.instruction,
|
|
114
138
|
leaderboard.show_prompt,
|
|
139
|
+
leaderboard.show_prompt_asset,
|
|
115
140
|
leaderboard.is_inversed,
|
|
116
141
|
leaderboard.min_responses,
|
|
117
142
|
leaderboard.response_budget,
|
|
@@ -126,24 +151,49 @@ class RapidataBenchmark:
|
|
|
126
151
|
|
|
127
152
|
return self.__leaderboards
|
|
128
153
|
|
|
129
|
-
def add_prompt(self, identifier: str, prompt: str):
|
|
154
|
+
def add_prompt(self, identifier: str, prompt: str | None = None, asset: str | None = None):
|
|
130
155
|
"""
|
|
131
156
|
Adds a prompt to the benchmark.
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
identifier: The identifier of the prompt/asset that will be used to match up the media.
|
|
160
|
+
prompt: The prompt that will be used to evaluate the model.
|
|
161
|
+
asset: The asset that will be used to evaluate the model. Provided as a link to the asset.
|
|
132
162
|
"""
|
|
133
|
-
if not isinstance(identifier, str)
|
|
134
|
-
raise ValueError("Identifier
|
|
163
|
+
if not isinstance(identifier, str):
|
|
164
|
+
raise ValueError("Identifier must be a string.")
|
|
165
|
+
|
|
166
|
+
if prompt is None and asset is None:
|
|
167
|
+
raise ValueError("Prompt or asset must be provided.")
|
|
168
|
+
|
|
169
|
+
if prompt is not None and not isinstance(prompt, str):
|
|
170
|
+
raise ValueError("Prompt must be a string.")
|
|
171
|
+
|
|
172
|
+
if asset is not None and not isinstance(asset, str):
|
|
173
|
+
raise ValueError("Asset must be a string. That is the link to the asset.")
|
|
135
174
|
|
|
136
175
|
if identifier in self.identifiers:
|
|
137
176
|
raise ValueError("Identifier already exists in the benchmark.")
|
|
138
177
|
|
|
178
|
+
if asset is not None and not re.match(r'^https?://', asset):
|
|
179
|
+
raise ValueError("Asset must be a link to the asset.")
|
|
180
|
+
|
|
139
181
|
self.__identifiers.append(identifier)
|
|
182
|
+
|
|
140
183
|
self.__prompts.append(prompt)
|
|
184
|
+
self.__prompt_assets.append(asset)
|
|
141
185
|
|
|
142
186
|
self.__openapi_service.benchmark_api.benchmark_benchmark_id_prompt_post(
|
|
143
187
|
benchmark_id=self.id,
|
|
144
188
|
submit_prompt_model=SubmitPromptModel(
|
|
145
189
|
identifier=identifier,
|
|
146
190
|
prompt=prompt,
|
|
191
|
+
promptAsset=SubmitPromptModelPromptAsset(
|
|
192
|
+
UrlAssetInput(
|
|
193
|
+
_t="UrlAssetInput",
|
|
194
|
+
url=asset
|
|
195
|
+
)
|
|
196
|
+
) if asset is not None else None
|
|
147
197
|
)
|
|
148
198
|
)
|
|
149
199
|
|
|
@@ -151,7 +201,8 @@ class RapidataBenchmark:
|
|
|
151
201
|
self,
|
|
152
202
|
name: str,
|
|
153
203
|
instruction: str,
|
|
154
|
-
show_prompt: bool,
|
|
204
|
+
show_prompt: bool = False,
|
|
205
|
+
show_prompt_asset: bool = False,
|
|
155
206
|
inverse_ranking: bool = False,
|
|
156
207
|
min_responses: int | None = None,
|
|
157
208
|
response_budget: int | None = None
|
|
@@ -162,7 +213,8 @@ class RapidataBenchmark:
|
|
|
162
213
|
Args:
|
|
163
214
|
name: The name of the leaderboard. (not shown to the users)
|
|
164
215
|
instruction: The instruction decides how the models will be evaluated.
|
|
165
|
-
show_prompt: Whether to show the prompt to the users.
|
|
216
|
+
show_prompt: Whether to show the prompt to the users. (default: False)
|
|
217
|
+
show_prompt_asset: Whether to show the prompt asset to the users. (only works if the prompt asset is a URL) (default: False)
|
|
166
218
|
inverse_ranking: Whether to inverse the ranking of the leaderboard. (if the question is inversed, e.g. "Which video is worse?")
|
|
167
219
|
min_responses: The minimum amount of responses that get collected per comparison. if None, it will be defaulted.
|
|
168
220
|
response_budget: The total amount of responses that get collected per new model evaluation. if None, it will be defaulted. Values below 2000 are not recommended.
|
|
@@ -177,6 +229,7 @@ class RapidataBenchmark:
|
|
|
177
229
|
name=name,
|
|
178
230
|
instruction=instruction,
|
|
179
231
|
showPrompt=show_prompt,
|
|
232
|
+
showPromptAsset=show_prompt_asset,
|
|
180
233
|
isInversed=inverse_ranking,
|
|
181
234
|
minResponses=min_responses,
|
|
182
235
|
responseBudget=response_budget
|
|
@@ -189,6 +242,7 @@ class RapidataBenchmark:
|
|
|
189
242
|
name,
|
|
190
243
|
instruction,
|
|
191
244
|
show_prompt,
|
|
245
|
+
show_prompt_asset,
|
|
192
246
|
inverse_ranking,
|
|
193
247
|
leaderboard_result.min_responses,
|
|
194
248
|
leaderboard_result.response_budget,
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from typing import Optional
|
|
1
2
|
from rapidata.rapidata_client.benchmark.rapidata_benchmark import RapidataBenchmark
|
|
2
3
|
from rapidata.api_client.models.create_benchmark_model import CreateBenchmarkModel
|
|
3
4
|
from rapidata.service.openapi_service import OpenAPIService
|
|
@@ -24,27 +25,40 @@ class RapidataBenchmarkManager:
|
|
|
24
25
|
def create_new_benchmark(self,
|
|
25
26
|
name: str,
|
|
26
27
|
identifiers: list[str],
|
|
27
|
-
prompts: list[str],
|
|
28
|
+
prompts: Optional[list[str]] = None,
|
|
29
|
+
prompt_assets: Optional[list[str]] = None,
|
|
28
30
|
) -> RapidataBenchmark:
|
|
29
31
|
"""
|
|
30
|
-
Creates a new benchmark with the given name, prompts, and
|
|
32
|
+
Creates a new benchmark with the given name, identifiers, prompts, and media assets.
|
|
33
|
+
|
|
34
|
+
prompts or prompt_assets must be provided.
|
|
31
35
|
|
|
32
36
|
Args:
|
|
33
37
|
name: The name of the benchmark.
|
|
34
38
|
prompts: The prompts that will be registered for the benchmark.
|
|
39
|
+
prompt_assets: The prompt assets that will be registered for the benchmark.
|
|
35
40
|
"""
|
|
36
41
|
if not isinstance(name, str):
|
|
37
42
|
raise ValueError("Name must be a string.")
|
|
38
43
|
|
|
39
|
-
if not isinstance(prompts, list) or not all(isinstance(prompt, str) for prompt in prompts):
|
|
44
|
+
if prompts and (not isinstance(prompts, list) or not all(isinstance(prompt, str) for prompt in prompts)):
|
|
40
45
|
raise ValueError("Prompts must be a list of strings.")
|
|
41
46
|
|
|
47
|
+
if prompt_assets and (not isinstance(prompt_assets, list) or not all(isinstance(asset, str) for asset in prompt_assets)):
|
|
48
|
+
raise ValueError("Media assets must be a list of strings.")
|
|
49
|
+
|
|
42
50
|
if not isinstance(identifiers, list) or not all(isinstance(identifier, str) for identifier in identifiers):
|
|
43
51
|
raise ValueError("Identifiers must be a list of strings.")
|
|
44
52
|
|
|
45
|
-
if len(identifiers) != len(prompts):
|
|
53
|
+
if prompts and len(identifiers) != len(prompts):
|
|
46
54
|
raise ValueError("Identifiers and prompts must have the same length.")
|
|
47
55
|
|
|
56
|
+
if prompt_assets and len(identifiers) != len(prompt_assets):
|
|
57
|
+
raise ValueError("Identifiers and media assets must have the same length.")
|
|
58
|
+
|
|
59
|
+
if not prompts and not prompt_assets:
|
|
60
|
+
raise ValueError("At least one of prompts or media assets must be provided.")
|
|
61
|
+
|
|
48
62
|
if len(set(identifiers)) != len(identifiers):
|
|
49
63
|
raise ValueError("Identifiers must be unique.")
|
|
50
64
|
|
|
@@ -55,8 +69,12 @@ class RapidataBenchmarkManager:
|
|
|
55
69
|
)
|
|
56
70
|
|
|
57
71
|
benchmark = RapidataBenchmark(name, benchmark_result.id, self.__openapi_service)
|
|
58
|
-
|
|
59
|
-
|
|
72
|
+
|
|
73
|
+
prompts_list = prompts if prompts is not None else [None] * len(identifiers)
|
|
74
|
+
media_assets_list = prompt_assets if prompt_assets is not None else [None] * len(identifiers)
|
|
75
|
+
|
|
76
|
+
for identifier, prompt, asset in zip(identifiers, prompts_list, media_assets_list):
|
|
77
|
+
benchmark.add_prompt(identifier, prompt, asset)
|
|
60
78
|
|
|
61
79
|
return benchmark
|
|
62
80
|
|
|
@@ -8,5 +8,6 @@ from .user_score_filter import UserScoreFilter
|
|
|
8
8
|
from .custom_filter import CustomFilter
|
|
9
9
|
from .not_filter import NotFilter
|
|
10
10
|
from .or_filter import OrFilter
|
|
11
|
+
from .and_filter import AndFilter
|
|
11
12
|
from .response_count_filter import ResponseCountFilter
|
|
12
13
|
from .new_user_filter import NewUserFilter
|
|
@@ -29,6 +29,26 @@ class RapidataFilter:
|
|
|
29
29
|
else:
|
|
30
30
|
return OrFilter([self, other])
|
|
31
31
|
|
|
32
|
+
def __and__(self, other):
|
|
33
|
+
"""Enable the & operator to create AndFilter combinations."""
|
|
34
|
+
if not isinstance(other, RapidataFilter):
|
|
35
|
+
return NotImplemented
|
|
36
|
+
|
|
37
|
+
from rapidata.rapidata_client.filter.and_filter import AndFilter
|
|
38
|
+
|
|
39
|
+
# If self is already an AndFilter, extend its filters list
|
|
40
|
+
if isinstance(self, AndFilter):
|
|
41
|
+
if isinstance(other, AndFilter):
|
|
42
|
+
return AndFilter(self.filters + other.filters)
|
|
43
|
+
else:
|
|
44
|
+
return AndFilter(self.filters + [other])
|
|
45
|
+
# If other is an AndFilter, prepend self to its filters
|
|
46
|
+
elif isinstance(other, AndFilter):
|
|
47
|
+
return AndFilter([self] + other.filters)
|
|
48
|
+
# Neither is an AndFilter, create a new one
|
|
49
|
+
else:
|
|
50
|
+
return AndFilter([self, other])
|
|
51
|
+
|
|
32
52
|
def __invert__(self):
|
|
33
53
|
"""Enable the ~ operator to create NotFilter negations."""
|
|
34
54
|
from rapidata.rapidata_client.filter.not_filter import NotFilter
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
from rapidata.rapidata_client.filter._base_filter import RapidataFilter
|
|
3
|
+
from rapidata.api_client.models.and_user_filter_model import AndUserFilterModel
|
|
4
|
+
from rapidata.api_client.models.and_user_filter_model_filters_inner import AndUserFilterModelFiltersInner
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class AndFilter(RapidataFilter):
|
|
8
|
+
"""A filter that combines multiple filters with a logical AND operation.
|
|
9
|
+
This class implements a logical AND operation on a list of filters, where the condition is met if all of the filters' conditions are met.
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
filters (list[RapidataFilter]): A list of filters to be combined with AND.
|
|
13
|
+
|
|
14
|
+
Example:
|
|
15
|
+
```python
|
|
16
|
+
from rapidata import AndFilter, LanguageFilter, CountryFilter
|
|
17
|
+
|
|
18
|
+
AndFilter([LanguageFilter(["en"]), CountryFilter(["US"])])
|
|
19
|
+
```
|
|
20
|
+
|
|
21
|
+
This will match users who have their phone set to English AND are located in the United States.
|
|
22
|
+
"""
|
|
23
|
+
def __init__(self, filters: list[RapidataFilter]):
|
|
24
|
+
if not all(isinstance(filter, RapidataFilter) for filter in filters):
|
|
25
|
+
raise ValueError("Filters must be a RapidataFilter object")
|
|
26
|
+
|
|
27
|
+
self.filters = filters
|
|
28
|
+
|
|
29
|
+
def _to_model(self):
|
|
30
|
+
return AndUserFilterModel(_t="AndFilter", filters=[AndUserFilterModelFiltersInner(filter._to_model()) for filter in self.filters])
|
|
@@ -5,7 +5,8 @@ from rapidata.rapidata_client.filter import (
|
|
|
5
5
|
LanguageFilter,
|
|
6
6
|
UserScoreFilter,
|
|
7
7
|
NotFilter,
|
|
8
|
-
OrFilter
|
|
8
|
+
OrFilter,
|
|
9
|
+
AndFilter)
|
|
9
10
|
|
|
10
11
|
class RapidataFilters:
|
|
11
12
|
"""RapidataFilters Classes
|
|
@@ -25,6 +26,7 @@ class RapidataFilters:
|
|
|
25
26
|
language (LanguageFilter): Filters for users with a specific language.
|
|
26
27
|
not_filter (NotFilter): Inverts the filter.
|
|
27
28
|
or_filter (OrFilter): Combines multiple filters with a logical OR operation.
|
|
29
|
+
and_filter (AndFilter): Combines multiple filters with a logical AND operation.
|
|
28
30
|
|
|
29
31
|
Example:
|
|
30
32
|
```python
|
|
@@ -40,10 +42,10 @@ class RapidataFilters:
|
|
|
40
42
|
|
|
41
43
|
```python
|
|
42
44
|
from rapidata import AgeFilter, LanguageFilter, CountryFilter
|
|
43
|
-
filters=[~AgeFilter([AgeGroup.UNDER_18]), CountryFilter(["US"]) | LanguageFilter(["en"])]
|
|
45
|
+
filters=[~AgeFilter([AgeGroup.UNDER_18]), CountryFilter(["US"]) | (CountryFilter(["CA"]) & LanguageFilter(["en"]))]
|
|
44
46
|
```
|
|
45
47
|
|
|
46
|
-
This would return users who are not under 18 years old and are from the US or whose phones are set to English.
|
|
48
|
+
This would return users who are not under 18 years old and are from the US or who are from Canada and whose phones are set to English.
|
|
47
49
|
"""
|
|
48
50
|
user_score = UserScoreFilter
|
|
49
51
|
age = AgeFilter
|
|
@@ -52,3 +54,4 @@ class RapidataFilters:
|
|
|
52
54
|
language = LanguageFilter
|
|
53
55
|
not_filter = NotFilter
|
|
54
56
|
or_filter = OrFilter
|
|
57
|
+
and_filter = AndFilter
|
|
@@ -58,7 +58,7 @@ class RapidataOrderBuilder:
|
|
|
58
58
|
self.__settings: Sequence[RapidataSetting] | None = None
|
|
59
59
|
self.__user_filters: list[RapidataFilter] = []
|
|
60
60
|
self.__selections: list[RapidataSelection] = []
|
|
61
|
-
self.__priority: int =
|
|
61
|
+
self.__priority: int | None = None
|
|
62
62
|
self.__assets: Sequence[BaseAsset] = []
|
|
63
63
|
|
|
64
64
|
def _to_model(self) -> CreateOrderModel:
|
|
@@ -93,10 +93,14 @@ class RapidataOrderBuilder:
|
|
|
93
93
|
if self.__settings is not None
|
|
94
94
|
else None
|
|
95
95
|
),
|
|
96
|
-
selections=
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
96
|
+
selections=(
|
|
97
|
+
[
|
|
98
|
+
AbTestSelectionAInner(selection._to_model())
|
|
99
|
+
for selection in self.__selections
|
|
100
|
+
]
|
|
101
|
+
if self.__selections
|
|
102
|
+
else None
|
|
103
|
+
),
|
|
100
104
|
priority=self.__priority,
|
|
101
105
|
)
|
|
102
106
|
|
|
@@ -276,7 +280,7 @@ class RapidataOrderBuilder:
|
|
|
276
280
|
self.__user_filters = filters
|
|
277
281
|
return self
|
|
278
282
|
|
|
279
|
-
def _validation_set_id(self, validation_set_id: str) -> "RapidataOrderBuilder":
|
|
283
|
+
def _validation_set_id(self, validation_set_id: str | None = None) -> "RapidataOrderBuilder":
|
|
280
284
|
"""
|
|
281
285
|
Set the validation set ID for the order.
|
|
282
286
|
|
|
@@ -286,7 +290,7 @@ class RapidataOrderBuilder:
|
|
|
286
290
|
Returns:
|
|
287
291
|
RapidataOrderBuilder: The updated RapidataOrderBuilder instance.
|
|
288
292
|
"""
|
|
289
|
-
if not isinstance(validation_set_id, str):
|
|
293
|
+
if validation_set_id is not None and not isinstance(validation_set_id, str):
|
|
290
294
|
raise TypeError("Validation set ID must be of type str.")
|
|
291
295
|
|
|
292
296
|
self.__validation_set_id = validation_set_id
|
|
@@ -329,7 +333,7 @@ class RapidataOrderBuilder:
|
|
|
329
333
|
self.__selections = selections # type: ignore
|
|
330
334
|
return self
|
|
331
335
|
|
|
332
|
-
def _priority(self, priority: int) -> "RapidataOrderBuilder":
|
|
336
|
+
def _priority(self, priority: int | None = None) -> "RapidataOrderBuilder":
|
|
333
337
|
"""
|
|
334
338
|
Set the priority for the order.
|
|
335
339
|
|
|
@@ -339,7 +343,7 @@ class RapidataOrderBuilder:
|
|
|
339
343
|
Returns:
|
|
340
344
|
RapidataOrderBuilder: The updated RapidataOrderBuilder instance.
|
|
341
345
|
"""
|
|
342
|
-
if not isinstance(priority, int):
|
|
346
|
+
if priority is not None and not isinstance(priority, int):
|
|
343
347
|
raise TypeError("Priority must be of type int.")
|
|
344
348
|
|
|
345
349
|
self.__priority = priority
|
|
@@ -53,13 +53,8 @@ class RapidataOrderManager:
|
|
|
53
53
|
self.filters = RapidataFilters
|
|
54
54
|
self.settings = RapidataSettings
|
|
55
55
|
self.selections = RapidataSelections
|
|
56
|
-
self.__priority =
|
|
56
|
+
self.__priority: int | None = None
|
|
57
57
|
logger.debug("RapidataOrderManager initialized")
|
|
58
|
-
|
|
59
|
-
def __get_selections(self, validation_set_id: str | None, labeling_amount=3) -> Sequence[RapidataSelection]:
|
|
60
|
-
if validation_set_id:
|
|
61
|
-
return [ValidationSelection(validation_set_id=validation_set_id), LabelingSelection(amount=labeling_amount-1)]
|
|
62
|
-
return [LabelingSelection(amount=labeling_amount)]
|
|
63
58
|
|
|
64
59
|
def _create_general_order(self,
|
|
65
60
|
name: str,
|
|
@@ -75,7 +70,6 @@ class RapidataOrderManager:
|
|
|
75
70
|
sentences: list[str] | None = None,
|
|
76
71
|
selections: Sequence[RapidataSelection] = [],
|
|
77
72
|
private_notes: list[str] | None = None,
|
|
78
|
-
default_labeling_amount: int = 3
|
|
79
73
|
) -> RapidataOrder:
|
|
80
74
|
|
|
81
75
|
if not assets:
|
|
@@ -108,9 +102,6 @@ class RapidataOrderManager:
|
|
|
108
102
|
|
|
109
103
|
if selections and validation_set_id:
|
|
110
104
|
logger.warning("Warning: Both selections and validation_set_id provided. Ignoring validation_set_id.")
|
|
111
|
-
|
|
112
|
-
if not selections:
|
|
113
|
-
selections = self.__get_selections(validation_set_id, labeling_amount=default_labeling_amount)
|
|
114
105
|
|
|
115
106
|
prompts_metadata = [PromptMetadata(prompt=prompt) for prompt in contexts] if contexts else None
|
|
116
107
|
sentence_metadata = [SelectWordsMetadata(select_words=sentence) for sentence in sentences] if sentences else None
|
|
@@ -135,6 +126,7 @@ class RapidataOrderManager:
|
|
|
135
126
|
._filters(filters)
|
|
136
127
|
._selections(selections)
|
|
137
128
|
._settings(settings)
|
|
129
|
+
._validation_set_id(validation_set_id if not selections else None)
|
|
138
130
|
._priority(self.__priority)
|
|
139
131
|
._create()
|
|
140
132
|
)
|
|
@@ -398,7 +390,6 @@ class RapidataOrderManager:
|
|
|
398
390
|
filters=filters,
|
|
399
391
|
selections=selections,
|
|
400
392
|
settings=settings,
|
|
401
|
-
default_labeling_amount=1,
|
|
402
393
|
private_notes=private_notes
|
|
403
394
|
)
|
|
404
395
|
|
|
@@ -451,7 +442,6 @@ class RapidataOrderManager:
|
|
|
451
442
|
selections=selections,
|
|
452
443
|
settings=settings,
|
|
453
444
|
sentences=sentences,
|
|
454
|
-
default_labeling_amount=2,
|
|
455
445
|
private_notes=private_notes
|
|
456
446
|
)
|
|
457
447
|
|
|
@@ -623,7 +613,6 @@ class RapidataOrderManager:
|
|
|
623
613
|
filters=filters,
|
|
624
614
|
selections=selections,
|
|
625
615
|
settings=settings,
|
|
626
|
-
default_labeling_amount=2,
|
|
627
616
|
private_notes=private_notes
|
|
628
617
|
)
|
|
629
618
|
|
|
@@ -1,14 +1,9 @@
|
|
|
1
|
-
from pydantic import StrictBytes, StrictStr
|
|
2
1
|
from rapidata.rapidata_client.assets import MediaAsset, TextAsset, MultiAsset
|
|
3
2
|
from rapidata.rapidata_client.metadata import Metadata
|
|
4
|
-
from typing import Sequence
|
|
5
|
-
from typing import Any
|
|
3
|
+
from typing import Sequence, Any, cast
|
|
6
4
|
from rapidata.api_client.models.add_validation_rapid_model import (
|
|
7
5
|
AddValidationRapidModel,
|
|
8
6
|
)
|
|
9
|
-
from rapidata.api_client.models.add_validation_text_rapid_model import (
|
|
10
|
-
AddValidationTextRapidModel,
|
|
11
|
-
)
|
|
12
7
|
from rapidata.api_client.models.add_validation_rapid_model_payload import (
|
|
13
8
|
AddValidationRapidModelPayload,
|
|
14
9
|
)
|
|
@@ -32,38 +27,52 @@ class Rapid():
|
|
|
32
27
|
logger.debug(f"Created Rapid with asset: {self.asset}, metadata: {self.metadata}, payload: {self.payload}, truth: {self.truth}, randomCorrectProbability: {self.randomCorrectProbability}, explanation: {self.explanation}")
|
|
33
28
|
|
|
34
29
|
def _add_to_validation_set(self, validationSetId: str, openapi_service: OpenAPIService) -> None:
|
|
35
|
-
|
|
36
|
-
|
|
30
|
+
model = self.__to_model()
|
|
31
|
+
assets = self.__convert_to_assets()
|
|
32
|
+
if isinstance(assets[0], TextAsset):
|
|
33
|
+
assert all(isinstance(asset, TextAsset) for asset in assets)
|
|
34
|
+
texts = cast(list[TextAsset], assets)
|
|
35
|
+
openapi_service.validation_api.validation_set_validation_set_id_rapid_post(
|
|
37
36
|
validation_set_id=validationSetId,
|
|
38
|
-
|
|
37
|
+
model=model,
|
|
38
|
+
texts=[asset.text for asset in texts]
|
|
39
39
|
)
|
|
40
40
|
|
|
41
|
-
elif isinstance(
|
|
42
|
-
|
|
43
|
-
|
|
41
|
+
elif isinstance(assets[0], MediaAsset):
|
|
42
|
+
assert all(isinstance(asset, MediaAsset) for asset in assets)
|
|
43
|
+
files = cast(list[MediaAsset], assets)
|
|
44
|
+
openapi_service.validation_api.validation_set_validation_set_id_rapid_post(
|
|
44
45
|
validation_set_id=validationSetId,
|
|
45
|
-
model=model
|
|
46
|
+
model=model,
|
|
47
|
+
files=[asset.to_file() for asset in files],
|
|
48
|
+
urls=[asset.path for asset in files if not asset.is_local()]
|
|
46
49
|
)
|
|
47
50
|
|
|
48
51
|
else:
|
|
49
52
|
raise TypeError("The asset must be a MediaAsset, TextAsset, or MultiAsset")
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def __convert_to_assets(self) -> list[MediaAsset | TextAsset]:
|
|
56
|
+
assets: list[MediaAsset | TextAsset] = []
|
|
53
57
|
if isinstance(self.asset, MultiAsset):
|
|
54
58
|
for asset in self.asset.assets:
|
|
55
59
|
if isinstance(asset, MediaAsset):
|
|
56
60
|
assets.append(asset)
|
|
61
|
+
elif isinstance(asset, TextAsset):
|
|
62
|
+
assets.append(asset)
|
|
57
63
|
else:
|
|
58
|
-
raise TypeError("The asset is a multiasset, but not all assets are MediaAssets")
|
|
64
|
+
raise TypeError("The asset is a multiasset, but not all assets are MediaAssets or TextAssets")
|
|
59
65
|
|
|
60
66
|
if isinstance(self.asset, TextAsset):
|
|
61
|
-
|
|
67
|
+
assets = [self.asset]
|
|
62
68
|
|
|
63
69
|
if isinstance(self.asset, MediaAsset):
|
|
64
70
|
assets = [self.asset]
|
|
65
71
|
|
|
66
|
-
return
|
|
72
|
+
return assets
|
|
73
|
+
|
|
74
|
+
def __to_model(self) -> AddValidationRapidModel:
|
|
75
|
+
return AddValidationRapidModel(
|
|
67
76
|
payload=AddValidationRapidModelPayload(self.payload),
|
|
68
77
|
truth=AddValidationRapidModelTruth(self.truth),
|
|
69
78
|
metadata=[
|
|
@@ -72,31 +81,4 @@ class Rapid():
|
|
|
72
81
|
],
|
|
73
82
|
randomCorrectProbability=self.randomCorrectProbability,
|
|
74
83
|
explanation=self.explanation
|
|
75
|
-
)
|
|
76
|
-
|
|
77
|
-
def __to_text_model(self) -> AddValidationTextRapidModel:
|
|
78
|
-
texts: list[str] = []
|
|
79
|
-
if isinstance(self.asset, MultiAsset):
|
|
80
|
-
for asset in self.asset.assets:
|
|
81
|
-
if isinstance(asset, TextAsset):
|
|
82
|
-
texts.append(asset.text)
|
|
83
|
-
else:
|
|
84
|
-
raise TypeError("The asset is a multiasset, but not all assets are TextAssets")
|
|
85
|
-
|
|
86
|
-
if isinstance(self.asset, MediaAsset):
|
|
87
|
-
raise TypeError("The asset must contain Text")
|
|
88
|
-
|
|
89
|
-
if isinstance(self.asset, TextAsset):
|
|
90
|
-
texts = [self.asset.text]
|
|
91
|
-
|
|
92
|
-
return AddValidationTextRapidModel(
|
|
93
|
-
payload=AddValidationRapidModelPayload(self.payload),
|
|
94
|
-
truth=AddValidationRapidModelTruth(self.truth),
|
|
95
|
-
metadata=[
|
|
96
|
-
DatasetDatasetIdDatapointsPostRequestMetadataInner(meta.to_model())
|
|
97
|
-
for meta in self.metadata
|
|
98
|
-
],
|
|
99
|
-
randomCorrectProbability=self.randomCorrectProbability,
|
|
100
|
-
texts=texts,
|
|
101
|
-
explanation=self.explanation
|
|
102
|
-
)
|
|
84
|
+
)
|
|
@@ -11,7 +11,6 @@ from rapidata.api_client.models.page_info import PageInfo
|
|
|
11
11
|
from rapidata.api_client.models.root_filter import RootFilter
|
|
12
12
|
from rapidata.api_client.models.filter import Filter
|
|
13
13
|
from rapidata.api_client.models.sort_criterion import SortCriterion
|
|
14
|
-
from urllib3._collections import HTTPHeaderDict # type: ignore[import]
|
|
15
14
|
|
|
16
15
|
from rapidata.rapidata_client.validation.rapids.box import Box
|
|
17
16
|
|
|
@@ -527,9 +526,17 @@ class ValidationSetManager:
|
|
|
527
526
|
)
|
|
528
527
|
|
|
529
528
|
logger.debug("Adding rapids to validation set")
|
|
529
|
+
failed_rapids = []
|
|
530
530
|
for rapid in tqdm(rapids, desc="Uploading validation tasks", disable=RapidataOutputManager.silent_mode):
|
|
531
|
-
|
|
532
|
-
|
|
531
|
+
try:
|
|
532
|
+
validation_set.add_rapid(rapid)
|
|
533
|
+
except Exception:
|
|
534
|
+
failed_rapids.append(rapid.asset)
|
|
535
|
+
|
|
536
|
+
if failed_rapids:
|
|
537
|
+
logger.error(f"Failed to add {len(failed_rapids)} datapoints to validation set: {failed_rapids}")
|
|
538
|
+
raise RuntimeError(f"Failed to add {len(failed_rapids)} datapoints to validation set: {failed_rapids}")
|
|
539
|
+
|
|
533
540
|
managed_print()
|
|
534
541
|
managed_print(f"Validation set '{name}' created with ID {validation_set_id}\n",
|
|
535
542
|
f"Now viewable under: https://app.{self.__openapi_service.environment}/validation-set/detail/{validation_set_id}",
|