rapidata 2.33.1__py3-none-any.whl → 2.34.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rapidata might be problematic. Click here for more details.

Files changed (30) hide show
  1. rapidata/__init__.py +2 -2
  2. rapidata/api_client/__init__.py +4 -0
  3. rapidata/api_client/api/__init__.py +1 -0
  4. rapidata/api_client/api/benchmark_api.py +6 -5
  5. rapidata/api_client/api/leaderboard_api.py +29 -296
  6. rapidata/api_client/api/prompt_api.py +320 -0
  7. rapidata/api_client/api/validation_set_api.py +3 -3
  8. rapidata/api_client/models/__init__.py +3 -0
  9. rapidata/api_client/models/conditional_validation_selection.py +4 -2
  10. rapidata/api_client/models/create_leaderboard_model.py +9 -2
  11. rapidata/api_client/models/get_standing_by_id_result.py +4 -15
  12. rapidata/api_client/models/prompt_by_benchmark_result.py +3 -1
  13. rapidata/api_client/models/standing_by_leaderboard.py +1 -1
  14. rapidata/api_client/models/standings_by_leaderboard_result.py +95 -0
  15. rapidata/api_client/models/tags_by_benchmark_result.py +87 -0
  16. rapidata/api_client/models/update_prompt_tags_model.py +87 -0
  17. rapidata/api_client_README.md +5 -2
  18. rapidata/rapidata_client/__init__.py +1 -1
  19. rapidata/rapidata_client/benchmark/leaderboard/rapidata_leaderboard.py +12 -9
  20. rapidata/rapidata_client/benchmark/participant/__init__.py +0 -0
  21. rapidata/rapidata_client/benchmark/participant/_participant.py +102 -0
  22. rapidata/rapidata_client/benchmark/rapidata_benchmark.py +50 -27
  23. rapidata/rapidata_client/benchmark/rapidata_benchmark_manager.py +14 -8
  24. rapidata/rapidata_client/selection/__init__.py +1 -1
  25. rapidata/rapidata_client/selection/effort_selection.py +9 -2
  26. rapidata/service/openapi_service.py +5 -0
  27. {rapidata-2.33.1.dist-info → rapidata-2.34.0.dist-info}/METADATA +1 -1
  28. {rapidata-2.33.1.dist-info → rapidata-2.34.0.dist-info}/RECORD +30 -24
  29. {rapidata-2.33.1.dist-info → rapidata-2.34.0.dist-info}/LICENSE +0 -0
  30. {rapidata-2.33.1.dist-info → rapidata-2.34.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,87 @@
1
+ # coding: utf-8
2
+
3
+ """
4
+ Rapidata.Dataset
5
+
6
+ No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
7
+
8
+ The version of the OpenAPI document: v1
9
+ Generated by OpenAPI Generator (https://openapi-generator.tech)
10
+
11
+ Do not edit the class manually.
12
+ """ # noqa: E501
13
+
14
+
15
+ from __future__ import annotations
16
+ import pprint
17
+ import re # noqa: F401
18
+ import json
19
+
20
+ from pydantic import BaseModel, ConfigDict, StrictStr
21
+ from typing import Any, ClassVar, Dict, List
22
+ from typing import Optional, Set
23
+ from typing_extensions import Self
24
+
25
+ class TagsByBenchmarkResult(BaseModel):
26
+ """
27
+ TagsByBenchmarkResult
28
+ """ # noqa: E501
29
+ tags: List[StrictStr]
30
+ __properties: ClassVar[List[str]] = ["tags"]
31
+
32
+ model_config = ConfigDict(
33
+ populate_by_name=True,
34
+ validate_assignment=True,
35
+ protected_namespaces=(),
36
+ )
37
+
38
+
39
+ def to_str(self) -> str:
40
+ """Returns the string representation of the model using alias"""
41
+ return pprint.pformat(self.model_dump(by_alias=True))
42
+
43
+ def to_json(self) -> str:
44
+ """Returns the JSON representation of the model using alias"""
45
+ # TODO: pydantic v2: use .model_dump_json(by_alias=True, exclude_unset=True) instead
46
+ return json.dumps(self.to_dict())
47
+
48
+ @classmethod
49
+ def from_json(cls, json_str: str) -> Optional[Self]:
50
+ """Create an instance of TagsByBenchmarkResult from a JSON string"""
51
+ return cls.from_dict(json.loads(json_str))
52
+
53
+ def to_dict(self) -> Dict[str, Any]:
54
+ """Return the dictionary representation of the model using alias.
55
+
56
+ This has the following differences from calling pydantic's
57
+ `self.model_dump(by_alias=True)`:
58
+
59
+ * `None` is only added to the output dict for nullable fields that
60
+ were set at model initialization. Other fields with value `None`
61
+ are ignored.
62
+ """
63
+ excluded_fields: Set[str] = set([
64
+ ])
65
+
66
+ _dict = self.model_dump(
67
+ by_alias=True,
68
+ exclude=excluded_fields,
69
+ exclude_none=True,
70
+ )
71
+ return _dict
72
+
73
+ @classmethod
74
+ def from_dict(cls, obj: Optional[Dict[str, Any]]) -> Optional[Self]:
75
+ """Create an instance of TagsByBenchmarkResult from a dict"""
76
+ if obj is None:
77
+ return None
78
+
79
+ if not isinstance(obj, dict):
80
+ return cls.model_validate(obj)
81
+
82
+ _obj = cls.model_validate({
83
+ "tags": obj.get("tags")
84
+ })
85
+ return _obj
86
+
87
+
@@ -0,0 +1,87 @@
1
+ # coding: utf-8
2
+
3
+ """
4
+ Rapidata.Dataset
5
+
6
+ No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
7
+
8
+ The version of the OpenAPI document: v1
9
+ Generated by OpenAPI Generator (https://openapi-generator.tech)
10
+
11
+ Do not edit the class manually.
12
+ """ # noqa: E501
13
+
14
+
15
+ from __future__ import annotations
16
+ import pprint
17
+ import re # noqa: F401
18
+ import json
19
+
20
+ from pydantic import BaseModel, ConfigDict, Field, StrictStr
21
+ from typing import Any, ClassVar, Dict, List
22
+ from typing import Optional, Set
23
+ from typing_extensions import Self
24
+
25
+ class UpdatePromptTagsModel(BaseModel):
26
+ """
27
+ The model for updating prompt tags.
28
+ """ # noqa: E501
29
+ tags: List[StrictStr] = Field(description="The list of tags to be associated with the prompt.")
30
+ __properties: ClassVar[List[str]] = ["tags"]
31
+
32
+ model_config = ConfigDict(
33
+ populate_by_name=True,
34
+ validate_assignment=True,
35
+ protected_namespaces=(),
36
+ )
37
+
38
+
39
+ def to_str(self) -> str:
40
+ """Returns the string representation of the model using alias"""
41
+ return pprint.pformat(self.model_dump(by_alias=True))
42
+
43
+ def to_json(self) -> str:
44
+ """Returns the JSON representation of the model using alias"""
45
+ # TODO: pydantic v2: use .model_dump_json(by_alias=True, exclude_unset=True) instead
46
+ return json.dumps(self.to_dict())
47
+
48
+ @classmethod
49
+ def from_json(cls, json_str: str) -> Optional[Self]:
50
+ """Create an instance of UpdatePromptTagsModel from a JSON string"""
51
+ return cls.from_dict(json.loads(json_str))
52
+
53
+ def to_dict(self) -> Dict[str, Any]:
54
+ """Return the dictionary representation of the model using alias.
55
+
56
+ This has the following differences from calling pydantic's
57
+ `self.model_dump(by_alias=True)`:
58
+
59
+ * `None` is only added to the output dict for nullable fields that
60
+ were set at model initialization. Other fields with value `None`
61
+ are ignored.
62
+ """
63
+ excluded_fields: Set[str] = set([
64
+ ])
65
+
66
+ _dict = self.model_dump(
67
+ by_alias=True,
68
+ exclude=excluded_fields,
69
+ exclude_none=True,
70
+ )
71
+ return _dict
72
+
73
+ @classmethod
74
+ def from_dict(cls, obj: Optional[Dict[str, Any]]) -> Optional[Self]:
75
+ """Create an instance of UpdatePromptTagsModel from a dict"""
76
+ if obj is None:
77
+ return None
78
+
79
+ if not isinstance(obj, dict):
80
+ return cls.model_validate(obj)
81
+
82
+ _obj = cls.model_validate({
83
+ "tags": obj.get("tags")
84
+ })
85
+ return _obj
86
+
87
+
@@ -135,7 +135,6 @@ Class | Method | HTTP request | Description
135
135
  *LeaderboardApi* | [**leaderboard_leaderboard_id_participants_post**](rapidata/api_client/docs/LeaderboardApi.md#leaderboard_leaderboard_id_participants_post) | **POST** /leaderboard/{leaderboardId}/participants | Creates a participant in a leaderboard.
136
136
  *LeaderboardApi* | [**leaderboard_leaderboard_id_prompts_get**](rapidata/api_client/docs/LeaderboardApi.md#leaderboard_leaderboard_id_prompts_get) | **GET** /leaderboard/{leaderboardId}/prompts | returns the paged prompts of a leaderboard by its ID.
137
137
  *LeaderboardApi* | [**leaderboard_leaderboard_id_prompts_post**](rapidata/api_client/docs/LeaderboardApi.md#leaderboard_leaderboard_id_prompts_post) | **POST** /leaderboard/{leaderboardId}/prompts | adds a new prompt to a leaderboard.
138
- *LeaderboardApi* | [**leaderboard_leaderboard_id_refresh_post**](rapidata/api_client/docs/LeaderboardApi.md#leaderboard_leaderboard_id_refresh_post) | **POST** /leaderboard/{leaderboardId}/refresh | This will force an update to all standings of a leaderboard. this could happen if the recorded matches and scores are out of sync
139
138
  *LeaderboardApi* | [**leaderboard_leaderboard_id_runs_get**](rapidata/api_client/docs/LeaderboardApi.md#leaderboard_leaderboard_id_runs_get) | **GET** /leaderboard/{leaderboardId}/runs | Gets the runs related to a leaderboard
140
139
  *LeaderboardApi* | [**leaderboard_leaderboard_id_standings_get**](rapidata/api_client/docs/LeaderboardApi.md#leaderboard_leaderboard_id_standings_get) | **GET** /leaderboard/{leaderboardId}/standings | queries all the participants connected to leaderboard by its ID.
141
140
  *LeaderboardApi* | [**leaderboard_post**](rapidata/api_client/docs/LeaderboardApi.md#leaderboard_post) | **POST** /leaderboard | Creates a new leaderboard with the specified name and criteria.
@@ -178,6 +177,7 @@ Class | Method | HTTP request | Description
178
177
  *PipelineApi* | [**pipeline_pipeline_id_get**](rapidata/api_client/docs/PipelineApi.md#pipeline_pipeline_id_get) | **GET** /pipeline/{pipelineId} | Gets a pipeline by its id.
179
178
  *PipelineApi* | [**pipeline_pipeline_id_preliminary_download_post**](rapidata/api_client/docs/PipelineApi.md#pipeline_pipeline_id_preliminary_download_post) | **POST** /pipeline/{pipelineId}/preliminary-download | Initiates a preliminary download of the pipeline.
180
179
  *PipelineApi* | [**pipeline_preliminary_download_preliminary_download_id_get**](rapidata/api_client/docs/PipelineApi.md#pipeline_preliminary_download_preliminary_download_id_get) | **GET** /pipeline/preliminary-download/{preliminaryDownloadId} | Gets the preliminary download.
180
+ *PromptApi* | [**benchmark_prompt_prompt_id_tags_put**](rapidata/api_client/docs/PromptApi.md#benchmark_prompt_prompt_id_tags_put) | **PUT** /benchmark-prompt/{promptId}/tags | Updates the tags associated with a prompt.
181
181
  *RapidataIdentityAPIApi* | [**root_get**](rapidata/api_client/docs/RapidataIdentityAPIApi.md#root_get) | **GET** / |
182
182
  *SimpleWorkflowApi* | [**workflow_simple_workflow_id_results_get**](rapidata/api_client/docs/SimpleWorkflowApi.md#workflow_simple_workflow_id_results_get) | **GET** /workflow/simple/{workflowId}/results | Get the result overview for a simple workflow.
183
183
  *UserInfoApi* | [**connect_userinfo_get**](rapidata/api_client/docs/UserInfoApi.md#connect_userinfo_get) | **GET** /connect/userinfo | Retrieves information about the authenticated user.
@@ -197,7 +197,7 @@ Class | Method | HTTP request | Description
197
197
  *ValidationSetApi* | [**validation_set_validation_set_id_rapid_texts_post**](rapidata/api_client/docs/ValidationSetApi.md#validation_set_validation_set_id_rapid_texts_post) | **POST** /validation-set/{validationSetId}/rapid/texts | Adds a new validation rapid to the specified validation set using text sources to create the assets.
198
198
  *ValidationSetApi* | [**validation_set_validation_set_id_rapids_get**](rapidata/api_client/docs/ValidationSetApi.md#validation_set_validation_set_id_rapids_get) | **GET** /validation-set/{validationSetId}/rapids | Queries the validation rapids for a specific validation set.
199
199
  *ValidationSetApi* | [**validation_set_validation_set_id_shouldalert_patch**](rapidata/api_client/docs/ValidationSetApi.md#validation_set_validation_set_id_shouldalert_patch) | **PATCH** /validation-set/{validationSetId}/shouldalert | Updates the dimensions of all rapids within a validation set.
200
- *ValidationSetApi* | [**validation_set_validation_set_id_shouldalert_put**](rapidata/api_client/docs/ValidationSetApi.md#validation_set_validation_set_id_shouldalert_put) | **PUT** /validation-set/{validationSetId}/shouldalert | Updates the dimensions of all rapids within a validation set.
200
+ *ValidationSetApi* | [**validation_set_validation_set_id_shouldalert_put**](rapidata/api_client/docs/ValidationSetApi.md#validation_set_validation_set_id_shouldalert_put) | **PUT** /validation-set/{validationSetId}/shouldalert | Updates the of all rapshouldAlert property of all rapids within a validation set.
201
201
  *ValidationSetApi* | [**validation_set_zip_compare_post**](rapidata/api_client/docs/ValidationSetApi.md#validation_set_zip_compare_post) | **POST** /validation-set/zip/compare | Imports a compare validation set from a zip file.
202
202
  *ValidationSetApi* | [**validation_set_zip_post**](rapidata/api_client/docs/ValidationSetApi.md#validation_set_zip_post) | **POST** /validation-set/zip | Imports a validation set from a zip file.
203
203
  *ValidationSetApi* | [**validation_sets_available_get**](rapidata/api_client/docs/ValidationSetApi.md#validation_sets_available_get) | **GET** /validation-sets/available | Gets the available validation sets for the current user.
@@ -529,6 +529,7 @@ Class | Method | HTTP request | Description
529
529
  - [StandingByLeaderboard](rapidata/api_client/docs/StandingByLeaderboard.md)
530
530
  - [StandingByLeaderboardPagedResult](rapidata/api_client/docs/StandingByLeaderboardPagedResult.md)
531
531
  - [StandingStatus](rapidata/api_client/docs/StandingStatus.md)
532
+ - [StandingsByLeaderboardResult](rapidata/api_client/docs/StandingsByLeaderboardResult.md)
532
533
  - [StaticSelection](rapidata/api_client/docs/StaticSelection.md)
533
534
  - [StickyState](rapidata/api_client/docs/StickyState.md)
534
535
  - [StreamFileWrapper](rapidata/api_client/docs/StreamFileWrapper.md)
@@ -540,6 +541,7 @@ Class | Method | HTTP request | Description
540
541
  - [SubmitParticipantResult](rapidata/api_client/docs/SubmitParticipantResult.md)
541
542
  - [SubmitPromptModel](rapidata/api_client/docs/SubmitPromptModel.md)
542
543
  - [SubmitPromptModelPromptAsset](rapidata/api_client/docs/SubmitPromptModelPromptAsset.md)
544
+ - [TagsByBenchmarkResult](rapidata/api_client/docs/TagsByBenchmarkResult.md)
543
545
  - [TextAsset](rapidata/api_client/docs/TextAsset.md)
544
546
  - [TextAssetInput](rapidata/api_client/docs/TextAssetInput.md)
545
547
  - [TextAssetModel](rapidata/api_client/docs/TextAssetModel.md)
@@ -563,6 +565,7 @@ Class | Method | HTTP request | Description
563
565
  - [UpdateLeaderboardNameModel](rapidata/api_client/docs/UpdateLeaderboardNameModel.md)
564
566
  - [UpdateOrderNameModel](rapidata/api_client/docs/UpdateOrderNameModel.md)
565
567
  - [UpdateParticipantNameModel](rapidata/api_client/docs/UpdateParticipantNameModel.md)
568
+ - [UpdatePromptTagsModel](rapidata/api_client/docs/UpdatePromptTagsModel.md)
566
569
  - [UpdateShouldAlertModel](rapidata/api_client/docs/UpdateShouldAlertModel.md)
567
570
  - [UpdateValidationRapidModel](rapidata/api_client/docs/UpdateValidationRapidModel.md)
568
571
  - [UpdateValidationRapidModelTruth](rapidata/api_client/docs/UpdateValidationRapidModelTruth.md)
@@ -7,7 +7,7 @@ from .selection import (
7
7
  CappedSelection,
8
8
  ShufflingSelection,
9
9
  RetrievalMode,
10
- EffortEstimationSelection,
10
+ EffortSelection,
11
11
  )
12
12
  from .datapoints import Datapoint
13
13
  from .datapoints.metadata import (
@@ -1,8 +1,5 @@
1
1
  import pandas as pd
2
-
3
- from rapidata.api_client.models.query_model import QueryModel
4
- from rapidata.api_client.models.page_info import PageInfo
5
- from rapidata.api_client.models.sort_criterion import SortCriterion
2
+ from typing import Optional
6
3
 
7
4
  from rapidata.service.openapi_service import OpenAPIService
8
5
 
@@ -89,16 +86,22 @@ class RapidataLeaderboard:
89
86
  """
90
87
  return self.__name
91
88
 
92
- def get_standings(self) -> pd.DataFrame:
89
+ def get_standings(self, tags: Optional[list[str]] = None) -> pd.DataFrame:
93
90
  """
94
91
  Returns the standings of the leaderboard.
92
+
93
+ Args:
94
+ tags: The matchups with these tags should be used to create the standings.
95
+ If tags are None, all matchups will be considered.
96
+ If tags are empty, no matchups will be considered.
97
+
98
+ Returns:
99
+ A pandas DataFrame containing the standings of the leaderboard.
95
100
  """
101
+
96
102
  participants = self.__openapi_service.leaderboard_api.leaderboard_leaderboard_id_standings_get(
97
103
  leaderboard_id=self.id,
98
- request=QueryModel(
99
- page=PageInfo(index=1, size=1000),
100
- sortCriteria=[SortCriterion(direction="Desc", propertyName="Score")]
101
- )
104
+ tags=tags
102
105
  )
103
106
 
104
107
  standings = []
@@ -0,0 +1,102 @@
1
+ from concurrent.futures import ThreadPoolExecutor, as_completed
2
+ from tqdm import tqdm
3
+
4
+ from rapidata.rapidata_client.datapoints.assets import MediaAsset
5
+ from rapidata.rapidata_client.logging import logger
6
+ from rapidata.rapidata_client.logging.output_manager import RapidataOutputManager
7
+ from rapidata.api_client.models.create_sample_model import CreateSampleModel
8
+ from rapidata.service.openapi_service import OpenAPIService
9
+
10
+
11
+ class BenchmarkParticipant:
12
+ def __init__(self, name: str, id: str, openapi_service: OpenAPIService):
13
+ self.name = name
14
+ self.id = id
15
+ self.__openapi_service = openapi_service
16
+
17
+ def _process_single_sample_upload(
18
+ self,
19
+ asset: MediaAsset,
20
+ identifier: str,
21
+ ) -> tuple[MediaAsset | None, MediaAsset | None]:
22
+ """
23
+ Process single sample upload with retry logic and error tracking.
24
+
25
+ Args:
26
+ asset: MediaAsset to upload
27
+ identifier: Identifier for the sample
28
+
29
+ Returns:
30
+ tuple[MediaAsset | None, MediaAsset | None]: (successful_asset, failed_asset)
31
+ """
32
+ if asset.is_local():
33
+ files = [asset.to_file()]
34
+ urls = []
35
+ else:
36
+ files = []
37
+ urls = [asset.path]
38
+
39
+ last_exception = None
40
+ try:
41
+ self.__openapi_service.participant_api.participant_participant_id_sample_post(
42
+ participant_id=self.id,
43
+ model=CreateSampleModel(
44
+ identifier=identifier
45
+ ),
46
+ files=files,
47
+ urls=urls
48
+ )
49
+
50
+ return asset, None
51
+
52
+ except Exception as e:
53
+ last_exception = e
54
+
55
+ logger.error(f"Upload failed for {identifier}. Error: {str(last_exception)}")
56
+ return None, asset
57
+
58
+ def upload_media(
59
+ self,
60
+ assets: list[MediaAsset],
61
+ identifiers: list[str],
62
+ max_workers: int = 10,
63
+ ) -> tuple[list[MediaAsset], list[MediaAsset]]:
64
+ """
65
+ Upload samples concurrently with proper error handling and progress tracking.
66
+
67
+ Args:
68
+ assets: List of MediaAsset objects to upload
69
+ identifiers: List of identifiers matching the assets
70
+ max_workers: Maximum number of concurrent upload workers
71
+
72
+ Returns:
73
+ tuple[list[str], list[str]]: Lists of successful and failed identifiers
74
+ """
75
+ successful_uploads: list[MediaAsset] = []
76
+ failed_uploads: list[MediaAsset] = []
77
+ total_uploads = len(assets)
78
+
79
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
80
+ futures = [
81
+ executor.submit(
82
+ self._process_single_sample_upload,
83
+ asset,
84
+ identifier,
85
+ )
86
+ for asset, identifier in zip(assets, identifiers)
87
+ ]
88
+
89
+ with tqdm(total=total_uploads, desc="Uploading media", disable=RapidataOutputManager.silent_mode) as pbar:
90
+ for future in as_completed(futures):
91
+ try:
92
+ successful_id, failed_id = future.result()
93
+ if successful_id:
94
+ successful_uploads.append(successful_id)
95
+ if failed_id:
96
+ failed_uploads.append(failed_id)
97
+ except Exception as e:
98
+ logger.error(f"Future execution failed: {str(e)}")
99
+
100
+ pbar.update(1)
101
+
102
+ return successful_uploads, failed_uploads
@@ -1,4 +1,5 @@
1
1
  import re
2
+ from typing import Optional
2
3
  from rapidata.api_client.models.root_filter import RootFilter
3
4
  from rapidata.api_client.models.filter import Filter
4
5
  from rapidata.api_client.models.query_model import QueryModel
@@ -11,14 +12,13 @@ from rapidata.api_client.models.url_asset_input import UrlAssetInput
11
12
  from rapidata.api_client.models.file_asset_model import FileAssetModel
12
13
  from rapidata.api_client.models.source_url_metadata_model import SourceUrlMetadataModel
13
14
 
15
+
16
+ from rapidata.rapidata_client.benchmark.participant._participant import BenchmarkParticipant
14
17
  from rapidata.rapidata_client.logging import logger
15
18
  from rapidata.service.openapi_service import OpenAPIService
16
19
 
17
20
  from rapidata.rapidata_client.benchmark.leaderboard.rapidata_leaderboard import RapidataLeaderboard
18
- from rapidata.rapidata_client.datapoints.metadata import PromptIdentifierMetadata
19
21
  from rapidata.rapidata_client.datapoints.assets import MediaAsset
20
- from rapidata.rapidata_client.order._rapidata_dataset import RapidataDataset
21
- from rapidata.rapidata_client.datapoints.datapoint import Datapoint
22
22
 
23
23
  class RapidataBenchmark:
24
24
  """
@@ -39,7 +39,8 @@ class RapidataBenchmark:
39
39
  self.__prompt_assets: list[str | None] = []
40
40
  self.__leaderboards: list[RapidataLeaderboard] = []
41
41
  self.__identifiers: list[str] = []
42
-
42
+ self.__tags: list[list[str]] = []
43
+
43
44
  def __instantiate_prompts(self) -> None:
44
45
  current_page = 1
45
46
  total_pages = None
@@ -70,7 +71,8 @@ class RapidataBenchmark:
70
71
  source_url = prompt.prompt_asset.actual_instance.metadata["sourceUrl"].actual_instance
71
72
  assert isinstance(source_url, SourceUrlMetadataModel)
72
73
  self.__prompt_assets.append(source_url.url)
73
-
74
+
75
+ self.__tags.append(prompt.tags)
74
76
  if current_page >= total_pages:
75
77
  break
76
78
 
@@ -104,6 +106,15 @@ class RapidataBenchmark:
104
106
  return self.__prompt_assets
105
107
 
106
108
  @property
109
+ def tags(self) -> list[list[str]]:
110
+ """
111
+ Returns the tags that are registered for the benchmark.
112
+ """
113
+ if not self.__tags:
114
+ self.__instantiate_prompts()
115
+
116
+ return self.__tags
117
+
107
118
  def leaderboards(self) -> list[RapidataLeaderboard]:
108
119
  """
109
120
  Returns the leaderboards that are registered for the benchmark.
@@ -152,7 +163,7 @@ class RapidataBenchmark:
152
163
 
153
164
  return self.__leaderboards
154
165
 
155
- def add_prompt(self, identifier: str, prompt: str | None = None, asset: str | None = None):
166
+ def add_prompt(self, identifier: str, prompt: str | None = None, asset: str | None = None, tags: Optional[list[str]] = None):
156
167
  """
157
168
  Adds a prompt to the benchmark.
158
169
 
@@ -160,7 +171,11 @@ class RapidataBenchmark:
160
171
  identifier: The identifier of the prompt/asset that will be used to match up the media.
161
172
  prompt: The prompt that will be used to evaluate the model.
162
173
  asset: The asset that will be used to evaluate the model. Provided as a link to the asset.
174
+ tags: The tags can be used to filter the leaderboard results. They will NOT be shown to the users.
163
175
  """
176
+ if tags is None:
177
+ tags = []
178
+
164
179
  if not isinstance(identifier, str):
165
180
  raise ValueError("Identifier must be a string.")
166
181
 
@@ -179,8 +194,12 @@ class RapidataBenchmark:
179
194
  if asset is not None and not re.match(r'^https?://', asset):
180
195
  raise ValueError("Asset must be a link to the asset.")
181
196
 
197
+ if tags is not None and (not isinstance(tags, list) or not all(isinstance(tag, str) for tag in tags)):
198
+ raise ValueError("Tags must be a list of strings.")
199
+
182
200
  self.__identifiers.append(identifier)
183
201
 
202
+ self.__tags.append(tags)
184
203
  self.__prompts.append(prompt)
185
204
  self.__prompt_assets.append(asset)
186
205
 
@@ -194,7 +213,8 @@ class RapidataBenchmark:
194
213
  _t="UrlAssetInput",
195
214
  url=asset
196
215
  )
197
- ) if asset is not None else None
216
+ ) if asset is not None else None,
217
+ tags=tags
198
218
  )
199
219
  )
200
220
 
@@ -250,7 +270,7 @@ class RapidataBenchmark:
250
270
  leaderboard_result.id,
251
271
  self.__openapi_service
252
272
  )
253
-
273
+
254
274
  def evaluate_model(self, name: str, media: list[str], identifiers: list[str]) -> None:
255
275
  """
256
276
  Evaluates a model on the benchmark across all leaderboards.
@@ -272,11 +292,9 @@ class RapidataBenchmark:
272
292
  \nTo see the prompts that are associated with the identifiers, use the prompts property.")
273
293
 
274
294
  # happens before the creation of the participant to ensure all media paths are valid
275
- assets = []
276
- prompts_metadata: list[list[PromptIdentifierMetadata]] = []
277
- for media_path, identifier in zip(media, identifiers):
295
+ assets: list[MediaAsset] = []
296
+ for media_path in media:
278
297
  assets.append(MediaAsset(media_path))
279
- prompts_metadata.append([PromptIdentifierMetadata(identifier=identifier)])
280
298
 
281
299
  participant_result = self.__openapi_service.benchmark_api.benchmark_benchmark_id_participants_post(
282
300
  benchmark_id=self.id,
@@ -285,22 +303,27 @@ class RapidataBenchmark:
285
303
  )
286
304
  )
287
305
 
288
- dataset = RapidataDataset(participant_result.dataset_id, self.__openapi_service)
289
-
290
- try:
291
- dataset.add_datapoints([Datapoint(asset=asset, metadata=metadata) for asset, metadata in zip(assets, prompts_metadata)])
292
- except Exception as e:
293
- logger.warning(f"An error occurred while adding datapoints to the dataset: {e}")
294
- upload_progress = self.__openapi_service.dataset_api.dataset_dataset_id_progress_get(
295
- dataset_id=dataset.id
296
- )
297
- if upload_progress.ready == 0:
298
- raise RuntimeError("None of the media was uploaded successfully. Please check the media paths and try again.")
299
-
300
- logger.warning(f"{upload_progress.failed} datapoints failed to upload. \n{upload_progress.ready} datapoints were uploaded successfully. \nEvaluation will continue with the uploaded datapoints.")
306
+ logger.info(f"Participant created: {participant_result.participant_id}")
301
307
 
302
- self.__openapi_service.benchmark_api.benchmark_benchmark_id_participants_participant_id_submit_post(
303
- benchmark_id=self.id,
308
+ participant = BenchmarkParticipant(name, participant_result.participant_id, self.__openapi_service)
309
+
310
+ successful_uploads, failed_uploads = participant.upload_media(
311
+ assets,
312
+ identifiers,
313
+ )
314
+
315
+ total_uploads = len(assets)
316
+ success_rate = (len(successful_uploads) / total_uploads * 100) if total_uploads > 0 else 0
317
+ logger.info(f"Upload complete: {len(successful_uploads)} successful, {len(failed_uploads)} failed ({success_rate:.1f}% success rate)")
318
+
319
+ if failed_uploads:
320
+ logger.error(f"Failed uploads for media: {[asset.path for asset in failed_uploads]}")
321
+ logger.warning("Some uploads failed. The model evaluation may be incomplete.")
322
+
323
+ if len(successful_uploads) == 0:
324
+ raise RuntimeError("No uploads were successful. The model evaluation will not be completed.")
325
+
326
+ self.__openapi_service.participant_api.participants_participant_id_submit_post(
304
327
  participant_id=participant_result.participant_id
305
328
  )
306
329
 
@@ -25,8 +25,9 @@ class RapidataBenchmarkManager:
25
25
  def create_new_benchmark(self,
26
26
  name: str,
27
27
  identifiers: list[str],
28
- prompts: Optional[list[str]] = None,
29
- prompt_assets: Optional[list[str]] = None,
28
+ prompts: Optional[list[str | None]] = None,
29
+ prompt_assets: Optional[list[str | None]] = None,
30
+ tags: Optional[list[list[str] | None]] = None,
30
31
  ) -> RapidataBenchmark:
31
32
  """
32
33
  Creates a new benchmark with the given name, identifiers, prompts, and media assets.
@@ -37,15 +38,16 @@ class RapidataBenchmarkManager:
37
38
  name: The name of the benchmark.
38
39
  prompts: The prompts that will be registered for the benchmark.
39
40
  prompt_assets: The prompt assets that will be registered for the benchmark.
41
+ tags: The tags that will be associated with the prompts to use for filtering the leaderboard results. They will NOT be shown to the users.
40
42
  """
41
43
  if not isinstance(name, str):
42
44
  raise ValueError("Name must be a string.")
43
45
 
44
- if prompts and (not isinstance(prompts, list) or not all(isinstance(prompt, str) for prompt in prompts)):
45
- raise ValueError("Prompts must be a list of strings.")
46
+ if prompts and (not isinstance(prompts, list) or not all(isinstance(prompt, str) or prompt is None for prompt in prompts)):
47
+ raise ValueError("Prompts must be a list of strings or None.")
46
48
 
47
- if prompt_assets and (not isinstance(prompt_assets, list) or not all(isinstance(asset, str) for asset in prompt_assets)):
48
- raise ValueError("Media assets must be a list of strings.")
49
+ if prompt_assets and (not isinstance(prompt_assets, list) or not all(isinstance(asset, str) or asset is None for asset in prompt_assets)):
50
+ raise ValueError("Media assets must be a list of strings or None.")
49
51
 
50
52
  if not isinstance(identifiers, list) or not all(isinstance(identifier, str) for identifier in identifiers):
51
53
  raise ValueError("Identifiers must be a list of strings.")
@@ -61,6 +63,9 @@ class RapidataBenchmarkManager:
61
63
 
62
64
  if len(set(identifiers)) != len(identifiers):
63
65
  raise ValueError("Identifiers must be unique.")
66
+
67
+ if tags and len(identifiers) != len(tags):
68
+ raise ValueError("Identifiers and tags must have the same length.")
64
69
 
65
70
  benchmark_result = self.__openapi_service.benchmark_api.benchmark_post(
66
71
  create_benchmark_model=CreateBenchmarkModel(
@@ -72,9 +77,10 @@ class RapidataBenchmarkManager:
72
77
 
73
78
  prompts_list = prompts if prompts is not None else [None] * len(identifiers)
74
79
  media_assets_list = prompt_assets if prompt_assets is not None else [None] * len(identifiers)
80
+ tags_list = tags if tags is not None else [None] * len(identifiers)
75
81
 
76
- for identifier, prompt, asset in zip(identifiers, prompts_list, media_assets_list):
77
- benchmark.add_prompt(identifier, prompt, asset)
82
+ for identifier, prompt, asset, tag in zip(identifiers, prompts_list, media_assets_list, tags_list):
83
+ benchmark.add_prompt(identifier, prompt, asset, tag)
78
84
 
79
85
  return benchmark
80
86
 
@@ -8,4 +8,4 @@ from .shuffling_selection import ShufflingSelection
8
8
  from .ab_test_selection import AbTestSelection
9
9
  from .static_selection import StaticSelection
10
10
  from .retrieval_modes import RetrievalMode
11
- from .effort_selection import EffortEstimationSelection
11
+ from .effort_selection import EffortSelection
@@ -3,9 +3,16 @@ from rapidata.api_client.models.effort_capped_selection import EffortCappedSelec
3
3
  from rapidata.rapidata_client.selection.retrieval_modes import RetrievalMode
4
4
 
5
5
 
6
- class EffortEstimationSelection(RapidataSelection):
7
-
6
+ class EffortSelection(RapidataSelection):
7
+ """
8
+ With this selection you can define the effort budget you have for a task.
9
+ As an example, you have a task that takes 10 seconds to complete. The effort budget would be 10.
8
10
 
11
+ Args:
12
+ effort_budget (int): The effort budget for the task.
13
+ retrieval_mode (RetrievalMode): The retrieval mode for the task.
14
+ max_iterations (int | None): The maximum number of iterations for the task.
15
+ """
9
16
  def __init__(self, effort_budget: int, retrieval_mode: RetrievalMode = RetrievalMode.Shuffled, max_iterations: int | None = None):
10
17
  self.effort_budget = effort_budget
11
18
  self.retrieval_mode = retrieval_mode
@@ -10,6 +10,7 @@ from rapidata.api_client.api.rapid_api import RapidApi
10
10
  from rapidata.api_client.api.leaderboard_api import LeaderboardApi
11
11
  from rapidata.api_client.api.validation_set_api import ValidationSetApi
12
12
  from rapidata.api_client.api.workflow_api import WorkflowApi
13
+ from rapidata.api_client.api.participant_api import ParticipantApi
13
14
  from rapidata.api_client.configuration import Configuration
14
15
  from rapidata.service.credential_manager import CredentialManager
15
16
  from rapidata.rapidata_client.api.rapidata_exception import RapidataApiClient
@@ -117,6 +118,10 @@ class OpenAPIService:
117
118
  @property
118
119
  def benchmark_api(self) -> BenchmarkApi:
119
120
  return BenchmarkApi(self.api_client)
121
+
122
+ @property
123
+ def participant_api(self) -> ParticipantApi:
124
+ return ParticipantApi(self.api_client)
120
125
 
121
126
  def _get_rapidata_package_version(self):
122
127
  """
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rapidata
3
- Version: 2.33.1
3
+ Version: 2.34.0
4
4
  Summary: Rapidata package containing the Rapidata Python Client to interact with the Rapidata Web API in an easy way.
5
5
  License: Apache-2.0
6
6
  Author: Rapidata AG