rapidata 2.37.0__py3-none-any.whl → 2.39.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rapidata might be problematic. Click here for more details.

Files changed (117) hide show
  1. rapidata/__init__.py +3 -4
  2. rapidata/api_client/__init__.py +4 -5
  3. rapidata/api_client/api/benchmark_api.py +289 -3
  4. rapidata/api_client/api/leaderboard_api.py +35 -1
  5. rapidata/api_client/api/participant_api.py +289 -3
  6. rapidata/api_client/api/validation_set_api.py +119 -400
  7. rapidata/api_client/models/__init__.py +4 -5
  8. rapidata/api_client/models/ab_test_selection_a_inner.py +1 -1
  9. rapidata/api_client/models/compare_workflow_model1.py +1 -8
  10. rapidata/api_client/models/conditional_validation_selection.py +4 -9
  11. rapidata/api_client/models/confidence_interval.py +98 -0
  12. rapidata/api_client/models/create_simple_pipeline_model_pipeline_steps_inner.py +8 -22
  13. rapidata/api_client/models/get_standing_by_id_result.py +7 -2
  14. rapidata/api_client/models/get_validation_set_by_id_result.py +4 -2
  15. rapidata/api_client/models/simple_workflow_model1.py +1 -8
  16. rapidata/api_client/models/standing_by_leaderboard.py +10 -4
  17. rapidata/api_client/models/update_benchmark_model.py +87 -0
  18. rapidata/api_client/models/update_participant_model.py +87 -0
  19. rapidata/api_client/models/update_validation_set_model.py +93 -0
  20. rapidata/api_client/models/validation_chance.py +20 -3
  21. rapidata/api_client/models/validation_set_model.py +5 -42
  22. rapidata/api_client_README.md +7 -7
  23. rapidata/rapidata_client/__init__.py +1 -4
  24. rapidata/rapidata_client/api/{rapidata_exception.py → rapidata_api_client.py} +119 -2
  25. rapidata/rapidata_client/benchmark/leaderboard/rapidata_leaderboard.py +88 -46
  26. rapidata/rapidata_client/benchmark/participant/_participant.py +26 -9
  27. rapidata/rapidata_client/benchmark/rapidata_benchmark.py +310 -210
  28. rapidata/rapidata_client/benchmark/rapidata_benchmark_manager.py +134 -75
  29. rapidata/rapidata_client/config/__init__.py +3 -0
  30. rapidata/rapidata_client/config/logger.py +135 -0
  31. rapidata/rapidata_client/config/logging_config.py +58 -0
  32. rapidata/rapidata_client/config/managed_print.py +6 -0
  33. rapidata/rapidata_client/config/order_config.py +14 -0
  34. rapidata/rapidata_client/config/rapidata_config.py +15 -10
  35. rapidata/rapidata_client/config/tracer.py +130 -0
  36. rapidata/rapidata_client/config/upload_config.py +14 -0
  37. rapidata/rapidata_client/datapoints/_datapoint.py +1 -1
  38. rapidata/rapidata_client/datapoints/assets/__init__.py +1 -0
  39. rapidata/rapidata_client/datapoints/assets/_base_asset.py +2 -0
  40. rapidata/rapidata_client/datapoints/assets/_media_asset.py +1 -1
  41. rapidata/rapidata_client/datapoints/assets/_sessions.py +2 -2
  42. rapidata/rapidata_client/datapoints/assets/_text_asset.py +2 -2
  43. rapidata/rapidata_client/datapoints/assets/data_type_enum.py +1 -1
  44. rapidata/rapidata_client/datapoints/metadata/_media_asset_metadata.py +9 -8
  45. rapidata/rapidata_client/datapoints/metadata/_prompt_metadata.py +1 -2
  46. rapidata/rapidata_client/demographic/demographic_manager.py +16 -14
  47. rapidata/rapidata_client/filter/_base_filter.py +11 -5
  48. rapidata/rapidata_client/filter/age_filter.py +9 -3
  49. rapidata/rapidata_client/filter/and_filter.py +20 -5
  50. rapidata/rapidata_client/filter/campaign_filter.py +7 -1
  51. rapidata/rapidata_client/filter/country_filter.py +8 -2
  52. rapidata/rapidata_client/filter/custom_filter.py +9 -3
  53. rapidata/rapidata_client/filter/gender_filter.py +9 -3
  54. rapidata/rapidata_client/filter/language_filter.py +12 -5
  55. rapidata/rapidata_client/filter/models/age_group.py +4 -4
  56. rapidata/rapidata_client/filter/models/gender.py +4 -2
  57. rapidata/rapidata_client/filter/new_user_filter.py +3 -4
  58. rapidata/rapidata_client/filter/not_filter.py +17 -5
  59. rapidata/rapidata_client/filter/or_filter.py +20 -5
  60. rapidata/rapidata_client/filter/rapidata_filters.py +12 -9
  61. rapidata/rapidata_client/filter/response_count_filter.py +6 -0
  62. rapidata/rapidata_client/filter/user_score_filter.py +17 -5
  63. rapidata/rapidata_client/order/_rapidata_dataset.py +45 -17
  64. rapidata/rapidata_client/order/_rapidata_order_builder.py +19 -13
  65. rapidata/rapidata_client/order/rapidata_order.py +60 -48
  66. rapidata/rapidata_client/order/rapidata_order_manager.py +231 -197
  67. rapidata/rapidata_client/order/rapidata_results.py +71 -57
  68. rapidata/rapidata_client/rapidata_client.py +36 -23
  69. rapidata/rapidata_client/referee/__init__.py +1 -1
  70. rapidata/rapidata_client/referee/_base_referee.py +3 -1
  71. rapidata/rapidata_client/referee/_early_stopping_referee.py +2 -2
  72. rapidata/rapidata_client/selection/_base_selection.py +6 -0
  73. rapidata/rapidata_client/selection/ab_test_selection.py +7 -3
  74. rapidata/rapidata_client/selection/capped_selection.py +2 -2
  75. rapidata/rapidata_client/selection/conditional_validation_selection.py +12 -6
  76. rapidata/rapidata_client/selection/demographic_selection.py +9 -6
  77. rapidata/rapidata_client/selection/rapidata_selections.py +11 -8
  78. rapidata/rapidata_client/selection/shuffling_selection.py +5 -5
  79. rapidata/rapidata_client/selection/static_selection.py +5 -10
  80. rapidata/rapidata_client/selection/validation_selection.py +9 -5
  81. rapidata/rapidata_client/settings/_rapidata_setting.py +8 -0
  82. rapidata/rapidata_client/settings/alert_on_fast_response.py +8 -5
  83. rapidata/rapidata_client/settings/allow_neither_both.py +1 -0
  84. rapidata/rapidata_client/settings/custom_setting.py +3 -2
  85. rapidata/rapidata_client/settings/free_text_minimum_characters.py +9 -4
  86. rapidata/rapidata_client/settings/models/translation_behaviour_options.py +3 -2
  87. rapidata/rapidata_client/settings/no_shuffle.py +4 -2
  88. rapidata/rapidata_client/settings/play_video_until_the_end.py +7 -4
  89. rapidata/rapidata_client/settings/rapidata_settings.py +4 -3
  90. rapidata/rapidata_client/settings/translation_behaviour.py +7 -5
  91. rapidata/rapidata_client/validation/rapidata_validation_set.py +23 -17
  92. rapidata/rapidata_client/validation/rapids/box.py +3 -1
  93. rapidata/rapidata_client/validation/rapids/rapids.py +7 -1
  94. rapidata/rapidata_client/validation/rapids/rapids_manager.py +174 -141
  95. rapidata/rapidata_client/validation/validation_set_manager.py +285 -268
  96. rapidata/rapidata_client/workflow/__init__.py +1 -1
  97. rapidata/rapidata_client/workflow/_base_workflow.py +6 -1
  98. rapidata/rapidata_client/workflow/_classify_workflow.py +6 -0
  99. rapidata/rapidata_client/workflow/_compare_workflow.py +6 -0
  100. rapidata/rapidata_client/workflow/_draw_workflow.py +6 -0
  101. rapidata/rapidata_client/workflow/_evaluation_workflow.py +6 -0
  102. rapidata/rapidata_client/workflow/_free_text_workflow.py +6 -0
  103. rapidata/rapidata_client/workflow/_locate_workflow.py +6 -0
  104. rapidata/rapidata_client/workflow/_ranking_workflow.py +12 -0
  105. rapidata/rapidata_client/workflow/_select_words_workflow.py +6 -0
  106. rapidata/rapidata_client/workflow/_timestamp_workflow.py +6 -0
  107. rapidata/service/__init__.py +1 -1
  108. rapidata/service/credential_manager.py +1 -1
  109. rapidata/service/local_file_service.py +9 -8
  110. rapidata/service/openapi_service.py +2 -2
  111. {rapidata-2.37.0.dist-info → rapidata-2.39.0.dist-info}/METADATA +4 -1
  112. {rapidata-2.37.0.dist-info → rapidata-2.39.0.dist-info}/RECORD +114 -107
  113. rapidata/rapidata_client/logging/__init__.py +0 -2
  114. rapidata/rapidata_client/logging/logger.py +0 -122
  115. rapidata/rapidata_client/logging/output_manager.py +0 -20
  116. {rapidata-2.37.0.dist-info → rapidata-2.39.0.dist-info}/LICENSE +0 -0
  117. {rapidata-2.37.0.dist-info → rapidata-2.39.0.dist-info}/WHEEL +0 -0
