rapidata 2.37.0__py3-none-any.whl → 2.39.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rapidata might be problematic. Click here for more details.
- rapidata/__init__.py +3 -4
- rapidata/api_client/__init__.py +4 -5
- rapidata/api_client/api/benchmark_api.py +289 -3
- rapidata/api_client/api/leaderboard_api.py +35 -1
- rapidata/api_client/api/participant_api.py +289 -3
- rapidata/api_client/api/validation_set_api.py +119 -400
- rapidata/api_client/models/__init__.py +4 -5
- rapidata/api_client/models/ab_test_selection_a_inner.py +1 -1
- rapidata/api_client/models/compare_workflow_model1.py +1 -8
- rapidata/api_client/models/conditional_validation_selection.py +4 -9
- rapidata/api_client/models/confidence_interval.py +98 -0
- rapidata/api_client/models/create_simple_pipeline_model_pipeline_steps_inner.py +8 -22
- rapidata/api_client/models/get_standing_by_id_result.py +7 -2
- rapidata/api_client/models/get_validation_set_by_id_result.py +4 -2
- rapidata/api_client/models/simple_workflow_model1.py +1 -8
- rapidata/api_client/models/standing_by_leaderboard.py +10 -4
- rapidata/api_client/models/update_benchmark_model.py +87 -0
- rapidata/api_client/models/update_participant_model.py +87 -0
- rapidata/api_client/models/update_validation_set_model.py +93 -0
- rapidata/api_client/models/validation_chance.py +20 -3
- rapidata/api_client/models/validation_set_model.py +5 -42
- rapidata/api_client_README.md +7 -7
- rapidata/rapidata_client/__init__.py +1 -4
- rapidata/rapidata_client/api/{rapidata_exception.py → rapidata_api_client.py} +119 -2
- rapidata/rapidata_client/benchmark/leaderboard/rapidata_leaderboard.py +88 -46
- rapidata/rapidata_client/benchmark/participant/_participant.py +26 -9
- rapidata/rapidata_client/benchmark/rapidata_benchmark.py +310 -210
- rapidata/rapidata_client/benchmark/rapidata_benchmark_manager.py +134 -75
- rapidata/rapidata_client/config/__init__.py +3 -0
- rapidata/rapidata_client/config/logger.py +135 -0
- rapidata/rapidata_client/config/logging_config.py +58 -0
- rapidata/rapidata_client/config/managed_print.py +6 -0
- rapidata/rapidata_client/config/order_config.py +14 -0
- rapidata/rapidata_client/config/rapidata_config.py +15 -10
- rapidata/rapidata_client/config/tracer.py +130 -0
- rapidata/rapidata_client/config/upload_config.py +14 -0
- rapidata/rapidata_client/datapoints/_datapoint.py +1 -1
- rapidata/rapidata_client/datapoints/assets/__init__.py +1 -0
- rapidata/rapidata_client/datapoints/assets/_base_asset.py +2 -0
- rapidata/rapidata_client/datapoints/assets/_media_asset.py +1 -1
- rapidata/rapidata_client/datapoints/assets/_sessions.py +2 -2
- rapidata/rapidata_client/datapoints/assets/_text_asset.py +2 -2
- rapidata/rapidata_client/datapoints/assets/data_type_enum.py +1 -1
- rapidata/rapidata_client/datapoints/metadata/_media_asset_metadata.py +9 -8
- rapidata/rapidata_client/datapoints/metadata/_prompt_metadata.py +1 -2
- rapidata/rapidata_client/demographic/demographic_manager.py +16 -14
- rapidata/rapidata_client/filter/_base_filter.py +11 -5
- rapidata/rapidata_client/filter/age_filter.py +9 -3
- rapidata/rapidata_client/filter/and_filter.py +20 -5
- rapidata/rapidata_client/filter/campaign_filter.py +7 -1
- rapidata/rapidata_client/filter/country_filter.py +8 -2
- rapidata/rapidata_client/filter/custom_filter.py +9 -3
- rapidata/rapidata_client/filter/gender_filter.py +9 -3
- rapidata/rapidata_client/filter/language_filter.py +12 -5
- rapidata/rapidata_client/filter/models/age_group.py +4 -4
- rapidata/rapidata_client/filter/models/gender.py +4 -2
- rapidata/rapidata_client/filter/new_user_filter.py +3 -4
- rapidata/rapidata_client/filter/not_filter.py +17 -5
- rapidata/rapidata_client/filter/or_filter.py +20 -5
- rapidata/rapidata_client/filter/rapidata_filters.py +12 -9
- rapidata/rapidata_client/filter/response_count_filter.py +6 -0
- rapidata/rapidata_client/filter/user_score_filter.py +17 -5
- rapidata/rapidata_client/order/_rapidata_dataset.py +45 -17
- rapidata/rapidata_client/order/_rapidata_order_builder.py +19 -13
- rapidata/rapidata_client/order/rapidata_order.py +60 -48
- rapidata/rapidata_client/order/rapidata_order_manager.py +231 -197
- rapidata/rapidata_client/order/rapidata_results.py +71 -57
- rapidata/rapidata_client/rapidata_client.py +36 -23
- rapidata/rapidata_client/referee/__init__.py +1 -1
- rapidata/rapidata_client/referee/_base_referee.py +3 -1
- rapidata/rapidata_client/referee/_early_stopping_referee.py +2 -2
- rapidata/rapidata_client/selection/_base_selection.py +6 -0
- rapidata/rapidata_client/selection/ab_test_selection.py +7 -3
- rapidata/rapidata_client/selection/capped_selection.py +2 -2
- rapidata/rapidata_client/selection/conditional_validation_selection.py +12 -6
- rapidata/rapidata_client/selection/demographic_selection.py +9 -6
- rapidata/rapidata_client/selection/rapidata_selections.py +11 -8
- rapidata/rapidata_client/selection/shuffling_selection.py +5 -5
- rapidata/rapidata_client/selection/static_selection.py +5 -10
- rapidata/rapidata_client/selection/validation_selection.py +9 -5
- rapidata/rapidata_client/settings/_rapidata_setting.py +8 -0
- rapidata/rapidata_client/settings/alert_on_fast_response.py +8 -5
- rapidata/rapidata_client/settings/allow_neither_both.py +1 -0
- rapidata/rapidata_client/settings/custom_setting.py +3 -2
- rapidata/rapidata_client/settings/free_text_minimum_characters.py +9 -4
- rapidata/rapidata_client/settings/models/translation_behaviour_options.py +3 -2
- rapidata/rapidata_client/settings/no_shuffle.py +4 -2
- rapidata/rapidata_client/settings/play_video_until_the_end.py +7 -4
- rapidata/rapidata_client/settings/rapidata_settings.py +4 -3
- rapidata/rapidata_client/settings/translation_behaviour.py +7 -5
- rapidata/rapidata_client/validation/rapidata_validation_set.py +23 -17
- rapidata/rapidata_client/validation/rapids/box.py +3 -1
- rapidata/rapidata_client/validation/rapids/rapids.py +7 -1
- rapidata/rapidata_client/validation/rapids/rapids_manager.py +174 -141
- rapidata/rapidata_client/validation/validation_set_manager.py +285 -268
- rapidata/rapidata_client/workflow/__init__.py +1 -1
- rapidata/rapidata_client/workflow/_base_workflow.py +6 -1
- rapidata/rapidata_client/workflow/_classify_workflow.py +6 -0
- rapidata/rapidata_client/workflow/_compare_workflow.py +6 -0
- rapidata/rapidata_client/workflow/_draw_workflow.py +6 -0
- rapidata/rapidata_client/workflow/_evaluation_workflow.py +6 -0
- rapidata/rapidata_client/workflow/_free_text_workflow.py +6 -0
- rapidata/rapidata_client/workflow/_locate_workflow.py +6 -0
- rapidata/rapidata_client/workflow/_ranking_workflow.py +12 -0
- rapidata/rapidata_client/workflow/_select_words_workflow.py +6 -0
- rapidata/rapidata_client/workflow/_timestamp_workflow.py +6 -0
- rapidata/service/__init__.py +1 -1
- rapidata/service/credential_manager.py +1 -1
- rapidata/service/local_file_service.py +9 -8
- rapidata/service/openapi_service.py +2 -2
- {rapidata-2.37.0.dist-info → rapidata-2.39.0.dist-info}/METADATA +4 -1
- {rapidata-2.37.0.dist-info → rapidata-2.39.0.dist-info}/RECORD +114 -107
- rapidata/rapidata_client/logging/__init__.py +0 -2
- rapidata/rapidata_client/logging/logger.py +0 -122
- rapidata/rapidata_client/logging/output_manager.py +0 -20
- {rapidata-2.37.0.dist-info → rapidata-2.39.0.dist-info}/LICENSE +0 -0
- {rapidata-2.37.0.dist-info → rapidata-2.39.0.dist-info}/WHEEL +0 -0
|
@@ -1,38 +1,41 @@
|
|
|
1
1
|
import re
|
|
2
|
+
import urllib.parse
|
|
3
|
+
import webbrowser
|
|
4
|
+
from colorama import Fore
|
|
2
5
|
from typing import Literal, Optional, Sequence
|
|
3
|
-
|
|
4
|
-
from rapidata.api_client.models.
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
from rapidata.api_client.models.create_leaderboard_model import CreateLeaderboardModel
|
|
6
|
+
|
|
7
|
+
from rapidata.api_client.models.and_user_filter_model_filters_inner import (
|
|
8
|
+
AndUserFilterModelFiltersInner,
|
|
9
|
+
)
|
|
8
10
|
from rapidata.api_client.models.create_benchmark_participant_model import (
|
|
9
11
|
CreateBenchmarkParticipantModel,
|
|
10
12
|
)
|
|
13
|
+
from rapidata.api_client.models.create_leaderboard_model import CreateLeaderboardModel
|
|
14
|
+
from rapidata.api_client.models.filter import Filter
|
|
15
|
+
from rapidata.api_client.models.filter_operator import FilterOperator
|
|
16
|
+
from rapidata.api_client.models.file_asset_model import FileAssetModel
|
|
17
|
+
from rapidata.api_client.models.query_model import QueryModel
|
|
18
|
+
from rapidata.api_client.models.page_info import PageInfo
|
|
19
|
+
from rapidata.api_client.models.root_filter import RootFilter
|
|
20
|
+
from rapidata.api_client.models.source_url_metadata_model import SourceUrlMetadataModel
|
|
11
21
|
from rapidata.api_client.models.submit_prompt_model import SubmitPromptModel
|
|
12
22
|
from rapidata.api_client.models.submit_prompt_model_prompt_asset import (
|
|
13
23
|
SubmitPromptModelPromptAsset,
|
|
14
24
|
)
|
|
15
25
|
from rapidata.api_client.models.url_asset_input import UrlAssetInput
|
|
16
|
-
from rapidata.api_client.models.file_asset_model import FileAssetModel
|
|
17
|
-
from rapidata.api_client.models.source_url_metadata_model import SourceUrlMetadataModel
|
|
18
|
-
from rapidata.api_client.models.and_user_filter_model_filters_inner import (
|
|
19
|
-
AndUserFilterModelFiltersInner,
|
|
20
|
-
)
|
|
21
|
-
from rapidata.api_client.models.filter_operator import FilterOperator
|
|
22
|
-
|
|
23
|
-
from rapidata.rapidata_client.benchmark.participant._participant import (
|
|
24
|
-
BenchmarkParticipant,
|
|
25
|
-
)
|
|
26
|
-
from rapidata.rapidata_client.logging import logger
|
|
27
|
-
from rapidata.service.openapi_service import OpenAPIService
|
|
28
26
|
|
|
27
|
+
from rapidata.rapidata_client.benchmark._detail_mapper import DetailMapper
|
|
29
28
|
from rapidata.rapidata_client.benchmark.leaderboard.rapidata_leaderboard import (
|
|
30
29
|
RapidataLeaderboard,
|
|
31
30
|
)
|
|
31
|
+
from rapidata.rapidata_client.benchmark.participant._participant import (
|
|
32
|
+
BenchmarkParticipant,
|
|
33
|
+
)
|
|
32
34
|
from rapidata.rapidata_client.datapoints.assets import MediaAsset
|
|
33
|
-
from rapidata.rapidata_client.benchmark._detail_mapper import DetailMapper
|
|
34
35
|
from rapidata.rapidata_client.filter import RapidataFilter
|
|
36
|
+
from rapidata.rapidata_client.config import logger, managed_print, tracer
|
|
35
37
|
from rapidata.rapidata_client.settings import RapidataSetting
|
|
38
|
+
from rapidata.service.openapi_service import OpenAPIService
|
|
36
39
|
|
|
37
40
|
|
|
38
41
|
class RapidataBenchmark:
|
|
@@ -56,6 +59,9 @@ class RapidataBenchmark:
|
|
|
56
59
|
self.__leaderboards: list[RapidataLeaderboard] = []
|
|
57
60
|
self.__identifiers: list[str] = []
|
|
58
61
|
self.__tags: list[list[str]] = []
|
|
62
|
+
self.__benchmark_page: str = (
|
|
63
|
+
f"https://app.{self.__openapi_service.environment}/mri/benchmarks/{self.id}"
|
|
64
|
+
)
|
|
59
65
|
|
|
60
66
|
def __instantiate_prompts(self) -> None:
|
|
61
67
|
current_page = 1
|
|
@@ -99,162 +105,192 @@ class RapidataBenchmark:
|
|
|
99
105
|
|
|
100
106
|
@property
|
|
101
107
|
def identifiers(self) -> list[str]:
|
|
102
|
-
|
|
103
|
-
self.
|
|
108
|
+
with tracer.start_as_current_span("RapidataBenchmark.identifiers"):
|
|
109
|
+
if not self.__identifiers:
|
|
110
|
+
self.__instantiate_prompts()
|
|
104
111
|
|
|
105
|
-
|
|
112
|
+
return self.__identifiers
|
|
106
113
|
|
|
107
114
|
@property
|
|
108
115
|
def prompts(self) -> list[str | None]:
|
|
109
116
|
"""
|
|
110
117
|
Returns the prompts that are registered for the leaderboard.
|
|
111
118
|
"""
|
|
112
|
-
|
|
113
|
-
self.
|
|
119
|
+
with tracer.start_as_current_span("RapidataBenchmark.prompts"):
|
|
120
|
+
if not self.__prompts:
|
|
121
|
+
self.__instantiate_prompts()
|
|
114
122
|
|
|
115
|
-
|
|
123
|
+
return self.__prompts
|
|
116
124
|
|
|
117
125
|
@property
|
|
118
126
|
def prompt_assets(self) -> list[str | None]:
|
|
119
127
|
"""
|
|
120
128
|
Returns the prompt assets that are registered for the benchmark.
|
|
121
129
|
"""
|
|
122
|
-
|
|
123
|
-
self.
|
|
130
|
+
with tracer.start_as_current_span("RapidataBenchmark.prompt_assets"):
|
|
131
|
+
if not self.__prompt_assets:
|
|
132
|
+
self.__instantiate_prompts()
|
|
124
133
|
|
|
125
|
-
|
|
134
|
+
return self.__prompt_assets
|
|
126
135
|
|
|
127
136
|
@property
|
|
128
137
|
def tags(self) -> list[list[str]]:
|
|
129
138
|
"""
|
|
130
139
|
Returns the tags that are registered for the benchmark.
|
|
131
140
|
"""
|
|
132
|
-
|
|
133
|
-
self.
|
|
141
|
+
with tracer.start_as_current_span("RapidataBenchmark.tags"):
|
|
142
|
+
if not self.__tags:
|
|
143
|
+
self.__instantiate_prompts()
|
|
134
144
|
|
|
135
|
-
|
|
145
|
+
return self.__tags
|
|
136
146
|
|
|
137
147
|
@property
|
|
138
148
|
def leaderboards(self) -> list[RapidataLeaderboard]:
|
|
139
149
|
"""
|
|
140
150
|
Returns the leaderboards that are registered for the benchmark.
|
|
141
151
|
"""
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
152
|
+
with tracer.start_as_current_span("RapidataBenchmark.leaderboards"):
|
|
153
|
+
if not self.__leaderboards:
|
|
154
|
+
current_page = 1
|
|
155
|
+
total_pages = None
|
|
156
|
+
|
|
157
|
+
while True:
|
|
158
|
+
leaderboards_result = (
|
|
159
|
+
self.__openapi_service.leaderboard_api.leaderboards_get(
|
|
160
|
+
request=QueryModel(
|
|
161
|
+
filter=RootFilter(
|
|
162
|
+
filters=[
|
|
163
|
+
Filter(
|
|
164
|
+
field="BenchmarkId",
|
|
165
|
+
operator=FilterOperator.EQ,
|
|
166
|
+
value=self.id,
|
|
167
|
+
)
|
|
168
|
+
]
|
|
169
|
+
),
|
|
170
|
+
page=PageInfo(index=current_page, size=100),
|
|
171
|
+
)
|
|
160
172
|
)
|
|
161
173
|
)
|
|
162
|
-
)
|
|
163
|
-
|
|
164
|
-
if leaderboards_result.total_pages is None:
|
|
165
|
-
raise ValueError(
|
|
166
|
-
"An error occurred while fetching leaderboards: total_pages is None"
|
|
167
|
-
)
|
|
168
174
|
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
[
|
|
173
|
-
RapidataLeaderboard(
|
|
174
|
-
leaderboard.name,
|
|
175
|
-
leaderboard.instruction,
|
|
176
|
-
leaderboard.show_prompt,
|
|
177
|
-
leaderboard.show_prompt_asset,
|
|
178
|
-
leaderboard.is_inversed,
|
|
179
|
-
leaderboard.response_budget,
|
|
180
|
-
leaderboard.min_responses,
|
|
181
|
-
leaderboard.id,
|
|
182
|
-
self.__openapi_service,
|
|
175
|
+
if leaderboards_result.total_pages is None:
|
|
176
|
+
raise ValueError(
|
|
177
|
+
"An error occurred while fetching leaderboards: total_pages is None"
|
|
183
178
|
)
|
|
184
|
-
for leaderboard in leaderboards_result.items
|
|
185
|
-
]
|
|
186
|
-
)
|
|
187
179
|
|
|
188
|
-
|
|
189
|
-
|
|
180
|
+
total_pages = leaderboards_result.total_pages
|
|
181
|
+
|
|
182
|
+
self.__leaderboards.extend(
|
|
183
|
+
[
|
|
184
|
+
RapidataLeaderboard(
|
|
185
|
+
leaderboard.name,
|
|
186
|
+
leaderboard.instruction,
|
|
187
|
+
leaderboard.show_prompt,
|
|
188
|
+
leaderboard.show_prompt_asset,
|
|
189
|
+
leaderboard.is_inversed,
|
|
190
|
+
leaderboard.response_budget,
|
|
191
|
+
leaderboard.min_responses,
|
|
192
|
+
self.id,
|
|
193
|
+
leaderboard.id,
|
|
194
|
+
self.__openapi_service,
|
|
195
|
+
)
|
|
196
|
+
for leaderboard in leaderboards_result.items
|
|
197
|
+
]
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
if current_page >= total_pages:
|
|
201
|
+
break
|
|
190
202
|
|
|
191
|
-
|
|
203
|
+
current_page += 1
|
|
192
204
|
|
|
193
|
-
|
|
205
|
+
return self.__leaderboards
|
|
194
206
|
|
|
195
207
|
def add_prompt(
|
|
196
208
|
self,
|
|
197
|
-
identifier: str,
|
|
209
|
+
identifier: str | None = None,
|
|
198
210
|
prompt: str | None = None,
|
|
199
|
-
|
|
211
|
+
prompt_asset: str | None = None,
|
|
200
212
|
tags: Optional[list[str]] = None,
|
|
201
213
|
):
|
|
202
214
|
"""
|
|
203
215
|
Adds a prompt to the benchmark.
|
|
204
216
|
|
|
205
217
|
Args:
|
|
206
|
-
identifier: The identifier of the prompt/asset/tags that will be used to match up the media.
|
|
218
|
+
identifier: The identifier of the prompt/asset/tags that will be used to match up the media. If not provided, it will use the prompt, asset or prompt + asset as the identifier.
|
|
207
219
|
prompt: The prompt that will be used to evaluate the model.
|
|
208
|
-
|
|
220
|
+
prompt_asset: The prompt asset that will be used to evaluate the model. Provided as a link to the asset.
|
|
209
221
|
tags: The tags can be used to filter the leaderboard results. They will NOT be shown to the users.
|
|
210
222
|
"""
|
|
211
|
-
|
|
212
|
-
tags
|
|
223
|
+
with tracer.start_as_current_span("RapidataBenchmark.add_prompt"):
|
|
224
|
+
if tags is None:
|
|
225
|
+
tags = []
|
|
213
226
|
|
|
214
|
-
|
|
215
|
-
|
|
227
|
+
if prompt is None and prompt_asset is None:
|
|
228
|
+
raise ValueError("Prompt or prompt asset must be provided.")
|
|
216
229
|
|
|
217
|
-
|
|
218
|
-
|
|
230
|
+
if identifier is None and prompt is None:
|
|
231
|
+
raise ValueError("Identifier or prompt must be provided.")
|
|
219
232
|
|
|
220
|
-
|
|
221
|
-
|
|
233
|
+
if identifier and not isinstance(identifier, str):
|
|
234
|
+
raise ValueError("Identifier must be a string.")
|
|
222
235
|
|
|
223
|
-
|
|
224
|
-
|
|
236
|
+
if prompt and not isinstance(prompt, str):
|
|
237
|
+
raise ValueError("Prompt must be a string.")
|
|
225
238
|
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
raise ValueError("Asset must be a link to the asset.")
|
|
239
|
+
if prompt_asset and not isinstance(prompt_asset, str):
|
|
240
|
+
raise ValueError(
|
|
241
|
+
"Asset must be a string. That is the link to the asset."
|
|
242
|
+
)
|
|
231
243
|
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
244
|
+
if identifier is None:
|
|
245
|
+
assert prompt is not None
|
|
246
|
+
if prompt in self.prompts:
|
|
247
|
+
raise ValueError(
|
|
248
|
+
"Prompts must be unique. Otherwise use identifiers."
|
|
249
|
+
)
|
|
250
|
+
identifier = prompt
|
|
251
|
+
|
|
252
|
+
if identifier in self.identifiers:
|
|
253
|
+
raise ValueError("Identifier already exists in the benchmark.")
|
|
254
|
+
|
|
255
|
+
if prompt_asset is not None and not re.match(r"^https?://", prompt_asset):
|
|
256
|
+
raise ValueError("Prompt asset must be a link to the asset.")
|
|
257
|
+
|
|
258
|
+
if tags is not None and (
|
|
259
|
+
not isinstance(tags, list)
|
|
260
|
+
or not all(isinstance(tag, str) for tag in tags)
|
|
261
|
+
):
|
|
262
|
+
raise ValueError("Tags must be a list of strings.")
|
|
263
|
+
|
|
264
|
+
logger.info(
|
|
265
|
+
"Adding identifier %s with prompt %s, prompt asset %s and tags %s to benchmark %s",
|
|
266
|
+
identifier,
|
|
267
|
+
prompt,
|
|
268
|
+
prompt_asset,
|
|
269
|
+
tags,
|
|
270
|
+
self.id,
|
|
271
|
+
)
|
|
236
272
|
|
|
237
|
-
|
|
273
|
+
self.__identifiers.append(identifier)
|
|
238
274
|
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
275
|
+
self.__tags.append(tags)
|
|
276
|
+
self.__prompts.append(prompt)
|
|
277
|
+
self.__prompt_assets.append(prompt_asset)
|
|
242
278
|
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
279
|
+
self.__openapi_service.benchmark_api.benchmark_benchmark_id_prompt_post(
|
|
280
|
+
benchmark_id=self.id,
|
|
281
|
+
submit_prompt_model=SubmitPromptModel(
|
|
282
|
+
identifier=identifier,
|
|
283
|
+
prompt=prompt,
|
|
284
|
+
promptAsset=(
|
|
285
|
+
SubmitPromptModelPromptAsset(
|
|
286
|
+
UrlAssetInput(_t="UrlAssetInput", url=prompt_asset)
|
|
287
|
+
)
|
|
288
|
+
if prompt_asset is not None
|
|
289
|
+
else None
|
|
290
|
+
),
|
|
291
|
+
tags=tags,
|
|
254
292
|
),
|
|
255
|
-
|
|
256
|
-
),
|
|
257
|
-
)
|
|
293
|
+
)
|
|
258
294
|
|
|
259
295
|
def create_leaderboard(
|
|
260
296
|
self,
|
|
@@ -284,126 +320,190 @@ class RapidataBenchmark:
|
|
|
284
320
|
filters: The filters that should be applied to the leaderboard. Will determine who can solve answer in the leaderboard. (default: [])
|
|
285
321
|
settings: The settings that should be applied to the leaderboard. Will determine the behavior of the tasks on the leaderboard. (default: [])
|
|
286
322
|
"""
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
name
|
|
297
|
-
instruction
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
filters
|
|
305
|
-
|
|
306
|
-
AndUserFilterModelFiltersInner(filter._to_model())
|
|
307
|
-
for filter in filters
|
|
308
|
-
]
|
|
309
|
-
if filters
|
|
310
|
-
else None
|
|
311
|
-
),
|
|
312
|
-
featureFlags=(
|
|
313
|
-
[setting._to_feature_flag() for setting in settings]
|
|
314
|
-
if settings
|
|
315
|
-
else None
|
|
316
|
-
),
|
|
323
|
+
with tracer.start_as_current_span("create_leaderboard"):
|
|
324
|
+
if not isinstance(min_responses_per_matchup, int):
|
|
325
|
+
raise ValueError("Min responses per matchup must be an integer")
|
|
326
|
+
|
|
327
|
+
if min_responses_per_matchup < 3:
|
|
328
|
+
raise ValueError("Min responses per matchup must be at least 3")
|
|
329
|
+
|
|
330
|
+
logger.info(
|
|
331
|
+
"Creating leaderboard %s with instruction %s, show_prompt %s, show_prompt_asset %s, inverse_ranking %s, level_of_detail %s, min_responses_per_matchup %s, validation_set_id %s, filters %s, settings %s",
|
|
332
|
+
name,
|
|
333
|
+
instruction,
|
|
334
|
+
show_prompt,
|
|
335
|
+
show_prompt_asset,
|
|
336
|
+
inverse_ranking,
|
|
337
|
+
level_of_detail,
|
|
338
|
+
min_responses_per_matchup,
|
|
339
|
+
validation_set_id,
|
|
340
|
+
filters,
|
|
341
|
+
settings,
|
|
317
342
|
)
|
|
318
|
-
)
|
|
319
343
|
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
344
|
+
leaderboard_result = (
|
|
345
|
+
self.__openapi_service.leaderboard_api.leaderboard_post(
|
|
346
|
+
create_leaderboard_model=CreateLeaderboardModel(
|
|
347
|
+
benchmarkId=self.id,
|
|
348
|
+
name=name,
|
|
349
|
+
instruction=instruction,
|
|
350
|
+
showPrompt=show_prompt,
|
|
351
|
+
showPromptAsset=show_prompt_asset,
|
|
352
|
+
isInversed=inverse_ranking,
|
|
353
|
+
minResponses=min_responses_per_matchup,
|
|
354
|
+
responseBudget=DetailMapper.get_budget(level_of_detail),
|
|
355
|
+
validationSetId=validation_set_id,
|
|
356
|
+
filters=(
|
|
357
|
+
[
|
|
358
|
+
AndUserFilterModelFiltersInner(filter._to_model())
|
|
359
|
+
for filter in filters
|
|
360
|
+
]
|
|
361
|
+
if filters
|
|
362
|
+
else None
|
|
363
|
+
),
|
|
364
|
+
featureFlags=(
|
|
365
|
+
[setting._to_feature_flag() for setting in settings]
|
|
366
|
+
if settings
|
|
367
|
+
else None
|
|
368
|
+
),
|
|
369
|
+
)
|
|
370
|
+
)
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
assert (
|
|
374
|
+
leaderboard_result.benchmark_id == self.id
|
|
375
|
+
), "The leaderboard was not created for the correct benchmark."
|
|
376
|
+
|
|
377
|
+
logger.info("Leaderboard created with id %s", leaderboard_result.id)
|
|
378
|
+
|
|
379
|
+
return RapidataLeaderboard(
|
|
380
|
+
name,
|
|
381
|
+
instruction,
|
|
382
|
+
show_prompt,
|
|
383
|
+
show_prompt_asset,
|
|
384
|
+
inverse_ranking,
|
|
385
|
+
leaderboard_result.response_budget,
|
|
386
|
+
min_responses_per_matchup,
|
|
387
|
+
self.id,
|
|
388
|
+
leaderboard_result.id,
|
|
389
|
+
self.__openapi_service,
|
|
390
|
+
)
|
|
335
391
|
|
|
336
392
|
def evaluate_model(
|
|
337
|
-
self,
|
|
393
|
+
self,
|
|
394
|
+
name: str,
|
|
395
|
+
media: list[str],
|
|
396
|
+
identifiers: list[str] | None = None,
|
|
397
|
+
prompts: list[str] | None = None,
|
|
338
398
|
) -> None:
|
|
339
399
|
"""
|
|
340
400
|
Evaluates a model on the benchmark across all leaderboards.
|
|
341
401
|
|
|
402
|
+
prompts or identifiers must be provided to match the media.
|
|
403
|
+
|
|
342
404
|
Args:
|
|
343
405
|
name: The name of the model.
|
|
344
406
|
media: The generated images/videos that will be used to evaluate the model.
|
|
345
|
-
identifiers: The identifiers that correspond to the media. The order of the identifiers must match the order of the media
|
|
407
|
+
identifiers: The identifiers that correspond to the media. The order of the identifiers must match the order of the media.\n
|
|
346
408
|
The identifiers that are used must be registered for the benchmark. To see the registered identifiers, use the identifiers property.
|
|
409
|
+
prompts: The prompts that correspond to the media. The order of the prompts must match the order of the media.
|
|
347
410
|
"""
|
|
348
|
-
|
|
349
|
-
|
|
411
|
+
with tracer.start_as_current_span("evaluate_model"):
|
|
412
|
+
if not media:
|
|
413
|
+
raise ValueError("Media must be a non-empty list of strings")
|
|
350
414
|
|
|
351
|
-
|
|
352
|
-
|
|
415
|
+
if not identifiers and not prompts:
|
|
416
|
+
raise ValueError("Identifiers or prompts must be provided.")
|
|
353
417
|
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
)
|
|
418
|
+
if identifiers and prompts:
|
|
419
|
+
raise ValueError(
|
|
420
|
+
"Identifiers and prompts cannot be provided at the same time. Use one or the other."
|
|
421
|
+
)
|
|
359
422
|
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
assets.append(MediaAsset(media_path))
|
|
423
|
+
if not identifiers:
|
|
424
|
+
assert prompts is not None
|
|
425
|
+
identifiers = prompts
|
|
364
426
|
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
),
|
|
370
|
-
)
|
|
427
|
+
if len(media) != len(identifiers):
|
|
428
|
+
raise ValueError(
|
|
429
|
+
"Media and identifiers/prompts must have the same length"
|
|
430
|
+
)
|
|
371
431
|
|
|
372
|
-
|
|
432
|
+
if not all(identifier in self.identifiers for identifier in identifiers):
|
|
433
|
+
raise ValueError(
|
|
434
|
+
"All identifiers/prompts must be in the registered identifiers/prompts list. To see the registered identifiers/prompts, use the identifiers/prompts property."
|
|
435
|
+
)
|
|
373
436
|
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
437
|
+
# happens before the creation of the participant to ensure all media paths are valid
|
|
438
|
+
assets: list[MediaAsset] = []
|
|
439
|
+
for media_path in media:
|
|
440
|
+
assets.append(MediaAsset(media_path))
|
|
377
441
|
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
442
|
+
participant_result = self.__openapi_service.benchmark_api.benchmark_benchmark_id_participants_post(
|
|
443
|
+
benchmark_id=self.id,
|
|
444
|
+
create_benchmark_participant_model=CreateBenchmarkParticipantModel(
|
|
445
|
+
name=name,
|
|
446
|
+
),
|
|
447
|
+
)
|
|
382
448
|
|
|
383
|
-
|
|
384
|
-
success_rate = (
|
|
385
|
-
(len(successful_uploads) / total_uploads * 100) if total_uploads > 0 else 0
|
|
386
|
-
)
|
|
387
|
-
logger.info(
|
|
388
|
-
f"Upload complete: {len(successful_uploads)} successful, {len(failed_uploads)} failed ({success_rate:.1f}% success rate)"
|
|
389
|
-
)
|
|
449
|
+
logger.info(f"Participant created: {participant_result.participant_id}")
|
|
390
450
|
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
f"Failed uploads for media: {[asset.path for asset in failed_uploads]}"
|
|
394
|
-
)
|
|
395
|
-
logger.warning(
|
|
396
|
-
"Some uploads failed. The model evaluation may be incomplete."
|
|
451
|
+
participant = BenchmarkParticipant(
|
|
452
|
+
name, participant_result.participant_id, self.__openapi_service
|
|
397
453
|
)
|
|
398
454
|
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
455
|
+
with tracer.start_as_current_span("upload_media_for_participant"):
|
|
456
|
+
logger.info(
|
|
457
|
+
f"Uploading {len(assets)} media assets to participant {participant.id}"
|
|
458
|
+
)
|
|
459
|
+
|
|
460
|
+
successful_uploads, failed_uploads = participant.upload_media(
|
|
461
|
+
assets,
|
|
462
|
+
identifiers,
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
total_uploads = len(assets)
|
|
466
|
+
success_rate = (
|
|
467
|
+
(len(successful_uploads) / total_uploads * 100)
|
|
468
|
+
if total_uploads > 0
|
|
469
|
+
else 0
|
|
470
|
+
)
|
|
471
|
+
logger.info(
|
|
472
|
+
f"Upload complete: {len(successful_uploads)} successful, {len(failed_uploads)} failed ({success_rate:.1f}% success rate)"
|
|
473
|
+
)
|
|
474
|
+
|
|
475
|
+
if failed_uploads:
|
|
476
|
+
logger.error(
|
|
477
|
+
f"Failed uploads for media: {[asset.path for asset in failed_uploads]}"
|
|
478
|
+
)
|
|
479
|
+
logger.warning(
|
|
480
|
+
"Some uploads failed. The model evaluation may be incomplete."
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
if len(successful_uploads) == 0:
|
|
484
|
+
raise RuntimeError(
|
|
485
|
+
"No uploads were successful. The model evaluation will not be completed."
|
|
486
|
+
)
|
|
487
|
+
|
|
488
|
+
self.__openapi_service.participant_api.participants_participant_id_submit_post(
|
|
489
|
+
participant_id=participant_result.participant_id
|
|
402
490
|
)
|
|
403
491
|
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
492
|
+
def view(self) -> None:
|
|
493
|
+
"""
|
|
494
|
+
Views the benchmark.
|
|
495
|
+
"""
|
|
496
|
+
logger.info("Opening benchmark page in browser...")
|
|
497
|
+
could_open_browser = webbrowser.open(self.__benchmark_page)
|
|
498
|
+
if not could_open_browser:
|
|
499
|
+
encoded_url = urllib.parse.quote(
|
|
500
|
+
self.__benchmark_page, safe="%/:=&?~#+!$,;'@()*[]"
|
|
501
|
+
)
|
|
502
|
+
managed_print(
|
|
503
|
+
Fore.RED
|
|
504
|
+
+ f"Please open this URL in your browser: '{encoded_url}'"
|
|
505
|
+
+ Fore.RESET
|
|
506
|
+
)
|
|
407
507
|
|
|
408
508
|
def __str__(self) -> str:
|
|
409
509
|
return f"RapidataBenchmark(name={self.name}, id={self.id})"
|