rapidata 2.28.4__py3-none-any.whl → 2.29.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rapidata might be problematic. Click here for more details.
- rapidata/__init__.py +1 -1
- rapidata/api_client/__init__.py +44 -17
- rapidata/api_client/api/__init__.py +1 -0
- rapidata/api_client/api/benchmark_api.py +2766 -0
- rapidata/api_client/api/campaign_api.py +0 -780
- rapidata/api_client/api/coco_api.py +0 -571
- rapidata/api_client/api/customer_rapid_api.py +332 -1
- rapidata/api_client/api/datapoint_api.py +0 -524
- rapidata/api_client/api/dataset_api.py +595 -2276
- rapidata/api_client/api/feedback_api.py +0 -270
- rapidata/api_client/api/identity_api.py +74 -888
- rapidata/api_client/api/leaderboard_api.py +1642 -259
- rapidata/api_client/api/order_api.py +617 -5692
- rapidata/api_client/api/pipeline_api.py +31 -334
- rapidata/api_client/api/validation_set_api.py +469 -3356
- rapidata/api_client/api/workflow_api.py +0 -799
- rapidata/api_client/models/__init__.py +43 -17
- rapidata/api_client/models/add_campaign_model.py +3 -3
- rapidata/api_client/models/add_validation_rapid_model.py +1 -3
- rapidata/api_client/models/add_validation_text_rapid_model.py +1 -3
- rapidata/api_client/models/and_user_filter_model.py +106 -0
- rapidata/api_client/models/and_user_filter_model_filters_inner.py +282 -0
- rapidata/api_client/models/benchmark_query_result.py +94 -0
- rapidata/api_client/models/benchmark_query_result_paged_result.py +105 -0
- rapidata/api_client/models/boost_leaderboard_model.py +89 -0
- rapidata/api_client/models/clone_order_model.py +2 -4
- rapidata/api_client/models/create_benchmark_model.py +87 -0
- rapidata/api_client/models/create_benchmark_participant_model.py +87 -0
- rapidata/api_client/models/create_benchmark_participant_result.py +89 -0
- rapidata/api_client/models/create_benchmark_result.py +87 -0
- rapidata/api_client/models/create_datapoint_result.py +4 -16
- rapidata/api_client/models/create_leaderboard_model.py +18 -2
- rapidata/api_client/models/create_leaderboard_result.py +5 -3
- rapidata/api_client/models/create_order_model.py +3 -3
- rapidata/api_client/models/file_asset_input.py +104 -0
- rapidata/api_client/models/file_asset_input1.py +104 -0
- rapidata/api_client/models/file_asset_input1_file.py +168 -0
- rapidata/api_client/models/file_asset_input2.py +104 -0
- rapidata/api_client/models/file_asset_input_file.py +182 -0
- rapidata/api_client/models/form_file_wrapper.py +120 -0
- rapidata/api_client/models/get_benchmark_by_id_query.py +96 -0
- rapidata/api_client/models/get_benchmark_by_id_query_result.py +94 -0
- rapidata/api_client/models/get_benchmark_by_id_query_result_paged_result.py +105 -0
- rapidata/api_client/models/get_benchmark_by_id_result.py +94 -0
- rapidata/api_client/models/get_participant_by_id_result.py +6 -22
- rapidata/api_client/models/get_standing_by_id_result.py +113 -0
- rapidata/api_client/models/get_validation_rapids_result.py +3 -3
- rapidata/api_client/models/get_workflow_results_result.py +3 -3
- rapidata/api_client/models/local_file_wrapper.py +120 -0
- rapidata/api_client/models/multi_asset_input.py +110 -0
- rapidata/api_client/models/multi_asset_input1.py +110 -0
- rapidata/api_client/models/multi_asset_input1_assets_inner.py +170 -0
- rapidata/api_client/models/multi_asset_input2.py +110 -0
- rapidata/api_client/models/multi_asset_input_assets_inner.py +170 -0
- rapidata/api_client/models/not_user_filter_model.py +3 -3
- rapidata/api_client/models/or_user_filter_model.py +3 -3
- rapidata/api_client/models/participant_by_benchmark.py +102 -0
- rapidata/api_client/models/participant_by_benchmark_paged_result.py +105 -0
- rapidata/api_client/models/participant_by_leaderboard.py +6 -2
- rapidata/api_client/models/participant_status.py +1 -4
- rapidata/api_client/models/pipeline_id_workflow_config_put_request.py +140 -0
- rapidata/api_client/models/potential_validation_rapid.py +103 -0
- rapidata/api_client/models/potential_validation_rapid_paged_result.py +105 -0
- rapidata/api_client/models/potential_validation_rapid_truth.py +280 -0
- rapidata/api_client/models/prompt_asset_metadata_input.py +3 -3
- rapidata/api_client/models/prompt_asset_metadata_input_asset.py +170 -0
- rapidata/api_client/models/prompt_by_benchmark_result.py +92 -0
- rapidata/api_client/models/prompt_by_benchmark_result_paged_result.py +105 -0
- rapidata/api_client/models/prompt_metadata_input.py +5 -3
- rapidata/api_client/models/proxy_file_wrapper.py +114 -0
- rapidata/api_client/models/query_validation_model.py +97 -0
- rapidata/api_client/models/rapid_model.py +3 -3
- rapidata/api_client/models/simple_workflow_config.py +3 -3
- rapidata/api_client/models/simple_workflow_model1.py +3 -3
- rapidata/api_client/models/standing_by_leaderboard.py +113 -0
- rapidata/api_client/models/standing_by_leaderboard_paged_result.py +105 -0
- rapidata/api_client/models/standing_status.py +38 -0
- rapidata/api_client/models/stream_file_wrapper.py +116 -0
- rapidata/api_client/models/submit_coco_model.py +1 -3
- rapidata/api_client/models/submit_prompt_model.py +89 -0
- rapidata/api_client/models/text_asset_input.py +100 -0
- rapidata/api_client/models/transcription_metadata_input.py +5 -3
- rapidata/api_client/models/validation_set_zip_post_request_blueprint.py +252 -0
- rapidata/api_client/models/zip_entry_file_wrapper.py +120 -0
- rapidata/api_client_README.md +67 -76
- rapidata/rapidata_client/benchmark/leaderboard/__init__.py +0 -0
- rapidata/rapidata_client/benchmark/leaderboard/rapidata_leaderboard.py +62 -0
- rapidata/rapidata_client/benchmark/rapidata_benchmark.py +227 -0
- rapidata/rapidata_client/benchmark/rapidata_benchmark_manager.py +83 -0
- rapidata/rapidata_client/filter/not_filter.py +2 -2
- rapidata/rapidata_client/filter/or_filter.py +2 -2
- rapidata/rapidata_client/metadata/__init__.py +1 -0
- rapidata/rapidata_client/metadata/_media_asset_metadata.py +8 -1
- rapidata/rapidata_client/metadata/_prompt_identifier_metadata.py +15 -0
- rapidata/rapidata_client/order/_rapidata_dataset.py +6 -6
- rapidata/rapidata_client/order/_rapidata_order_builder.py +4 -4
- rapidata/rapidata_client/order/rapidata_order.py +1 -1
- rapidata/rapidata_client/rapidata_client.py +3 -3
- rapidata/rapidata_client/validation/rapidata_validation_set.py +1 -1
- rapidata/rapidata_client/validation/rapids/rapids.py +4 -6
- rapidata/service/openapi_service.py +5 -0
- {rapidata-2.28.4.dist-info → rapidata-2.29.0.dist-info}/METADATA +1 -1
- {rapidata-2.28.4.dist-info → rapidata-2.29.0.dist-info}/RECORD +106 -57
- rapidata/rapidata_client/leaderboard/rapidata_leaderboard.py +0 -127
- rapidata/rapidata_client/leaderboard/rapidata_leaderboard_manager.py +0 -92
- /rapidata/rapidata_client/{leaderboard → benchmark}/__init__.py +0 -0
- {rapidata-2.28.4.dist-info → rapidata-2.29.0.dist-info}/LICENSE +0 -0
- {rapidata-2.28.4.dist-info → rapidata-2.29.0.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,227 @@
|
|
|
1
|
+
from rapidata.api_client.models.root_filter import RootFilter
|
|
2
|
+
from rapidata.api_client.models.filter import Filter
|
|
3
|
+
from rapidata.api_client.models.query_model import QueryModel
|
|
4
|
+
from rapidata.api_client.models.page_info import PageInfo
|
|
5
|
+
from rapidata.api_client.models.create_leaderboard_model import CreateLeaderboardModel
|
|
6
|
+
from rapidata.api_client.models.create_benchmark_participant_model import CreateBenchmarkParticipantModel
|
|
7
|
+
from rapidata.api_client.models.submit_prompt_model import SubmitPromptModel
|
|
8
|
+
|
|
9
|
+
from rapidata.rapidata_client.logging import logger
|
|
10
|
+
from rapidata.service.openapi_service import OpenAPIService
|
|
11
|
+
|
|
12
|
+
from rapidata.rapidata_client.benchmark.leaderboard.rapidata_leaderboard import RapidataLeaderboard
|
|
13
|
+
from rapidata.rapidata_client.metadata import PromptIdentifierMetadata
|
|
14
|
+
from rapidata.rapidata_client.assets import MediaAsset
|
|
15
|
+
from rapidata.rapidata_client.order._rapidata_dataset import RapidataDataset
|
|
16
|
+
|
|
17
|
+
class RapidataBenchmark:
|
|
18
|
+
"""
|
|
19
|
+
An instance of a Rapidata benchmark.
|
|
20
|
+
|
|
21
|
+
Used to interact with a specific benchmark in the Rapidata system, such as retrieving prompts and evaluating models.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
name: The name that will be used to identify the benchmark on the overview.
|
|
25
|
+
id: The id of the benchmark.
|
|
26
|
+
openapi_service: The OpenAPI service to use to interact with the Rapidata API.
|
|
27
|
+
"""
|
|
28
|
+
def __init__(self, name: str, id: str, openapi_service: OpenAPIService):
|
|
29
|
+
self.name = name
|
|
30
|
+
self.id = id
|
|
31
|
+
self.__openapi_service = openapi_service
|
|
32
|
+
self.__prompts: list[str] = []
|
|
33
|
+
self.__leaderboards: list[RapidataLeaderboard] = []
|
|
34
|
+
self.__identifiers: list[str] = []
|
|
35
|
+
|
|
36
|
+
def __instantiate_prompts(self) -> None:
|
|
37
|
+
current_page = 1
|
|
38
|
+
total_pages = None
|
|
39
|
+
|
|
40
|
+
while True:
|
|
41
|
+
prompts_result = self.__openapi_service.benchmark_api.benchmark_benchmark_id_prompts_get(
|
|
42
|
+
benchmark_id=self.id,
|
|
43
|
+
request=QueryModel(
|
|
44
|
+
page=PageInfo(
|
|
45
|
+
index=current_page,
|
|
46
|
+
size=100
|
|
47
|
+
)
|
|
48
|
+
)
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
if prompts_result.total_pages is None:
|
|
52
|
+
raise ValueError("An error occurred while fetching prompts: total_pages is None")
|
|
53
|
+
|
|
54
|
+
total_pages = prompts_result.total_pages
|
|
55
|
+
|
|
56
|
+
self.__prompts.extend([prompt.prompt for prompt in prompts_result.items])
|
|
57
|
+
self.__identifiers.extend([prompt.identifier for prompt in prompts_result.items])
|
|
58
|
+
|
|
59
|
+
if current_page >= total_pages:
|
|
60
|
+
break
|
|
61
|
+
|
|
62
|
+
current_page += 1
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
def prompts(self) -> list[str]:
|
|
66
|
+
"""
|
|
67
|
+
Returns the prompts that are registered for the leaderboard.
|
|
68
|
+
"""
|
|
69
|
+
if not self.__prompts:
|
|
70
|
+
self.__instantiate_prompts()
|
|
71
|
+
|
|
72
|
+
return self.__prompts
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def identifiers(self) -> list[str]:
|
|
76
|
+
if not self.__identifiers:
|
|
77
|
+
self.__instantiate_prompts()
|
|
78
|
+
|
|
79
|
+
return self.__identifiers
|
|
80
|
+
|
|
81
|
+
@property
|
|
82
|
+
def leaderboards(self) -> list[RapidataLeaderboard]:
|
|
83
|
+
"""
|
|
84
|
+
Returns the leaderboards that are registered for the benchmark.
|
|
85
|
+
"""
|
|
86
|
+
if not self.__leaderboards:
|
|
87
|
+
current_page = 1
|
|
88
|
+
total_pages = None
|
|
89
|
+
|
|
90
|
+
while True:
|
|
91
|
+
leaderboards_result = self.__openapi_service.leaderboard_api.leaderboards_get(
|
|
92
|
+
request=QueryModel(
|
|
93
|
+
filter=RootFilter(
|
|
94
|
+
filters=[
|
|
95
|
+
Filter(field="BenchmarkId", operator="Eq", value=self.id)
|
|
96
|
+
]
|
|
97
|
+
),
|
|
98
|
+
page=PageInfo(
|
|
99
|
+
index=current_page,
|
|
100
|
+
size=100
|
|
101
|
+
)
|
|
102
|
+
)
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
if leaderboards_result.total_pages is None:
|
|
106
|
+
raise ValueError("An error occurred while fetching leaderboards: total_pages is None")
|
|
107
|
+
|
|
108
|
+
total_pages = leaderboards_result.total_pages
|
|
109
|
+
|
|
110
|
+
self.__leaderboards.extend([
|
|
111
|
+
RapidataLeaderboard(
|
|
112
|
+
leaderboard.name,
|
|
113
|
+
leaderboard.instruction,
|
|
114
|
+
leaderboard.show_prompt,
|
|
115
|
+
leaderboard.id,
|
|
116
|
+
self.__openapi_service
|
|
117
|
+
) for leaderboard in leaderboards_result.items])
|
|
118
|
+
|
|
119
|
+
if current_page >= total_pages:
|
|
120
|
+
break
|
|
121
|
+
|
|
122
|
+
current_page += 1
|
|
123
|
+
|
|
124
|
+
return self.__leaderboards
|
|
125
|
+
|
|
126
|
+
def add_prompt(self, identifier: str, prompt: str):
|
|
127
|
+
"""
|
|
128
|
+
Adds a prompt to the benchmark.
|
|
129
|
+
"""
|
|
130
|
+
if not isinstance(identifier, str) or not isinstance(prompt, str):
|
|
131
|
+
raise ValueError("Identifier and prompt must be strings.")
|
|
132
|
+
|
|
133
|
+
if identifier in self.identifiers:
|
|
134
|
+
raise ValueError("Identifier already exists in the benchmark.")
|
|
135
|
+
|
|
136
|
+
self.__identifiers.append(identifier)
|
|
137
|
+
self.__prompts.append(prompt)
|
|
138
|
+
|
|
139
|
+
self.__openapi_service.benchmark_api.benchmark_benchmark_id_prompt_post(
|
|
140
|
+
benchmark_id=self.id,
|
|
141
|
+
submit_prompt_model=SubmitPromptModel(
|
|
142
|
+
identifier=identifier,
|
|
143
|
+
prompt=prompt,
|
|
144
|
+
)
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
def create_leaderboard(self, name: str, instruction: str, show_prompt: bool) -> RapidataLeaderboard:
|
|
148
|
+
"""
|
|
149
|
+
Creates a new leaderboard for the benchmark.
|
|
150
|
+
"""
|
|
151
|
+
leaderboard_result = self.__openapi_service.leaderboard_api.leaderboard_post(
|
|
152
|
+
create_leaderboard_model=CreateLeaderboardModel(
|
|
153
|
+
benchmarkId=self.id,
|
|
154
|
+
name=name,
|
|
155
|
+
instruction=instruction,
|
|
156
|
+
showPrompt=show_prompt
|
|
157
|
+
)
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
assert leaderboard_result.benchmark_id == self.id, "The leaderboard was not created for the correct benchmark."
|
|
161
|
+
|
|
162
|
+
return RapidataLeaderboard(
|
|
163
|
+
name,
|
|
164
|
+
instruction,
|
|
165
|
+
show_prompt,
|
|
166
|
+
leaderboard_result.id,
|
|
167
|
+
self.__openapi_service
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
def evaluate_model(self, name: str, media: list[str], identifiers: list[str]) -> None:
|
|
171
|
+
"""
|
|
172
|
+
Evaluates a model on the benchmark across all leaderboards.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
name: The name of the model.
|
|
176
|
+
media: The generated images/videos that will be used to evaluate the model.
|
|
177
|
+
identifiers: The identifiers that correspond to the media. The order of the identifiers must match the order of the media.
|
|
178
|
+
The identifiers that are used must be registered for the benchmark. To see the registered identifiers, use the identifiers property.
|
|
179
|
+
"""
|
|
180
|
+
if not media:
|
|
181
|
+
raise ValueError("Media must be a non-empty list of strings")
|
|
182
|
+
|
|
183
|
+
if len(media) != len(identifiers):
|
|
184
|
+
raise ValueError("Media and identifiers must have the same length")
|
|
185
|
+
|
|
186
|
+
if not all(identifier in self.identifiers for identifier in identifiers):
|
|
187
|
+
raise ValueError("All identifiers must be in the registered identifiers list. To see the registered identifiers, use the identifiers property.\
|
|
188
|
+
\nTo see the prompts that are associated with the identifiers, use the prompts property.")
|
|
189
|
+
|
|
190
|
+
# happens before the creation of the participant to ensure all media paths are valid
|
|
191
|
+
assets = []
|
|
192
|
+
prompts_metadata: list[list[PromptIdentifierMetadata]] = []
|
|
193
|
+
for media_path, identifier in zip(media, identifiers):
|
|
194
|
+
assets.append(MediaAsset(media_path))
|
|
195
|
+
prompts_metadata.append([PromptIdentifierMetadata(identifier=identifier)])
|
|
196
|
+
|
|
197
|
+
participant_result = self.__openapi_service.benchmark_api.benchmark_benchmark_id_participants_post(
|
|
198
|
+
benchmark_id=self.id,
|
|
199
|
+
create_benchmark_participant_model=CreateBenchmarkParticipantModel(
|
|
200
|
+
name=name,
|
|
201
|
+
)
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
dataset = RapidataDataset(participant_result.dataset_id, self.__openapi_service)
|
|
205
|
+
|
|
206
|
+
try:
|
|
207
|
+
dataset._add_datapoints(assets, prompts_metadata)
|
|
208
|
+
except Exception as e:
|
|
209
|
+
logger.warning(f"An error occurred while adding datapoints to the dataset: {e}")
|
|
210
|
+
upload_progress = self.__openapi_service.dataset_api.dataset_dataset_id_progress_get(
|
|
211
|
+
dataset_id=dataset.id
|
|
212
|
+
)
|
|
213
|
+
if upload_progress.ready == 0:
|
|
214
|
+
raise RuntimeError("None of the media was uploaded successfully. Please check the media paths and try again.")
|
|
215
|
+
|
|
216
|
+
logger.warning(f"{upload_progress.failed} datapoints failed to upload. \n{upload_progress.ready} datapoints were uploaded successfully. \nEvaluation will continue with the uploaded datapoints.")
|
|
217
|
+
|
|
218
|
+
self.__openapi_service.benchmark_api.benchmark_benchmark_id_participants_participant_id_submit_post(
|
|
219
|
+
benchmark_id=self.id,
|
|
220
|
+
participant_id=participant_result.participant_id
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
def __str__(self) -> str:
|
|
224
|
+
return f"RapidataBenchmark(name={self.name}, id={self.id})"
|
|
225
|
+
|
|
226
|
+
def __repr__(self) -> str:
|
|
227
|
+
return self.__str__()
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
from rapidata.rapidata_client.benchmark.rapidata_benchmark import RapidataBenchmark
|
|
2
|
+
from rapidata.api_client.models.create_benchmark_model import CreateBenchmarkModel
|
|
3
|
+
from rapidata.service.openapi_service import OpenAPIService
|
|
4
|
+
from rapidata.api_client.models.query_model import QueryModel
|
|
5
|
+
from rapidata.api_client.models.page_info import PageInfo
|
|
6
|
+
from rapidata.api_client.models.root_filter import RootFilter
|
|
7
|
+
from rapidata.api_client.models.filter import Filter
|
|
8
|
+
from rapidata.api_client.models.sort_criterion import SortCriterion
|
|
9
|
+
|
|
10
|
+
class RapidataBenchmarkManager:
|
|
11
|
+
"""
|
|
12
|
+
A manager for benchmarks.
|
|
13
|
+
|
|
14
|
+
Used to create and retrieve benchmarks.
|
|
15
|
+
|
|
16
|
+
A benchmark is a collection of leaderboards.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
openapi_service: The OpenAPIService instance for API interaction.
|
|
20
|
+
"""
|
|
21
|
+
def __init__(self, openapi_service: OpenAPIService):
|
|
22
|
+
self.__openapi_service = openapi_service
|
|
23
|
+
|
|
24
|
+
def create_new_benchmark(self,
|
|
25
|
+
name: str,
|
|
26
|
+
identifiers: list[str],
|
|
27
|
+
prompts: list[str],
|
|
28
|
+
) -> RapidataBenchmark:
|
|
29
|
+
"""
|
|
30
|
+
Creates a new benchmark with the given name, prompts, and leaderboards.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
name: The name of the benchmark.
|
|
34
|
+
prompts: The prompts that will be registered for the benchmark.
|
|
35
|
+
"""
|
|
36
|
+
if not isinstance(name, str):
|
|
37
|
+
raise ValueError("Name must be a string.")
|
|
38
|
+
|
|
39
|
+
if not isinstance(prompts, list) or not all(isinstance(prompt, str) for prompt in prompts):
|
|
40
|
+
raise ValueError("Prompts must be a list of strings.")
|
|
41
|
+
|
|
42
|
+
if not isinstance(identifiers, list) or not all(isinstance(identifier, str) for identifier in identifiers):
|
|
43
|
+
raise ValueError("Identifiers must be a list of strings.")
|
|
44
|
+
|
|
45
|
+
if len(identifiers) != len(prompts):
|
|
46
|
+
raise ValueError("Identifiers and prompts must have the same length.")
|
|
47
|
+
|
|
48
|
+
if len(set(identifiers)) != len(identifiers):
|
|
49
|
+
raise ValueError("Identifiers must be unique.")
|
|
50
|
+
|
|
51
|
+
benchmark_result = self.__openapi_service.benchmark_api.benchmark_post(
|
|
52
|
+
create_benchmark_model=CreateBenchmarkModel(
|
|
53
|
+
name=name,
|
|
54
|
+
)
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
benchmark = RapidataBenchmark(name, benchmark_result.id, self.__openapi_service)
|
|
58
|
+
for identifier, prompt in zip(identifiers, prompts):
|
|
59
|
+
benchmark.add_prompt(identifier, prompt)
|
|
60
|
+
|
|
61
|
+
return benchmark
|
|
62
|
+
|
|
63
|
+
def get_benchmark_by_id(self, id: str) -> RapidataBenchmark:
|
|
64
|
+
"""
|
|
65
|
+
Returns a benchmark by its ID.
|
|
66
|
+
"""
|
|
67
|
+
benchmark_result = self.__openapi_service.benchmark_api.benchmark_benchmark_id_get(
|
|
68
|
+
benchmark_id=id
|
|
69
|
+
)
|
|
70
|
+
return RapidataBenchmark(benchmark_result.name, benchmark_result.id, self.__openapi_service)
|
|
71
|
+
|
|
72
|
+
def find_benchmarks(self, name: str = "", amount: int = 10) -> list[RapidataBenchmark]:
|
|
73
|
+
"""
|
|
74
|
+
Returns a list of benchmarks by their name.
|
|
75
|
+
"""
|
|
76
|
+
benchmark_result = self.__openapi_service.benchmark_api.benchmarks_get(
|
|
77
|
+
QueryModel(
|
|
78
|
+
page=PageInfo(index=1, size=amount),
|
|
79
|
+
filter=RootFilter(filters=[Filter(field="Name", operator="Contains", value=name)]),
|
|
80
|
+
sortCriteria=[SortCriterion(direction="Desc", propertyName="CreatedAt")]
|
|
81
|
+
)
|
|
82
|
+
)
|
|
83
|
+
return [RapidataBenchmark(benchmark.name, benchmark.id, self.__openapi_service) for benchmark in benchmark_result.items]
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from typing import Any
|
|
2
2
|
from rapidata.rapidata_client.filter._base_filter import RapidataFilter
|
|
3
3
|
from rapidata.api_client.models.not_user_filter_model import NotUserFilterModel
|
|
4
|
-
from rapidata.api_client.models.
|
|
4
|
+
from rapidata.api_client.models.and_user_filter_model_filters_inner import AndUserFilterModelFiltersInner
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
class NotFilter(RapidataFilter):
|
|
@@ -27,4 +27,4 @@ class NotFilter(RapidataFilter):
|
|
|
27
27
|
self.filter = filter
|
|
28
28
|
|
|
29
29
|
def _to_model(self):
|
|
30
|
-
return NotUserFilterModel(_t="NotFilter", filter=
|
|
30
|
+
return NotUserFilterModel(_t="NotFilter", filter=AndUserFilterModelFiltersInner(self.filter._to_model()))
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from typing import Any
|
|
2
2
|
from rapidata.rapidata_client.filter._base_filter import RapidataFilter
|
|
3
3
|
from rapidata.api_client.models.or_user_filter_model import OrUserFilterModel
|
|
4
|
-
from rapidata.api_client.models.
|
|
4
|
+
from rapidata.api_client.models.and_user_filter_model_filters_inner import AndUserFilterModelFiltersInner
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
class OrFilter(RapidataFilter):
|
|
@@ -27,4 +27,4 @@ class OrFilter(RapidataFilter):
|
|
|
27
27
|
self.filters = filters
|
|
28
28
|
|
|
29
29
|
def _to_model(self):
|
|
30
|
-
return OrUserFilterModel(_t="OrFilter", filters=[
|
|
30
|
+
return OrUserFilterModel(_t="OrFilter", filters=[AndUserFilterModelFiltersInner(filter._to_model()) for filter in self.filters])
|
|
@@ -4,3 +4,4 @@ from ._public_text_metadata import PublicTextMetadata
|
|
|
4
4
|
from ._prompt_metadata import PromptMetadata
|
|
5
5
|
from ._select_words_metadata import SelectWordsMetadata
|
|
6
6
|
from ._media_asset_metadata import MediaAssetMetadata
|
|
7
|
+
from ._prompt_identifier_metadata import PromptIdentifierMetadata
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from rapidata.api_client.models.prompt_asset_metadata_input import PromptAssetMetadataInput
|
|
2
2
|
from rapidata.api_client.models.url_asset_input import UrlAssetInput
|
|
3
3
|
from rapidata.rapidata_client.metadata._base_metadata import Metadata
|
|
4
|
+
from rapidata.api_client.models.prompt_asset_metadata_input_asset import PromptAssetMetadataInputAsset
|
|
4
5
|
|
|
5
6
|
|
|
6
7
|
class MediaAssetMetadata(Metadata):
|
|
@@ -11,5 +12,11 @@ class MediaAssetMetadata(Metadata):
|
|
|
11
12
|
|
|
12
13
|
def to_model(self):
|
|
13
14
|
return PromptAssetMetadataInput(
|
|
14
|
-
_t="PromptAssetMetadataInput",
|
|
15
|
+
_t="PromptAssetMetadataInput",
|
|
16
|
+
asset=PromptAssetMetadataInputAsset(
|
|
17
|
+
actual_instance=UrlAssetInput(
|
|
18
|
+
_t="UrlAssetInput",
|
|
19
|
+
url=self._url
|
|
20
|
+
)
|
|
21
|
+
)
|
|
15
22
|
)
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from rapidata.rapidata_client.metadata._base_metadata import Metadata
|
|
2
|
+
from rapidata.api_client.models.private_text_metadata_input import (
|
|
3
|
+
PrivateTextMetadataInput,
|
|
4
|
+
)
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class PromptIdentifierMetadata(Metadata):
|
|
8
|
+
def __init__(self, identifier: str):
|
|
9
|
+
super().__init__()
|
|
10
|
+
self._identifier = identifier
|
|
11
|
+
|
|
12
|
+
def to_model(self):
|
|
13
|
+
return PrivateTextMetadataInput(
|
|
14
|
+
_t="PrivateTextMetadataInput", identifier="prompt-id", text=self._identifier
|
|
15
|
+
)
|
|
@@ -22,7 +22,7 @@ def chunk_list(lst: list, chunk_size: int) -> Generator:
|
|
|
22
22
|
class RapidataDataset:
|
|
23
23
|
|
|
24
24
|
def __init__(self, dataset_id: str, openapi_service: OpenAPIService):
|
|
25
|
-
self.
|
|
25
|
+
self.id = dataset_id
|
|
26
26
|
self.openapi_service = openapi_service
|
|
27
27
|
self.local_file_service = LocalFileService()
|
|
28
28
|
|
|
@@ -96,7 +96,7 @@ class RapidataDataset:
|
|
|
96
96
|
metadata=metadata,
|
|
97
97
|
)
|
|
98
98
|
|
|
99
|
-
self.openapi_service.dataset_api.dataset_dataset_id_datapoints_texts_post(dataset_id=self.
|
|
99
|
+
self.openapi_service.dataset_api.dataset_dataset_id_datapoints_texts_post(dataset_id=self.id, create_datapoint_from_text_sources_model=model)
|
|
100
100
|
|
|
101
101
|
total_uploads = len(text_assets)
|
|
102
102
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
@@ -161,7 +161,7 @@ class RapidataDataset:
|
|
|
161
161
|
for attempt in range(max_retries):
|
|
162
162
|
try:
|
|
163
163
|
self.openapi_service.dataset_api.dataset_dataset_id_datapoints_post(
|
|
164
|
-
dataset_id=self.
|
|
164
|
+
dataset_id=self.id,
|
|
165
165
|
file=local_paths,
|
|
166
166
|
url=urls,
|
|
167
167
|
metadata=metadata,
|
|
@@ -222,7 +222,7 @@ class RapidataDataset:
|
|
|
222
222
|
|
|
223
223
|
while not stop_event.is_set() or not all_uploads_complete.is_set():
|
|
224
224
|
try:
|
|
225
|
-
current_progress = self.openapi_service.dataset_api.dataset_dataset_id_progress_get(self.
|
|
225
|
+
current_progress = self.openapi_service.dataset_api.dataset_dataset_id_progress_get(self.id)
|
|
226
226
|
|
|
227
227
|
# Calculate items completed since our initialization
|
|
228
228
|
completed_ready = current_progress.ready
|
|
@@ -365,7 +365,7 @@ class RapidataDataset:
|
|
|
365
365
|
"""
|
|
366
366
|
try:
|
|
367
367
|
# Get final progress
|
|
368
|
-
final_progress = self.openapi_service.dataset_api.dataset_dataset_id_progress_get(self.
|
|
368
|
+
final_progress = self.openapi_service.dataset_api.dataset_dataset_id_progress_get(self.id)
|
|
369
369
|
total_ready = final_progress.ready
|
|
370
370
|
total_failed = final_progress.failed
|
|
371
371
|
|
|
@@ -373,7 +373,7 @@ class RapidataDataset:
|
|
|
373
373
|
if total_ready + total_failed < total_uploads:
|
|
374
374
|
# Try one more time after a longer wait
|
|
375
375
|
time.sleep(5 * progress_poll_interval)
|
|
376
|
-
final_progress = self.openapi_service.dataset_api.dataset_dataset_id_progress_get(self.
|
|
376
|
+
final_progress = self.openapi_service.dataset_api.dataset_dataset_id_progress_get(self.id)
|
|
377
377
|
total_ready = final_progress.ready
|
|
378
378
|
total_failed = final_progress.failed
|
|
379
379
|
|
|
@@ -5,8 +5,8 @@ from rapidata.api_client.models.create_order_model import CreateOrderModel
|
|
|
5
5
|
from rapidata.api_client.models.create_order_model_referee import (
|
|
6
6
|
CreateOrderModelReferee,
|
|
7
7
|
)
|
|
8
|
-
from rapidata.api_client.models.
|
|
9
|
-
|
|
8
|
+
from rapidata.api_client.models.and_user_filter_model_filters_inner import (
|
|
9
|
+
AndUserFilterModelFiltersInner,
|
|
10
10
|
)
|
|
11
11
|
from rapidata.api_client.models.create_order_model_workflow import (
|
|
12
12
|
CreateOrderModelWorkflow,
|
|
@@ -83,7 +83,7 @@ class RapidataOrderBuilder:
|
|
|
83
83
|
orderName=self._name,
|
|
84
84
|
workflow=CreateOrderModelWorkflow(self.__workflow._to_model()),
|
|
85
85
|
userFilters=[
|
|
86
|
-
|
|
86
|
+
AndUserFilterModelFiltersInner(user_filter._to_model())
|
|
87
87
|
for user_filter in self.__user_filters
|
|
88
88
|
],
|
|
89
89
|
referee=CreateOrderModelReferee(self.__referee._to_model()),
|
|
@@ -136,7 +136,7 @@ class RapidataOrderBuilder:
|
|
|
136
136
|
else None
|
|
137
137
|
)
|
|
138
138
|
if self.__dataset:
|
|
139
|
-
logger.debug(f"Dataset created with ID: {self.__dataset.
|
|
139
|
+
logger.debug(f"Dataset created with ID: {self.__dataset.id}")
|
|
140
140
|
else:
|
|
141
141
|
logger.warning("No dataset created for this order.")
|
|
142
142
|
|
|
@@ -80,7 +80,7 @@ class RapidataOrder:
|
|
|
80
80
|
def unpause(self) -> None:
|
|
81
81
|
"""Unpauses/resumes the order."""
|
|
82
82
|
logger.info(f"Unpausing order '{self}'")
|
|
83
|
-
self.__openapi_service.order_api.
|
|
83
|
+
self.__openapi_service.order_api.order_order_id_resume_post(self.id)
|
|
84
84
|
logger.debug(f"Order '{self}' has been unpaused.")
|
|
85
85
|
managed_print(f"Order '{self}' has been unpaused.")
|
|
86
86
|
|
|
@@ -5,7 +5,7 @@ from rapidata import __version__
|
|
|
5
5
|
from rapidata.service.openapi_service import OpenAPIService
|
|
6
6
|
|
|
7
7
|
from rapidata.rapidata_client.order.rapidata_order_manager import RapidataOrderManager
|
|
8
|
-
from rapidata.rapidata_client.
|
|
8
|
+
from rapidata.rapidata_client.benchmark.rapidata_benchmark_manager import RapidataBenchmarkManager
|
|
9
9
|
|
|
10
10
|
from rapidata.rapidata_client.validation.validation_set_manager import (
|
|
11
11
|
ValidationSetManager,
|
|
@@ -67,8 +67,8 @@ class RapidataClient:
|
|
|
67
67
|
logger.debug("Initializing DemographicManager")
|
|
68
68
|
self._demographic = DemographicManager(openapi_service=self._openapi_service)
|
|
69
69
|
|
|
70
|
-
logger.debug("Initializing
|
|
71
|
-
self.mri =
|
|
70
|
+
logger.debug("Initializing RapidataBenchmarkManager")
|
|
71
|
+
self.mri = RapidataBenchmarkManager(openapi_service=self._openapi_service)
|
|
72
72
|
|
|
73
73
|
def reset_credentials(self):
|
|
74
74
|
"""Reset the credentials saved in the configuration file for the current environment."""
|
|
@@ -37,7 +37,7 @@ class RapidataValidationSet:
|
|
|
37
37
|
dimensions (list[str]): The new dimensions of the validation set.
|
|
38
38
|
"""
|
|
39
39
|
logger.debug(f"Updating dimensions for validation set {self.id} to {dimensions}")
|
|
40
|
-
self.__openapi_service.validation_api.
|
|
40
|
+
self.__openapi_service.validation_api.validation_set_validation_set_id_dimensions_put(self.id, UpdateDimensionsModel(dimensions=dimensions))
|
|
41
41
|
return self
|
|
42
42
|
|
|
43
43
|
def __str__(self):
|
|
@@ -35,11 +35,11 @@ class Rapid():
|
|
|
35
35
|
if isinstance(self.asset, TextAsset) or (isinstance(self.asset, MultiAsset) and isinstance(self.asset.assets[0], TextAsset)):
|
|
36
36
|
openapi_service.validation_api.validation_set_validation_set_id_rapid_texts_post(
|
|
37
37
|
validation_set_id=validationSetId,
|
|
38
|
-
add_validation_text_rapid_model=self.__to_text_model(
|
|
38
|
+
add_validation_text_rapid_model=self.__to_text_model()
|
|
39
39
|
)
|
|
40
40
|
|
|
41
41
|
elif isinstance(self.asset, MediaAsset) or (isinstance(self.asset, MultiAsset) and isinstance(self.asset.assets[0], MediaAsset)):
|
|
42
|
-
model = self.__to_media_model(
|
|
42
|
+
model = self.__to_media_model()
|
|
43
43
|
openapi_service.validation_api.validation_set_validation_set_id_rapid_files_post(
|
|
44
44
|
validation_set_id=validationSetId,
|
|
45
45
|
model=model[0], files=model[1]
|
|
@@ -48,7 +48,7 @@ class Rapid():
|
|
|
48
48
|
else:
|
|
49
49
|
raise TypeError("The asset must be a MediaAsset, TextAsset, or MultiAsset")
|
|
50
50
|
|
|
51
|
-
def __to_media_model(self
|
|
51
|
+
def __to_media_model(self) -> tuple[AddValidationRapidModel, list[StrictStr | tuple[StrictStr, StrictBytes] | StrictBytes]]:
|
|
52
52
|
assets: list[MediaAsset] = []
|
|
53
53
|
if isinstance(self.asset, MultiAsset):
|
|
54
54
|
for asset in self.asset.assets:
|
|
@@ -64,7 +64,6 @@ class Rapid():
|
|
|
64
64
|
assets = [self.asset]
|
|
65
65
|
|
|
66
66
|
return (AddValidationRapidModel(
|
|
67
|
-
validationSetId=validationSetId, # will be removed in the future
|
|
68
67
|
payload=AddValidationRapidModelPayload(self.payload),
|
|
69
68
|
truth=AddValidationRapidModelTruth(self.truth),
|
|
70
69
|
metadata=[
|
|
@@ -75,7 +74,7 @@ class Rapid():
|
|
|
75
74
|
explanation=self.explanation
|
|
76
75
|
), [asset.to_file() for asset in assets])
|
|
77
76
|
|
|
78
|
-
def __to_text_model(self
|
|
77
|
+
def __to_text_model(self) -> AddValidationTextRapidModel:
|
|
79
78
|
texts: list[str] = []
|
|
80
79
|
if isinstance(self.asset, MultiAsset):
|
|
81
80
|
for asset in self.asset.assets:
|
|
@@ -91,7 +90,6 @@ class Rapid():
|
|
|
91
90
|
texts = [self.asset.text]
|
|
92
91
|
|
|
93
92
|
return AddValidationTextRapidModel(
|
|
94
|
-
validationSetId=validationSetId, # will be removed in the future
|
|
95
93
|
payload=AddValidationRapidModelPayload(self.payload),
|
|
96
94
|
truth=AddValidationRapidModelTruth(self.truth),
|
|
97
95
|
metadata=[
|
|
@@ -3,6 +3,7 @@ from importlib.metadata import version, PackageNotFoundError
|
|
|
3
3
|
|
|
4
4
|
from rapidata.api_client.api.campaign_api import CampaignApi
|
|
5
5
|
from rapidata.api_client.api.dataset_api import DatasetApi
|
|
6
|
+
from rapidata.api_client.api.benchmark_api import BenchmarkApi
|
|
6
7
|
from rapidata.api_client.api.order_api import OrderApi
|
|
7
8
|
from rapidata.api_client.api.pipeline_api import PipelineApi
|
|
8
9
|
from rapidata.api_client.api.rapid_api import RapidApi
|
|
@@ -112,6 +113,10 @@ class OpenAPIService:
|
|
|
112
113
|
@property
|
|
113
114
|
def leaderboard_api(self) -> LeaderboardApi:
|
|
114
115
|
return LeaderboardApi(self.api_client)
|
|
116
|
+
|
|
117
|
+
@property
|
|
118
|
+
def benchmark_api(self) -> BenchmarkApi:
|
|
119
|
+
return BenchmarkApi(self.api_client)
|
|
115
120
|
|
|
116
121
|
def _get_rapidata_package_version(self):
|
|
117
122
|
"""
|