@@ -1,38 +1,41 @@
1
1
  import re
2
+ import urllib.parse
3
+ import webbrowser
4
+ from colorama import Fore
2
5
  from typing import Literal, Optional, Sequence
3
- from rapidata.api_client.models.root_filter import RootFilter
4
- from rapidata.api_client.models.filter import Filter
5
- from rapidata.api_client.models.query_model import QueryModel
6
- from rapidata.api_client.models.page_info import PageInfo
7
- from rapidata.api_client.models.create_leaderboard_model import CreateLeaderboardModel
6
+
7
+ from rapidata.api_client.models.and_user_filter_model_filters_inner import (
8
+ AndUserFilterModelFiltersInner,
9
+ )
8
10
  from rapidata.api_client.models.create_benchmark_participant_model import (
9
11
  CreateBenchmarkParticipantModel,
10
12
  )
13
+ from rapidata.api_client.models.create_leaderboard_model import CreateLeaderboardModel
14
+ from rapidata.api_client.models.filter import Filter
15
+ from rapidata.api_client.models.filter_operator import FilterOperator
16
+ from rapidata.api_client.models.file_asset_model import FileAssetModel
17
+ from rapidata.api_client.models.query_model import QueryModel
18
+ from rapidata.api_client.models.page_info import PageInfo
19
+ from rapidata.api_client.models.root_filter import RootFilter
20
+ from rapidata.api_client.models.source_url_metadata_model import SourceUrlMetadataModel
11
21
  from rapidata.api_client.models.submit_prompt_model import SubmitPromptModel
