rapidata 2.41.2__py3-none-any.whl → 2.42.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rapidata might be problematic. Click here for more details.
- rapidata/__init__.py +1 -5
- rapidata/api_client/__init__.py +14 -14
- rapidata/api_client/api/__init__.py +1 -0
- rapidata/api_client/api/asset_api.py +851 -0
- rapidata/api_client/api/benchmark_api.py +298 -0
- rapidata/api_client/api/customer_rapid_api.py +29 -43
- rapidata/api_client/api/dataset_api.py +163 -1143
- rapidata/api_client/api/participant_api.py +28 -74
- rapidata/api_client/api/validation_set_api.py +283 -0
- rapidata/api_client/models/__init__.py +13 -14
- rapidata/api_client/models/add_validation_rapid_model.py +3 -3
- rapidata/api_client/models/add_validation_rapid_new_model.py +152 -0
- rapidata/api_client/models/add_validation_rapid_new_model_asset.py +182 -0
- rapidata/api_client/models/compare_workflow_model.py +3 -3
- rapidata/api_client/models/create_datapoint_from_files_model.py +3 -3
- rapidata/api_client/models/create_datapoint_from_text_sources_model.py +3 -3
- rapidata/api_client/models/create_datapoint_from_urls_model.py +3 -3
- rapidata/api_client/models/create_datapoint_model.py +108 -0
- rapidata/api_client/models/create_datapoint_model_asset.py +182 -0
- rapidata/api_client/models/create_demographic_rapid_model.py +13 -2
- rapidata/api_client/models/create_demographic_rapid_model_asset.py +188 -0
- rapidata/api_client/models/create_demographic_rapid_model_new.py +119 -0
- rapidata/api_client/models/create_sample_model.py +8 -2
- rapidata/api_client/models/create_sample_model_asset.py +182 -0
- rapidata/api_client/models/create_sample_model_obsolete.py +87 -0
- rapidata/api_client/models/file_asset_input_file.py +8 -22
- rapidata/api_client/models/fork_benchmark_result.py +87 -0
- rapidata/api_client/models/form_file_wrapper.py +17 -2
- rapidata/api_client/models/get_asset_metadata_result.py +100 -0
- rapidata/api_client/models/multi_asset_input_assets_inner.py +10 -24
- rapidata/api_client/models/prompt_asset_metadata_input.py +3 -3
- rapidata/api_client/models/proxy_file_wrapper.py +17 -2
- rapidata/api_client/models/stream_file_wrapper.py +25 -3
- rapidata/api_client/models/submit_prompt_model.py +3 -3
- rapidata/api_client/models/text_metadata.py +6 -1
- rapidata/api_client/models/text_metadata_model.py +7 -2
- rapidata/api_client/models/upload_file_from_url_result.py +87 -0
- rapidata/api_client/models/upload_file_result.py +87 -0
- rapidata/api_client/models/zip_entry_file_wrapper.py +33 -2
- rapidata/api_client_README.md +28 -25
- rapidata/rapidata_client/__init__.py +0 -1
- rapidata/rapidata_client/benchmark/participant/_participant.py +24 -22
- rapidata/rapidata_client/benchmark/rapidata_benchmark.py +89 -102
- rapidata/rapidata_client/datapoints/__init__.py +0 -1
- rapidata/rapidata_client/datapoints/_asset_uploader.py +71 -0
- rapidata/rapidata_client/datapoints/_datapoint.py +58 -171
- rapidata/rapidata_client/datapoints/_datapoint_uploader.py +95 -0
- rapidata/rapidata_client/datapoints/assets/__init__.py +0 -11
- rapidata/rapidata_client/datapoints/metadata/_media_asset_metadata.py +10 -7
- rapidata/rapidata_client/demographic/demographic_manager.py +21 -8
- rapidata/rapidata_client/exceptions/failed_upload_exception.py +0 -62
- rapidata/rapidata_client/order/_rapidata_order_builder.py +0 -10
- rapidata/rapidata_client/order/dataset/_rapidata_dataset.py +67 -187
- rapidata/rapidata_client/order/rapidata_order_manager.py +58 -116
- rapidata/rapidata_client/settings/translation_behaviour.py +1 -1
- rapidata/rapidata_client/validation/rapidata_validation_set.py +9 -5
- rapidata/rapidata_client/validation/rapids/_validation_rapid_uploader.py +101 -0
- rapidata/rapidata_client/validation/rapids/box.py +35 -11
- rapidata/rapidata_client/validation/rapids/rapids.py +26 -128
- rapidata/rapidata_client/validation/rapids/rapids_manager.py +123 -104
- rapidata/rapidata_client/validation/validation_set_manager.py +25 -34
- rapidata/rapidata_client/workflow/_ranking_workflow.py +14 -17
- rapidata/rapidata_client/workflow/_select_words_workflow.py +3 -16
- rapidata/service/openapi_service.py +8 -3
- {rapidata-2.41.2.dist-info → rapidata-2.42.0.dist-info}/METADATA +1 -1
- {rapidata-2.41.2.dist-info → rapidata-2.42.0.dist-info}/RECORD +68 -59
- {rapidata-2.41.2.dist-info → rapidata-2.42.0.dist-info}/WHEEL +1 -1
- rapidata/rapidata_client/datapoints/assets/_base_asset.py +0 -13
- rapidata/rapidata_client/datapoints/assets/_media_asset.py +0 -318
- rapidata/rapidata_client/datapoints/assets/_multi_asset.py +0 -61
- rapidata/rapidata_client/datapoints/assets/_sessions.py +0 -40
- rapidata/rapidata_client/datapoints/assets/_text_asset.py +0 -34
- rapidata/rapidata_client/datapoints/assets/data_type_enum.py +0 -8
- rapidata/rapidata_client/order/dataset/_progress_tracker.py +0 -100
- {rapidata-2.41.2.dist-info → rapidata-2.42.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -16,10 +16,10 @@ from rapidata.api_client.models.query_model import QueryModel
|
|
|
16
16
|
from rapidata.api_client.models.page_info import PageInfo
|
|
17
17
|
from rapidata.api_client.models.source_url_metadata_model import SourceUrlMetadataModel
|
|
18
18
|
from rapidata.api_client.models.submit_prompt_model import SubmitPromptModel
|
|
19
|
-
from rapidata.api_client.models.
|
|
20
|
-
|
|
19
|
+
from rapidata.api_client.models.create_demographic_rapid_model_asset import (
|
|
20
|
+
CreateDemographicRapidModelAsset,
|
|
21
21
|
)
|
|
22
|
-
from rapidata.api_client.models.
|
|
22
|
+
from rapidata.api_client.models.existing_asset_input import ExistingAssetInput
|
|
23
23
|
|
|
24
24
|
from rapidata.rapidata_client.benchmark._detail_mapper import DetailMapper
|
|
25
25
|
from rapidata.rapidata_client.benchmark.leaderboard.rapidata_leaderboard import (
|
|
@@ -28,7 +28,7 @@ from rapidata.rapidata_client.benchmark.leaderboard.rapidata_leaderboard import
|
|
|
28
28
|
from rapidata.rapidata_client.benchmark.participant._participant import (
|
|
29
29
|
BenchmarkParticipant,
|
|
30
30
|
)
|
|
31
|
-
from rapidata.rapidata_client.datapoints.
|
|
31
|
+
from rapidata.rapidata_client.datapoints._asset_uploader import AssetUploader
|
|
32
32
|
from rapidata.rapidata_client.filter import RapidataFilter
|
|
33
33
|
from rapidata.rapidata_client.config import logger, managed_print, tracer
|
|
34
34
|
from rapidata.rapidata_client.settings import RapidataSetting
|
|
@@ -50,96 +50,92 @@ class RapidataBenchmark:
|
|
|
50
50
|
def __init__(self, name: str, id: str, openapi_service: OpenAPIService):
|
|
51
51
|
self.name = name
|
|
52
52
|
self.id = id
|
|
53
|
-
self.
|
|
53
|
+
self._openapi_service = openapi_service
|
|
54
54
|
self.__prompts: list[str | None] = []
|
|
55
55
|
self.__prompt_assets: list[str | None] = []
|
|
56
56
|
self.__leaderboards: list[RapidataLeaderboard] = []
|
|
57
57
|
self.__identifiers: list[str] = []
|
|
58
58
|
self.__tags: list[list[str]] = []
|
|
59
59
|
self.__benchmark_page: str = (
|
|
60
|
-
f"https://app.{self.
|
|
60
|
+
f"https://app.{self._openapi_service.environment}/mri/benchmarks/{self.id}"
|
|
61
61
|
)
|
|
62
|
+
self._asset_uploader = AssetUploader(openapi_service)
|
|
62
63
|
|
|
63
64
|
def __instantiate_prompts(self) -> None:
|
|
64
|
-
|
|
65
|
-
|
|
65
|
+
with tracer.start_as_current_span("RapidataBenchmark.__instantiate_prompts"):
|
|
66
|
+
current_page = 1
|
|
67
|
+
total_pages = None
|
|
66
68
|
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
self.__openapi_service.benchmark_api.benchmark_benchmark_id_prompts_get(
|
|
69
|
+
while True:
|
|
70
|
+
prompts_result = self._openapi_service.benchmark_api.benchmark_benchmark_id_prompts_get(
|
|
70
71
|
benchmark_id=self.id,
|
|
71
72
|
request=QueryModel(page=PageInfo(index=current_page, size=100)),
|
|
72
73
|
)
|
|
73
|
-
)
|
|
74
74
|
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
75
|
+
if prompts_result.total_pages is None:
|
|
76
|
+
raise ValueError(
|
|
77
|
+
"An error occurred while fetching prompts: total_pages is None"
|
|
78
|
+
)
|
|
79
79
|
|
|
80
|
-
|
|
80
|
+
total_pages = prompts_result.total_pages
|
|
81
81
|
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
82
|
+
for prompt in prompts_result.items:
|
|
83
|
+
self.__prompts.append(prompt.prompt)
|
|
84
|
+
self.__identifiers.append(prompt.identifier)
|
|
85
|
+
if prompt.prompt_asset is None:
|
|
86
|
+
self.__prompt_assets.append(None)
|
|
87
|
+
else:
|
|
88
|
+
assert isinstance(
|
|
89
|
+
prompt.prompt_asset.actual_instance, FileAssetModel
|
|
90
|
+
)
|
|
91
|
+
source_url = prompt.prompt_asset.actual_instance.metadata[
|
|
92
|
+
"sourceUrl"
|
|
93
|
+
].actual_instance
|
|
94
|
+
assert isinstance(source_url, SourceUrlMetadataModel)
|
|
95
|
+
self.__prompt_assets.append(source_url.url)
|
|
96
96
|
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
97
|
+
self.__tags.append(prompt.tags)
|
|
98
|
+
if current_page >= total_pages:
|
|
99
|
+
break
|
|
100
100
|
|
|
101
|
-
|
|
101
|
+
current_page += 1
|
|
102
102
|
|
|
103
103
|
@property
|
|
104
104
|
def identifiers(self) -> list[str]:
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
self.__instantiate_prompts()
|
|
105
|
+
if not self.__identifiers:
|
|
106
|
+
self.__instantiate_prompts()
|
|
108
107
|
|
|
109
|
-
|
|
108
|
+
return self.__identifiers
|
|
110
109
|
|
|
111
110
|
@property
|
|
112
111
|
def prompts(self) -> list[str | None]:
|
|
113
112
|
"""
|
|
114
113
|
Returns the prompts that are registered for the leaderboard.
|
|
115
114
|
"""
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
self.__instantiate_prompts()
|
|
115
|
+
if not self.__prompts:
|
|
116
|
+
self.__instantiate_prompts()
|
|
119
117
|
|
|
120
|
-
|
|
118
|
+
return self.__prompts
|
|
121
119
|
|
|
122
120
|
@property
|
|
123
121
|
def prompt_assets(self) -> list[str | None]:
|
|
124
122
|
"""
|
|
125
123
|
Returns the prompt assets that are registered for the benchmark.
|
|
126
124
|
"""
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
self.__instantiate_prompts()
|
|
125
|
+
if not self.__prompt_assets:
|
|
126
|
+
self.__instantiate_prompts()
|
|
130
127
|
|
|
131
|
-
|
|
128
|
+
return self.__prompt_assets
|
|
132
129
|
|
|
133
130
|
@property
|
|
134
131
|
def tags(self) -> list[list[str]]:
|
|
135
132
|
"""
|
|
136
133
|
Returns the tags that are registered for the benchmark.
|
|
137
134
|
"""
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
self.__instantiate_prompts()
|
|
135
|
+
if not self.__tags:
|
|
136
|
+
self.__instantiate_prompts()
|
|
141
137
|
|
|
142
|
-
|
|
138
|
+
return self.__tags
|
|
143
139
|
|
|
144
140
|
@property
|
|
145
141
|
def leaderboards(self) -> list[RapidataLeaderboard]:
|
|
@@ -152,7 +148,7 @@ class RapidataBenchmark:
|
|
|
152
148
|
total_pages = None
|
|
153
149
|
|
|
154
150
|
while True:
|
|
155
|
-
leaderboards_result = self.
|
|
151
|
+
leaderboards_result = self._openapi_service.benchmark_api.benchmark_benchmark_id_leaderboards_get(
|
|
156
152
|
benchmark_id=self.id,
|
|
157
153
|
request=QueryModel(
|
|
158
154
|
page=PageInfo(index=current_page, size=100),
|
|
@@ -178,7 +174,7 @@ class RapidataBenchmark:
|
|
|
178
174
|
leaderboard.min_responses,
|
|
179
175
|
self.id,
|
|
180
176
|
leaderboard.id,
|
|
181
|
-
self.
|
|
177
|
+
self._openapi_service,
|
|
182
178
|
)
|
|
183
179
|
for leaderboard in leaderboards_result.items
|
|
184
180
|
]
|
|
@@ -239,9 +235,6 @@ class RapidataBenchmark:
|
|
|
239
235
|
if identifier in self.identifiers:
|
|
240
236
|
raise ValueError("Identifier already exists in the benchmark.")
|
|
241
237
|
|
|
242
|
-
if prompt_asset is not None and not re.match(r"^https?://", prompt_asset):
|
|
243
|
-
raise ValueError("Prompt asset must be a link to the asset.")
|
|
244
|
-
|
|
245
238
|
if tags is not None and (
|
|
246
239
|
not isinstance(tags, list)
|
|
247
240
|
or not all(isinstance(tag, str) for tag in tags)
|
|
@@ -263,14 +256,17 @@ class RapidataBenchmark:
|
|
|
263
256
|
self.__prompts.append(prompt)
|
|
264
257
|
self.__prompt_assets.append(prompt_asset)
|
|
265
258
|
|
|
266
|
-
self.
|
|
259
|
+
self._openapi_service.benchmark_api.benchmark_benchmark_id_prompt_post(
|
|
267
260
|
benchmark_id=self.id,
|
|
268
261
|
submit_prompt_model=SubmitPromptModel(
|
|
269
262
|
identifier=identifier,
|
|
270
263
|
prompt=prompt,
|
|
271
264
|
promptAsset=(
|
|
272
|
-
|
|
273
|
-
|
|
265
|
+
CreateDemographicRapidModelAsset(
|
|
266
|
+
actual_instance=ExistingAssetInput(
|
|
267
|
+
_t="ExistingAssetInput",
|
|
268
|
+
name=self._asset_uploader.upload_asset(prompt_asset),
|
|
269
|
+
)
|
|
274
270
|
)
|
|
275
271
|
if prompt_asset is not None
|
|
276
272
|
else None
|
|
@@ -328,32 +324,30 @@ class RapidataBenchmark:
|
|
|
328
324
|
settings,
|
|
329
325
|
)
|
|
330
326
|
|
|
331
|
-
leaderboard_result = (
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
),
|
|
356
|
-
)
|
|
327
|
+
leaderboard_result = self._openapi_service.leaderboard_api.leaderboard_post(
|
|
328
|
+
create_leaderboard_model=CreateLeaderboardModel(
|
|
329
|
+
benchmarkId=self.id,
|
|
330
|
+
name=name,
|
|
331
|
+
instruction=instruction,
|
|
332
|
+
showPrompt=show_prompt,
|
|
333
|
+
showPromptAsset=show_prompt_asset,
|
|
334
|
+
isInversed=inverse_ranking,
|
|
335
|
+
minResponses=min_responses_per_matchup,
|
|
336
|
+
responseBudget=DetailMapper.get_budget(level_of_detail),
|
|
337
|
+
validationSetId=validation_set_id,
|
|
338
|
+
filters=(
|
|
339
|
+
[
|
|
340
|
+
AndUserFilterModelFiltersInner(filter._to_model())
|
|
341
|
+
for filter in filters
|
|
342
|
+
]
|
|
343
|
+
if filters
|
|
344
|
+
else None
|
|
345
|
+
),
|
|
346
|
+
featureFlags=(
|
|
347
|
+
[setting._to_feature_flag() for setting in settings]
|
|
348
|
+
if settings
|
|
349
|
+
else None
|
|
350
|
+
),
|
|
357
351
|
)
|
|
358
352
|
)
|
|
359
353
|
|
|
@@ -373,7 +367,7 @@ class RapidataBenchmark:
|
|
|
373
367
|
min_responses_per_matchup,
|
|
374
368
|
self.id,
|
|
375
369
|
leaderboard_result.id,
|
|
376
|
-
self.
|
|
370
|
+
self._openapi_service,
|
|
377
371
|
)
|
|
378
372
|
|
|
379
373
|
def evaluate_model(
|
|
@@ -421,12 +415,7 @@ class RapidataBenchmark:
|
|
|
421
415
|
"All identifiers/prompts must be in the registered identifiers/prompts list. To see the registered identifiers/prompts, use the identifiers/prompts property."
|
|
422
416
|
)
|
|
423
417
|
|
|
424
|
-
|
|
425
|
-
assets: list[MediaAsset] = []
|
|
426
|
-
for media_path in media:
|
|
427
|
-
assets.append(MediaAsset(media_path))
|
|
428
|
-
|
|
429
|
-
participant_result = self.__openapi_service.benchmark_api.benchmark_benchmark_id_participants_post(
|
|
418
|
+
participant_result = self._openapi_service.benchmark_api.benchmark_benchmark_id_participants_post(
|
|
430
419
|
benchmark_id=self.id,
|
|
431
420
|
create_benchmark_participant_model=CreateBenchmarkParticipantModel(
|
|
432
421
|
name=name,
|
|
@@ -436,20 +425,20 @@ class RapidataBenchmark:
|
|
|
436
425
|
logger.info(f"Participant created: {participant_result.participant_id}")
|
|
437
426
|
|
|
438
427
|
participant = BenchmarkParticipant(
|
|
439
|
-
name, participant_result.participant_id, self.
|
|
428
|
+
name, participant_result.participant_id, self._openapi_service
|
|
440
429
|
)
|
|
441
430
|
|
|
442
431
|
with tracer.start_as_current_span("upload_media_for_participant"):
|
|
443
432
|
logger.info(
|
|
444
|
-
f"Uploading {len(
|
|
433
|
+
f"Uploading {len(media)} media assets to participant {participant.id}"
|
|
445
434
|
)
|
|
446
435
|
|
|
447
436
|
successful_uploads, failed_uploads = participant.upload_media(
|
|
448
|
-
|
|
437
|
+
media,
|
|
449
438
|
identifiers,
|
|
450
439
|
)
|
|
451
440
|
|
|
452
|
-
total_uploads = len(
|
|
441
|
+
total_uploads = len(media)
|
|
453
442
|
success_rate = (
|
|
454
443
|
(len(successful_uploads) / total_uploads * 100)
|
|
455
444
|
if total_uploads > 0
|
|
@@ -460,9 +449,7 @@ class RapidataBenchmark:
|
|
|
460
449
|
)
|
|
461
450
|
|
|
462
451
|
if failed_uploads:
|
|
463
|
-
logger.error(
|
|
464
|
-
f"Failed uploads for media: {[asset.path for asset in failed_uploads]}"
|
|
465
|
-
)
|
|
452
|
+
logger.error(f"Failed uploads for media: {failed_uploads}")
|
|
466
453
|
logger.warning(
|
|
467
454
|
"Some uploads failed. The model evaluation may be incomplete."
|
|
468
455
|
)
|
|
@@ -472,9 +459,9 @@ class RapidataBenchmark:
|
|
|
472
459
|
"No uploads were successful. The model evaluation will not be completed."
|
|
473
460
|
)
|
|
474
461
|
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
462
|
+
self._openapi_service.participant_api.participants_participant_id_submit_post(
|
|
463
|
+
participant_id=participant_result.participant_id
|
|
464
|
+
)
|
|
478
465
|
|
|
479
466
|
def view(self) -> None:
|
|
480
467
|
"""
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
import re
|
|
2
|
+
import os
|
|
3
|
+
from rapidata.api_client.models.existing_asset_input import ExistingAssetInput
|
|
4
|
+
from rapidata.api_client.models.multi_asset_input import (
|
|
5
|
+
MultiAssetInput,
|
|
6
|
+
MultiAssetInputAssetsInner,
|
|
7
|
+
)
|
|
8
|
+
from rapidata.api_client.models.text_asset_input import TextAssetInput
|
|
9
|
+
from rapidata.service.openapi_service import OpenAPIService
|
|
10
|
+
from rapidata.rapidata_client.config import logger
|
|
11
|
+
from rapidata.rapidata_client.config import tracer
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class AssetUploader:
|
|
15
|
+
def __init__(self, openapi_service: OpenAPIService):
|
|
16
|
+
self.openapi_service = openapi_service
|
|
17
|
+
|
|
18
|
+
def upload_asset(self, asset: str) -> str:
|
|
19
|
+
with tracer.start_as_current_span("AssetUploader.upload_asset"):
|
|
20
|
+
logger.debug("Uploading asset: %s", asset)
|
|
21
|
+
assert isinstance(asset, str), "Asset must be a string"
|
|
22
|
+
|
|
23
|
+
if re.match(r"^https?://", asset):
|
|
24
|
+
response = self.openapi_service.asset_api.asset_url_post(
|
|
25
|
+
url=asset,
|
|
26
|
+
)
|
|
27
|
+
else:
|
|
28
|
+
if not os.path.exists(asset):
|
|
29
|
+
raise FileNotFoundError(f"File not found: {asset}")
|
|
30
|
+
response = self.openapi_service.asset_api.asset_file_post(
|
|
31
|
+
file=asset,
|
|
32
|
+
)
|
|
33
|
+
return response.file_name
|
|
34
|
+
|
|
35
|
+
def get_uploaded_text_input(
|
|
36
|
+
self, assets: list[str] | str
|
|
37
|
+
) -> MultiAssetInput | TextAssetInput:
|
|
38
|
+
if isinstance(assets, list):
|
|
39
|
+
return MultiAssetInput(
|
|
40
|
+
_t="MultiAssetInput",
|
|
41
|
+
assets=[
|
|
42
|
+
MultiAssetInputAssetsInner(
|
|
43
|
+
actual_instance=TextAssetInput(_t="TextAssetInput", text=asset)
|
|
44
|
+
)
|
|
45
|
+
for asset in assets
|
|
46
|
+
],
|
|
47
|
+
)
|
|
48
|
+
else:
|
|
49
|
+
return TextAssetInput(_t="TextAssetInput", text=assets)
|
|
50
|
+
|
|
51
|
+
def get_uploaded_asset_input(
|
|
52
|
+
self, assets: list[str] | str
|
|
53
|
+
) -> MultiAssetInput | ExistingAssetInput:
|
|
54
|
+
if isinstance(assets, list):
|
|
55
|
+
return MultiAssetInput(
|
|
56
|
+
_t="MultiAssetInput",
|
|
57
|
+
assets=[
|
|
58
|
+
MultiAssetInputAssetsInner(
|
|
59
|
+
actual_instance=ExistingAssetInput(
|
|
60
|
+
_t="ExistingAssetInput",
|
|
61
|
+
name=self.upload_asset(asset),
|
|
62
|
+
),
|
|
63
|
+
)
|
|
64
|
+
for asset in assets
|
|
65
|
+
],
|
|
66
|
+
)
|
|
67
|
+
else:
|
|
68
|
+
return ExistingAssetInput(
|
|
69
|
+
_t="ExistingAssetInput",
|
|
70
|
+
name=self.upload_asset(assets),
|
|
71
|
+
)
|
|
@@ -1,196 +1,83 @@
|
|
|
1
|
-
from typing import
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
MultiAsset,
|
|
6
|
-
BaseAsset,
|
|
7
|
-
)
|
|
1
|
+
from typing import Literal
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, model_validator, field_validator
|
|
4
|
+
from typing_extensions import Self
|
|
8
5
|
from rapidata.rapidata_client.datapoints.assets.constants import (
|
|
9
6
|
ALLOWED_VIDEO_EXTENSIONS,
|
|
10
7
|
ALLOWED_IMAGE_EXTENSIONS,
|
|
11
8
|
ALLOWED_AUDIO_EXTENSIONS,
|
|
12
9
|
)
|
|
13
|
-
from rapidata.rapidata_client.datapoints.metadata import Metadata
|
|
14
|
-
from rapidata.api_client.models.dataset_dataset_id_datapoints_post_request_metadata_inner import (
|
|
15
|
-
DatasetDatasetIdDatapointsPostRequestMetadataInner,
|
|
16
|
-
)
|
|
17
|
-
from rapidata.api_client.models.create_datapoint_from_text_sources_model import (
|
|
18
|
-
CreateDatapointFromTextSourcesModel,
|
|
19
|
-
)
|
|
20
|
-
from pydantic import StrictStr, StrictBytes
|
|
21
10
|
from rapidata.api_client.models.asset_type import AssetType
|
|
22
11
|
from rapidata.api_client.models.prompt_type import PromptType
|
|
23
|
-
from rapidata.rapidata_client.datapoints.metadata._media_asset_metadata import (
|
|
24
|
-
MediaAssetMetadata,
|
|
25
|
-
)
|
|
26
|
-
from rapidata.rapidata_client.datapoints.metadata._prompt_metadata import PromptMetadata
|
|
27
12
|
from rapidata.rapidata_client.config import logger
|
|
28
13
|
|
|
29
14
|
|
|
30
|
-
class Datapoint:
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
raise TypeError(
|
|
38
|
-
"Asset must be of type MediaAsset, TextAsset, or MultiAsset."
|
|
39
|
-
)
|
|
40
|
-
|
|
41
|
-
if metadata and not isinstance(metadata, Sequence):
|
|
42
|
-
raise TypeError("Metadata must be a list of Metadata objects.")
|
|
43
|
-
|
|
44
|
-
if metadata and not all(isinstance(m, Metadata) for m in metadata):
|
|
45
|
-
raise TypeError("All metadata objects must be of type Metadata.")
|
|
46
|
-
|
|
47
|
-
self.asset = asset
|
|
48
|
-
self.metadata = metadata
|
|
49
|
-
|
|
50
|
-
def _get_effective_asset_type(self) -> type:
|
|
51
|
-
"""Get the effective asset type, handling MultiAsset by looking at its first asset."""
|
|
52
|
-
if isinstance(self.asset, MultiAsset):
|
|
53
|
-
return type(self.asset.assets[0])
|
|
54
|
-
return type(self.asset)
|
|
15
|
+
class Datapoint(BaseModel):
|
|
16
|
+
asset: str | list[str]
|
|
17
|
+
data_type: Literal["text", "media"]
|
|
18
|
+
context: str | None = None
|
|
19
|
+
media_context: str | None = None
|
|
20
|
+
sentence: str | None = None
|
|
21
|
+
private_note: str | None = None
|
|
55
22
|
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
23
|
+
@field_validator("context")
|
|
24
|
+
@classmethod
|
|
25
|
+
def context_not_empty(cls, v: str | None) -> str | None:
|
|
26
|
+
if v is not None and v == "":
|
|
27
|
+
raise ValueError(
|
|
28
|
+
"context cannot be an empty string. If not needed, set to None."
|
|
29
|
+
)
|
|
30
|
+
return v
|
|
60
31
|
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
32
|
+
@field_validator("media_context")
|
|
33
|
+
@classmethod
|
|
34
|
+
def media_context_not_empty(cls, v: str | None) -> str | None:
|
|
35
|
+
if v is not None and v == "":
|
|
36
|
+
raise ValueError(
|
|
37
|
+
"media_context cannot be an empty string. If not needed, set to None."
|
|
38
|
+
)
|
|
39
|
+
return v
|
|
40
|
+
|
|
41
|
+
@field_validator("sentence")
|
|
42
|
+
@classmethod
|
|
43
|
+
def sentence_has_space(cls, v: str | None) -> str | None:
|
|
44
|
+
if v is not None and len(v.split()) <= 1:
|
|
45
|
+
raise ValueError("sentence must contain at least two words.")
|
|
46
|
+
return v
|
|
47
|
+
|
|
48
|
+
@model_validator(mode="after")
|
|
49
|
+
def check_sentence_and_context(self) -> Self:
|
|
50
|
+
if isinstance(self.sentence, str) and isinstance(self.context, str):
|
|
51
|
+
raise ValueError(
|
|
52
|
+
"Both 'sentence' and 'context' cannot be strings at the same time."
|
|
53
|
+
)
|
|
54
|
+
return self
|
|
65
55
|
|
|
66
56
|
def get_asset_type(self) -> AssetType:
|
|
67
|
-
|
|
68
|
-
if self.is_text_asset():
|
|
57
|
+
if self.data_type == "text":
|
|
69
58
|
return AssetType.TEXT
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
elif any(asset.path.endswith(ext) for ext in ALLOWED_VIDEO_EXTENSIONS):
|
|
79
|
-
return AssetType.VIDEO
|
|
80
|
-
elif any(asset.path.endswith(ext) for ext in ALLOWED_AUDIO_EXTENSIONS):
|
|
81
|
-
return AssetType.AUDIO
|
|
82
|
-
else:
|
|
83
|
-
logger.debug(
|
|
84
|
-
f"Cannot get asset type for asset type: {type(self.asset)}"
|
|
85
|
-
)
|
|
86
|
-
return AssetType.NONE
|
|
59
|
+
|
|
60
|
+
evaluation_asset = self.asset[0] if isinstance(self.asset, list) else self.asset
|
|
61
|
+
if any(evaluation_asset.endswith(ext) for ext in ALLOWED_IMAGE_EXTENSIONS):
|
|
62
|
+
return AssetType.IMAGE
|
|
63
|
+
elif any(evaluation_asset.endswith(ext) for ext in ALLOWED_VIDEO_EXTENSIONS):
|
|
64
|
+
return AssetType.VIDEO
|
|
65
|
+
elif any(evaluation_asset.endswith(ext) for ext in ALLOWED_AUDIO_EXTENSIONS):
|
|
66
|
+
return AssetType.AUDIO
|
|
87
67
|
else:
|
|
88
|
-
logger.debug(
|
|
68
|
+
logger.debug(
|
|
69
|
+
f"Cannot get asset type for asset type: {type(evaluation_asset)}"
|
|
70
|
+
)
|
|
89
71
|
return AssetType.NONE
|
|
90
72
|
|
|
91
73
|
def get_prompt_type(self) -> list[PromptType]:
|
|
92
|
-
"""Get the prompt type of the datapoint."""
|
|
93
74
|
prompt_types = []
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
prompt_types.append(PromptType.TEXT)
|
|
75
|
+
if self.context:
|
|
76
|
+
prompt_types.append(PromptType.TEXT)
|
|
77
|
+
if self.media_context:
|
|
78
|
+
prompt_types.append(PromptType.ASSET)
|
|
99
79
|
|
|
100
80
|
if len(prompt_types) == 0:
|
|
101
81
|
return [PromptType.NONE]
|
|
102
82
|
|
|
103
83
|
return prompt_types
|
|
104
|
-
|
|
105
|
-
def get_texts(self) -> list[str]:
|
|
106
|
-
"""Extract text content from the asset(s)."""
|
|
107
|
-
if isinstance(self.asset, TextAsset):
|
|
108
|
-
return [self.asset.text]
|
|
109
|
-
elif isinstance(self.asset, MultiAsset):
|
|
110
|
-
texts = []
|
|
111
|
-
for asset in self.asset.assets:
|
|
112
|
-
if isinstance(asset, TextAsset):
|
|
113
|
-
texts.append(asset.text)
|
|
114
|
-
return texts
|
|
115
|
-
else:
|
|
116
|
-
raise ValueError(f"Cannot extract text from asset type: {type(self.asset)}")
|
|
117
|
-
|
|
118
|
-
def get_media_assets(self) -> list[MediaAsset]:
|
|
119
|
-
"""Extract media assets from the datapoint."""
|
|
120
|
-
if isinstance(self.asset, MediaAsset):
|
|
121
|
-
return [self.asset]
|
|
122
|
-
elif isinstance(self.asset, MultiAsset):
|
|
123
|
-
media_assets = []
|
|
124
|
-
for asset in self.asset.assets:
|
|
125
|
-
if isinstance(asset, MediaAsset):
|
|
126
|
-
media_assets.append(asset)
|
|
127
|
-
return media_assets
|
|
128
|
-
else:
|
|
129
|
-
raise ValueError(
|
|
130
|
-
f"Cannot extract media assets from asset type: {type(self.asset)}"
|
|
131
|
-
)
|
|
132
|
-
|
|
133
|
-
def get_local_file_paths(
|
|
134
|
-
self,
|
|
135
|
-
) -> list[StrictStr | tuple[StrictStr, StrictBytes] | StrictBytes]:
|
|
136
|
-
"""Get local file paths for media assets that are stored locally."""
|
|
137
|
-
if not self.is_media_asset():
|
|
138
|
-
return []
|
|
139
|
-
|
|
140
|
-
media_assets = self.get_media_assets()
|
|
141
|
-
return [asset.to_file() for asset in media_assets if asset.is_local()]
|
|
142
|
-
|
|
143
|
-
def get_urls(self) -> list[str]:
|
|
144
|
-
"""Get URLs for media assets that are remote."""
|
|
145
|
-
if not self.is_media_asset():
|
|
146
|
-
return []
|
|
147
|
-
|
|
148
|
-
media_assets = self.get_media_assets()
|
|
149
|
-
return [asset.path for asset in media_assets if not asset.is_local()]
|
|
150
|
-
|
|
151
|
-
def get_prepared_metadata(
|
|
152
|
-
self,
|
|
153
|
-
) -> list[DatasetDatasetIdDatapointsPostRequestMetadataInner]:
|
|
154
|
-
"""Prepare metadata for API upload."""
|
|
155
|
-
metadata: list[DatasetDatasetIdDatapointsPostRequestMetadataInner] = []
|
|
156
|
-
if self.metadata:
|
|
157
|
-
for meta in self.metadata:
|
|
158
|
-
meta_model = meta.to_model() if meta else None
|
|
159
|
-
if meta_model:
|
|
160
|
-
metadata.append(
|
|
161
|
-
DatasetDatasetIdDatapointsPostRequestMetadataInner(meta_model)
|
|
162
|
-
)
|
|
163
|
-
return metadata
|
|
164
|
-
|
|
165
|
-
def create_text_upload_model(
|
|
166
|
-
self, index: int
|
|
167
|
-
) -> CreateDatapointFromTextSourcesModel:
|
|
168
|
-
"""Create the model for uploading text datapoints."""
|
|
169
|
-
if not self.is_text_asset():
|
|
170
|
-
raise ValueError("Cannot create text upload model for non-text asset")
|
|
171
|
-
|
|
172
|
-
texts = self.get_texts()
|
|
173
|
-
metadata = self.get_prepared_metadata()
|
|
174
|
-
|
|
175
|
-
return CreateDatapointFromTextSourcesModel(
|
|
176
|
-
textSources=texts,
|
|
177
|
-
sortIndex=index,
|
|
178
|
-
metadata=metadata,
|
|
179
|
-
)
|
|
180
|
-
|
|
181
|
-
def get_datapoint_string(self) -> str:
|
|
182
|
-
"""Get the datapoint string for the datapoint."""
|
|
183
|
-
if isinstance(self.asset, MediaAsset):
|
|
184
|
-
return self.asset.path
|
|
185
|
-
elif isinstance(self.asset, TextAsset):
|
|
186
|
-
return self.asset.text
|
|
187
|
-
else:
|
|
188
|
-
raise ValueError(
|
|
189
|
-
f"Cannot get datapoint string for asset type: {type(self.asset)}"
|
|
190
|
-
)
|
|
191
|
-
|
|
192
|
-
def __str__(self):
|
|
193
|
-
return f"Datapoint(asset={self.asset})"
|
|
194
|
-
|
|
195
|
-
def __repr__(self):
|
|
196
|
-
return self.__str__()
|