12
22
  from rapidata.api_client.models.submit_prompt_model_prompt_asset import (
13
23
  SubmitPromptModelPromptAsset,
14
24
  )
15
25
  from rapidata.api_client.models.url_asset_input import UrlAssetInput
16
- from rapidata.api_client.models.file_asset_model import FileAssetModel
17
- from rapidata.api_client.models.source_url_metadata_model import SourceUrlMetadataModel
18
- from rapidata.api_client.models.and_user_filter_model_filters_inner import (
19
- AndUserFilterModelFiltersInner,
20
- )
21
- from rapidata.api_client.models.filter_operator import FilterOperator
22
-
23
- from rapidata.rapidata_client.benchmark.participant._participant import (
24
- BenchmarkParticipant,
25
- )
26
- from rapidata.rapidata_client.logging import logger
27
- from rapidata.service.openapi_service import OpenAPIService
28
26
 
27
+ from rapidata.rapidata_client.benchmark._detail_mapper import DetailMapper
29
28
  from rapidata.rapidata_client.benchmark.leaderboard.rapidata_leaderboard import (
30
29
  RapidataLeaderboard,
31
30
  )
31
+ from rapidata.rapidata_client.benchmark.participant._participant import (
32
+ BenchmarkParticipant,
33
+ )
32
34
  from rapidata.rapidata_client.datapoints.assets import MediaAsset
33
- from rapidata.rapidata_client.benchmark._detail_mapper import DetailMapper
34
35
  from rapidata.rapidata_client.filter import RapidataFilter
36
+ from rapidata.rapidata_client.config import logger, managed_print, tracer
35
37
  from rapidata.rapidata_client.settings import RapidataSetting
38
+ from rapidata.service.openapi_service import OpenAPIService
36
39
 
37
40
 
38
41
  class RapidataBenchmark:
@@ -56,6 +59,9 @@ class RapidataBenchmark:
56
59
  self.__leaderboards: list[RapidataLeaderboard] = []
57
60
  self.__identifiers: list[str] = []
58
61
  self.__tags: list[list[str]] = []
62
+ self.__benchmark_page: str = (
63
+ f"https://app.{self.__openapi_service.environment}/mri/benchmarks/{self.id}"
64
+ )
59
65
 
60
66
  def __instantiate_prompts(self) -> None:
61
67
  current_page = 1
@@ -99,162 +105,192 @@ class RapidataBenchmark:
99
105
 
100
106
  @property
101
107
  def identifiers(self) -> list[str]:
102
- if not self.__identifiers:
103
- self.__instantiate_prompts()
108
+ with tracer.start_as_current_span("RapidataBenchmark.identifiers"):
109
+ if not self.__identifiers:
110
+ self.__instantiate_prompts()
104
111
 
105
- return self.__identifiers
112
+ return self.__identifiers
106
113
 
107
114
  @property
108
115
  def prompts(self) -> list[str | None]:
109
116
  """
110
117
  Returns the prompts that are registered for the leaderboard.
111
118
  """
112
- if not self.__prompts:
113
- self.__instantiate_prompts()
119
+ with tracer.start_as_current_span("RapidataBenchmark.prompts"):
120
+ if not self.__prompts:
121
+ self.__instantiate_prompts()
114
122
 
115
- return self.__prompts
123
+ return self.__prompts
116
124
 
117
125
  @property
118
126
  def prompt_assets(self) -> list[str | None]:
119
127
  """
120
128
  Returns the prompt assets that are registered for the benchmark.
121
129
  """
122
- if not self.__prompt_assets:
123
- self.__instantiate_prompts()
130
+ with tracer.start_as_current_span("RapidataBenchmark.prompt_assets"):
131
+ if not self.__prompt_assets:
132
+ self.__instantiate_prompts()
124
133
 
125
- return self.__prompt_assets
134
+ return self.__prompt_assets
126
135
 
127
136
  @property
128
137
  def tags(self) -> list[list[str]]:
129
138
  """
130
139
  Returns the tags that are registered for the benchmark.
131
140
  """
132
- if not self.__tags:
133
- self.__instantiate_prompts()
141
+ with tracer.start_as_current_span("RapidataBenchmark.tags"):
142
+ if not self.__tags:
143
+ self.__instantiate_prompts()
134
144
 
135
- return self.__tags
145
+ return self.__tags
136
146
 
137
147
  @property
138
148
  def leaderboards(self) -> list[RapidataLeaderboard]:
139
149
  """
140
150
  Returns the leaderboards that are registered for the benchmark.
141
151
  """
142
- if not self.__leaderboards:
143
- current_page = 1
144
- total_pages = None
145
-
146
- while True:
147
- leaderboards_result = (
148
- self.__openapi_service.leaderboard_api.leaderboards_get(
149
- request=QueryModel(
150
- filter=RootFilter(
151
- filters=[
152
- Filter(
153
- field="BenchmarkId",
154
- operator=FilterOperator.EQ,
155
- value=self.id,
156
- )
157
- ]
158
- ),
159
- page=PageInfo(index=current_page, size=100),
152
+ with tracer.start_as_current_span("RapidataBenchmark.leaderboards"):
153
+ if not self.__leaderboards:
154
+ current_page = 1
155
+ total_pages = None
156
+
157
+ while True:
158
+ leaderboards_result = (
159
+ self.__openapi_service.leaderboard_api.leaderboards_get(
160
+ request=QueryModel(
161
+ filter=RootFilter(
162
+ filters=[
163
+ Filter(
164
+ field="BenchmarkId",
165
+ operator=FilterOperator.EQ,
166
+ value=self.id,
167
+ )
168
+ ]
169
+ ),
170
+ page=PageInfo(index=current_page, size=100),
171
+ )
160
172
  )
161
173
  )
162
- )
163
-
164
- if leaderboards_result.total_pages is None:
165
- raise ValueError(
166
- "An error occurred while fetching leaderboards: total_pages is None"
167
- )
168
174
 
169
- total_pages = leaderboards_result.total_pages
170
-
171
- self.__leaderboards.extend(
172
- [
173
- RapidataLeaderboard(
174
- leaderboard.name,
175
- leaderboard.instruction,
176
- leaderboard.show_prompt,
177
- leaderboard.show_prompt_asset,
178
- leaderboard.is_inversed,
179
- leaderboard.response_budget,
180
- leaderboard.min_responses,
181
- leaderboard.id,
182
- self.__openapi_service,
175
+ if leaderboards_result.total_pages is None:
176
+ raise ValueError(
177
+ "An error occurred while fetching leaderboards: total_pages is None"
183
178
  )
184
- for leaderboard in leaderboards_result.items
185
- ]
186
- )
187
179
 
188
- if current_page >= total_pages:
189
- break
180
+ total_pages = leaderboards_result.total_pages
181
+
182
+ self.__leaderboards.extend(
183
+ [
184
+ RapidataLeaderboard(
185
+ leaderboard.name,
186
+ leaderboard.instruction,
187
+ leaderboard.show_prompt,
188
+ leaderboard.show_prompt_asset,
189
+ leaderboard.is_inversed,
190
+ leaderboard.response_budget,
191
+ leaderboard.min_responses,
192
+ self.id,
193
+ leaderboard.id,
194
+ self.__openapi_service,
195
+ )
196
+ for leaderboard in leaderboards_result.items
197
+ ]
198
+ )
199
+
200
+ if current_page >= total_pages:
201
+ break
190
202
 
191
- current_page += 1
203
+ current_page += 1
192
204
 
193
- return self.__leaderboards
205
+ return self.__leaderboards
194
206
 
195
207
  def add_prompt(
196
208
  self,
197
- identifier: str,
209
+ identifier: str | None = None,
198
210
  prompt: str | None = None,
199
- asset: str | None = None,
211
+ prompt_asset: str | None = None,
200
212
  tags: Optional[list[str]] = None,
201
213
  ):
202
214
  """
203
215
  Adds a prompt to the benchmark.
204
216
 
205
217
  Args:
206
- identifier: The identifier of the prompt/asset/tags that will be used to match up the media.
218
+ identifier: The identifier of the prompt/asset/tags that will be used to match up the media. If not provided, it will use the prompt, asset or prompt + asset as the identifier.
207
219
  prompt: The prompt that will be used to evaluate the model.
208
- asset: The asset that will be used to evaluate the model. Provided as a link to the asset.
220
+ prompt_asset: The prompt asset that will be used to evaluate the model. Provided as a link to the asset.
209
221
  tags: The tags can be used to filter the leaderboard results. They will NOT be shown to the users.
210
222
  """
211
- if tags is None:
212
- tags = []
223
+ with tracer.start_as_current_span("RapidataBenchmark.add_prompt"):
224
+ if tags is None:
225
+ tags = []
213
226
 
214
- if not isinstance(identifier, str):
215
- raise ValueError("Identifier must be a string.")
227
+ if prompt is None and prompt_asset is None:
228
+ raise ValueError("Prompt or prompt asset must be provided.")
216
229
 
217
- if prompt is None and asset is None:
218
- raise ValueError("Prompt or asset must be provided.")
230
+ if identifier is None and prompt is None:
231
+ raise ValueError("Identifier or prompt must be provided.")
219
232
 
220
- if prompt is not None and not isinstance(prompt, str):
221
- raise ValueError("Prompt must be a string.")
233
+ if identifier and not isinstance(identifier, str):
234
+ raise ValueError("Identifier must be a string.")
222
235
 
223
- if asset is not None and not isinstance(asset, str):
224
- raise ValueError("Asset must be a string. That is the link to the asset.")
236
+ if prompt and not isinstance(prompt, str):
237
+ raise ValueError("Prompt must be a string.")
225
238
 
226
- if identifier in self.identifiers:
227
- raise ValueError("Identifier already exists in the benchmark.")
228
-
229
- if asset is not None and not re.match(r"^https?://", asset):
230
- raise ValueError("Asset must be a link to the asset.")
239
+ if prompt_asset and not isinstance(prompt_asset, str):
240
+ raise ValueError(
241
+ "Asset must be a string. That is the link to the asset."
242
+ )
231
243
 
232
- if tags is not None and (
233
- not isinstance(tags, list) or not all(isinstance(tag, str) for tag in tags)
234
- ):
235
- raise ValueError("Tags must be a list of strings.")
244
+ if identifier is None:
245
+ assert prompt is not None
246
+ if prompt in self.prompts:
247
+ raise ValueError(
248
+ "Prompts must be unique. Otherwise use identifiers."
249
+ )
250
+ identifier = prompt
251
+
252
+ if identifier in self.identifiers:
253
+ raise ValueError("Identifier already exists in the benchmark.")
254
+
255
+ if prompt_asset is not None and not re.match(r"^https?://", prompt_asset):
256
+ raise ValueError("Prompt asset must be a link to the asset.")
257
+
258
+ if tags is not None and (
259
+ not isinstance(tags, list)
260
+ or not all(isinstance(tag, str) for tag in tags)
261
+ ):
262
+ raise ValueError("Tags must be a list of strings.")
263
+
264
+ logger.info(
265
+ "Adding identifier %s with prompt %s, prompt asset %s and tags %s to benchmark %s",
266
+ identifier,
267
+ prompt,
268
+ prompt_asset,
269
+ tags,
270
+ self.id,
271
+ )
236
272
 
237
- self.__identifiers.append(identifier)
273
+ self.__identifiers.append(identifier)
238
274
 
239
- self.__tags.append(tags)
240
- self.__prompts.append(prompt)
241
- self.__prompt_assets.append(asset)
275
+ self.__tags.append(tags)
276
+ self.__prompts.append(prompt)
277
+ self.__prompt_assets.append(prompt_asset)
242
278
 
243
- self.__openapi_service.benchmark_api.benchmark_benchmark_id_prompt_post(
244
- benchmark_id=self.id,
245
- submit_prompt_model=SubmitPromptModel(
246
- identifier=identifier,
247
- prompt=prompt,
248
- promptAsset=(
249
- SubmitPromptModelPromptAsset(
250
- UrlAssetInput(_t="UrlAssetInput", url=asset)
251
- )
252
- if asset is not None
253
- else None
279
+ self.__openapi_service.benchmark_api.benchmark_benchmark_id_prompt_post(
280
+ benchmark_id=self.id,
281
+ submit_prompt_model=SubmitPromptModel(
282
+ identifier=identifier,
283
+ prompt=prompt,
284
+ promptAsset=(
285
+ SubmitPromptModelPromptAsset(
286
+ UrlAssetInput(_t="UrlAssetInput", url=prompt_asset)
287
+ )
288
+ if prompt_asset is not None
289
+ else None
290
+ ),
291
+ tags=tags,
254
292
  ),
255
- tags=tags,
256
- ),
257
- )
293
+ )
258
294
 
259
295
  def create_leaderboard(
260
296
  self,
@@ -284,126 +320,190 @@ class RapidataBenchmark:
284
320
  filters: The filters that should be applied to the leaderboard. Will determine who can solve answer in the leaderboard. (default: [])
285
321
  settings: The settings that should be applied to the leaderboard. Will determine the behavior of the tasks on the leaderboard. (default: [])
286
322
  """
287
- if not isinstance(min_responses_per_matchup, int):
288
- raise ValueError("Min responses per matchup must be an integer")
289
-
290
- if min_responses_per_matchup < 3:
291
- raise ValueError("Min responses per matchup must be at least 3")
292
-
293
- leaderboard_result = self.__openapi_service.leaderboard_api.leaderboard_post(
294
- create_leaderboard_model=CreateLeaderboardModel(
295
- benchmarkId=self.id,
296
- name=name,
297
- instruction=instruction,
298
- showPrompt=show_prompt,
299
- showPromptAsset=show_prompt_asset,
300
- isInversed=inverse_ranking,
301
- minResponses=min_responses_per_matchup,
302
- responseBudget=DetailMapper.get_budget(level_of_detail),
303
- validationSetId=validation_set_id,
304
- filters=(
305
- [
306
- AndUserFilterModelFiltersInner(filter._to_model())
307
- for filter in filters
308
- ]
309
- if filters
310
- else None
311
- ),
312
- featureFlags=(
313
- [setting._to_feature_flag() for setting in settings]
314
- if settings
315
- else None
316
- ),
323
+ with tracer.start_as_current_span("create_leaderboard"):
324
+ if not isinstance(min_responses_per_matchup, int):
325
+ raise ValueError("Min responses per matchup must be an integer")
326
+
327
+ if min_responses_per_matchup < 3:
328
+ raise ValueError("Min responses per matchup must be at least 3")
329
+
330
+ logger.info(
331
+ "Creating leaderboard %s with instruction %s, show_prompt %s, show_prompt_asset %s, inverse_ranking %s, level_of_detail %s, min_responses_per_matchup %s, validation_set_id %s, filters %s, settings %s",
332
+ name,
333
+ instruction,
334
+ show_prompt,
335
+ show_prompt_asset,
336
+ inverse_ranking,
337
+ level_of_detail,
338
+ min_responses_per_matchup,
339
+ validation_set_id,
340
+ filters,
341
+ settings,
317
342
  )
318
- )
319
343
 
320
- assert (
321
- leaderboard_result.benchmark_id == self.id
322
- ), "The leaderboard was not created for the correct benchmark."
323
-
324
- return RapidataLeaderboard(
325
- name,
326
- instruction,
327
- show_prompt,
328
- show_prompt_asset,
329
- inverse_ranking,
330
- leaderboard_result.response_budget,
331
- min_responses_per_matchup,
332
- leaderboard_result.id,
333
- self.__openapi_service,
334
- )
344
+ leaderboard_result = (
345
+ self.__openapi_service.leaderboard_api.leaderboard_post(
346
+ create_leaderboard_model=CreateLeaderboardModel(
347
+ benchmarkId=self.id,
348
+ name=name,
349
+ instruction=instruction,
350
+ showPrompt=show_prompt,
351
+ showPromptAsset=show_prompt_asset,
352
+ isInversed=inverse_ranking,
353
+ minResponses=min_responses_per_matchup,
354
+ responseBudget=DetailMapper.get_budget(level_of_detail),
355
+ validationSetId=validation_set_id,
356
+ filters=(
357
+ [
358
+ AndUserFilterModelFiltersInner(filter._to_model())
359
+ for filter in filters
360
+ ]
361
+ if filters
362
+ else None
363
+ ),
364
+ featureFlags=(
365
+ [setting._to_feature_flag() for setting in settings]
366
+ if settings
367
+ else None
368
+ ),
369
+ )
370
+ )
371
+ )
372
+
373
+ assert (
374
+ leaderboard_result.benchmark_id == self.id
375
+ ), "The leaderboard was not created for the correct benchmark."
376
+
377
+ logger.info("Leaderboard created with id %s", leaderboard_result.id)
378
+
379
+ return RapidataLeaderboard(
380
+ name,
381
+ instruction,
382
+ show_prompt,
383
+ show_prompt_asset,
384
+ inverse_ranking,
385
+ leaderboard_result.response_budget,
386
+ min_responses_per_matchup,
387
+ self.id,
388
+ leaderboard_result.id,
389
+ self.__openapi_service,
390
+ )
335
391
 
336
392
  def evaluate_model(
337
- self, name: str, media: list[str], identifiers: list[str]
393
+ self,
394
+ name: str,
395
+ media: list[str],
396
+ identifiers: list[str] | None = None,
397
+ prompts: list[str] | None = None,
338
398
  ) -> None:
339
399
  """
340
400
  Evaluates a model on the benchmark across all leaderboards.
341
401
 
402
+ prompts or identifiers must be provided to match the media.
403
+
342
404
  Args:
343
405
  name: The name of the model.
344
406
  media: The generated images/videos that will be used to evaluate the model.
345
- identifiers: The identifiers that correspond to the media. The order of the identifiers must match the order of the media.
407
+ identifiers: The identifiers that correspond to the media. The order of the identifiers must match the order of the media.\n
346
408
  The identifiers that are used must be registered for the benchmark. To see the registered identifiers, use the identifiers property.
409
+ prompts: The prompts that correspond to the media. The order of the prompts must match the order of the media.
347
410
  """
348
- if not media:
349
- raise ValueError("Media must be a non-empty list of strings")
411
+ with tracer.start_as_current_span("evaluate_model"):
412
+ if not media:
413
+ raise ValueError("Media must be a non-empty list of strings")
350
414
 
351
- if len(media) != len(identifiers):
352
- raise ValueError("Media and identifiers must have the same length")
415
+ if not identifiers and not prompts:
416
+ raise ValueError("Identifiers or prompts must be provided.")
353
417
 
354
- if not all(identifier in self.identifiers for identifier in identifiers):
355
- raise ValueError(
356
- "All identifiers must be in the registered identifiers list. To see the registered identifiers, use the identifiers property.\
357
- \nTo see the prompts that are associated with the identifiers, use the prompts property."
358
- )
418
+ if identifiers and prompts:
419
+ raise ValueError(
420
+ "Identifiers and prompts cannot be provided at the same time. Use one or the other."
421
+ )
359
422
 
360
- # happens before the creation of the participant to ensure all media paths are valid
361
- assets: list[MediaAsset] = []
362
- for media_path in media:
363
- assets.append(MediaAsset(media_path))
423
+ if not identifiers:
424
+ assert prompts is not None
425
+ identifiers = prompts
364
426
 
365
- participant_result = self.__openapi_service.benchmark_api.benchmark_benchmark_id_participants_post(
366
- benchmark_id=self.id,
367
- create_benchmark_participant_model=CreateBenchmarkParticipantModel(
368
- name=name,
369
- ),
370
- )
427
+ if len(media) != len(identifiers):
428
+ raise ValueError(
429
+ "Media and identifiers/prompts must have the same length"
430
+ )
371
431
 
372
- logger.info(f"Participant created: {participant_result.participant_id}")
432
+ if not all(identifier in self.identifiers for identifier in identifiers):
433
+ raise ValueError(
434
+ "All identifiers/prompts must be in the registered identifiers/prompts list. To see the registered identifiers/prompts, use the identifiers/prompts property."
435
+ )
373
436
 
374
- participant = BenchmarkParticipant(
375
- name, participant_result.participant_id, self.__openapi_service
376
- )
437
+ # happens before the creation of the participant to ensure all media paths are valid
438
+ assets: list[MediaAsset] = []
439
+ for media_path in media:
440
+ assets.append(MediaAsset(media_path))
377
441
 
378
- successful_uploads, failed_uploads = participant.upload_media(
379
- assets,
380
- identifiers,
381
- )
442
+ participant_result = self.__openapi_service.benchmark_api.benchmark_benchmark_id_participants_post(
443
+ benchmark_id=self.id,
444
+ create_benchmark_participant_model=CreateBenchmarkParticipantModel(
445
+ name=name,
446
+ ),
447
+ )
382
448
 
383
- total_uploads = len(assets)
384
- success_rate = (
385
- (len(successful_uploads) / total_uploads * 100) if total_uploads > 0 else 0
386
- )
387
- logger.info(
388
- f"Upload complete: {len(successful_uploads)} successful, {len(failed_uploads)} failed ({success_rate:.1f}% success rate)"
389
- )
449
+ logger.info(f"Participant created: {participant_result.participant_id}")
390
450
 
391
- if failed_uploads:
392
- logger.error(
393
- f"Failed uploads for media: {[asset.path for asset in failed_uploads]}"
394
- )
395
- logger.warning(
396
- "Some uploads failed. The model evaluation may be incomplete."
451
+ participant = BenchmarkParticipant(
452
+ name, participant_result.participant_id, self.__openapi_service
397
453
  )
398
454
 
399
- if len(successful_uploads) == 0:
400
- raise RuntimeError(
401
- "No uploads were successful. The model evaluation will not be completed."
455
+ with tracer.start_as_current_span("upload_media_for_participant"):
456
+ logger.info(
457
+ f"Uploading {len(assets)} media assets to participant {participant.id}"
458
+ )
459
+
460
+ successful_uploads, failed_uploads = participant.upload_media(
461
+ assets,
462
+ identifiers,
463
+ )
464
+
465
+ total_uploads = len(assets)
466
+ success_rate = (
467
+ (len(successful_uploads) / total_uploads * 100)
468
+ if total_uploads > 0
469
+ else 0
470
+ )
471
+ logger.info(
472
+ f"Upload complete: {len(successful_uploads)} successful, {len(failed_uploads)} failed ({success_rate:.1f}% success rate)"
473
+ )
474
+
475
+ if failed_uploads:
476
+ logger.error(
477
+ f"Failed uploads for media: {[asset.path for asset in failed_uploads]}"
478
+ )
479
+ logger.warning(
480
+ "Some uploads failed. The model evaluation may be incomplete."
481
+ )
482
+
483
+ if len(successful_uploads) == 0:
484
+ raise RuntimeError(
485
+ "No uploads were successful. The model evaluation will not be completed."
486
+ )
487
+
488
+ self.__openapi_service.participant_api.participants_participant_id_submit_post(
489
+ participant_id=participant_result.participant_id
402
490
  )
403
491
 
404
- self.__openapi_service.participant_api.participants_participant_id_submit_post(
405
- participant_id=participant_result.participant_id
406
- )
492
+ def view(self) -> None:
493
+ """
494
+ Views the benchmark.
495
+ """
496
+ logger.info("Opening benchmark page in browser...")
497
+ could_open_browser = webbrowser.open(self.__benchmark_page)
498
+ if not could_open_browser:
499
+ encoded_url = urllib.parse.quote(
500
+ self.__benchmark_page, safe="%/:=&?~#+!$,;'@()*[]"
501
+ )
502
+ managed_print(
503
+ Fore.RED
504
+ + f"Please open this URL in your browser: '{encoded_url}'"
505
+ + Fore.RESET
506
+ )
407
507
 
408
508
  def __str__(self) -> str:
409
509
  return f"RapidataBenchmark(name={self.name}, id={self.id})"