rapidata 2.41.3__py3-none-any.whl → 2.42.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rapidata might be problematic. Click here for more details.

Files changed (74) hide show
  1. rapidata/__init__.py +1 -5
  2. rapidata/api_client/__init__.py +14 -14
  3. rapidata/api_client/api/__init__.py +1 -0
  4. rapidata/api_client/api/asset_api.py +851 -0
  5. rapidata/api_client/api/benchmark_api.py +298 -0
  6. rapidata/api_client/api/customer_rapid_api.py +29 -43
  7. rapidata/api_client/api/dataset_api.py +163 -1143
  8. rapidata/api_client/api/participant_api.py +28 -74
  9. rapidata/api_client/api/validation_set_api.py +283 -0
  10. rapidata/api_client/models/__init__.py +13 -14
  11. rapidata/api_client/models/add_validation_rapid_model.py +3 -3
  12. rapidata/api_client/models/add_validation_rapid_new_model.py +152 -0
  13. rapidata/api_client/models/add_validation_rapid_new_model_asset.py +182 -0
  14. rapidata/api_client/models/compare_workflow_model.py +3 -3
  15. rapidata/api_client/models/create_datapoint_from_files_model.py +3 -3
  16. rapidata/api_client/models/create_datapoint_from_text_sources_model.py +3 -3
  17. rapidata/api_client/models/create_datapoint_from_urls_model.py +3 -3
  18. rapidata/api_client/models/create_datapoint_model.py +108 -0
  19. rapidata/api_client/models/create_datapoint_model_asset.py +182 -0
  20. rapidata/api_client/models/create_demographic_rapid_model.py +13 -2
  21. rapidata/api_client/models/create_demographic_rapid_model_asset.py +188 -0
  22. rapidata/api_client/models/create_demographic_rapid_model_new.py +119 -0
  23. rapidata/api_client/models/create_sample_model.py +8 -2
  24. rapidata/api_client/models/create_sample_model_asset.py +182 -0
  25. rapidata/api_client/models/create_sample_model_obsolete.py +87 -0
  26. rapidata/api_client/models/file_asset_input_file.py +8 -22
  27. rapidata/api_client/models/fork_benchmark_result.py +87 -0
  28. rapidata/api_client/models/form_file_wrapper.py +17 -2
  29. rapidata/api_client/models/get_asset_metadata_result.py +100 -0
  30. rapidata/api_client/models/multi_asset_input_assets_inner.py +10 -24
  31. rapidata/api_client/models/prompt_asset_metadata_input.py +3 -3
  32. rapidata/api_client/models/proxy_file_wrapper.py +17 -2
  33. rapidata/api_client/models/stream_file_wrapper.py +25 -3
  34. rapidata/api_client/models/submit_prompt_model.py +3 -3
  35. rapidata/api_client/models/text_metadata.py +6 -1
  36. rapidata/api_client/models/text_metadata_model.py +7 -2
  37. rapidata/api_client/models/upload_file_from_url_result.py +87 -0
  38. rapidata/api_client/models/upload_file_result.py +87 -0
  39. rapidata/api_client/models/zip_entry_file_wrapper.py +33 -2
  40. rapidata/api_client_README.md +28 -25
  41. rapidata/rapidata_client/__init__.py +0 -1
  42. rapidata/rapidata_client/benchmark/participant/_participant.py +25 -24
  43. rapidata/rapidata_client/benchmark/rapidata_benchmark.py +89 -102
  44. rapidata/rapidata_client/datapoints/__init__.py +0 -1
  45. rapidata/rapidata_client/datapoints/_asset_uploader.py +71 -0
  46. rapidata/rapidata_client/datapoints/_datapoint.py +58 -171
  47. rapidata/rapidata_client/datapoints/_datapoint_uploader.py +95 -0
  48. rapidata/rapidata_client/datapoints/assets/__init__.py +0 -11
  49. rapidata/rapidata_client/datapoints/metadata/_media_asset_metadata.py +10 -7
  50. rapidata/rapidata_client/demographic/demographic_manager.py +21 -8
  51. rapidata/rapidata_client/exceptions/failed_upload_exception.py +0 -62
  52. rapidata/rapidata_client/order/_rapidata_order_builder.py +0 -10
  53. rapidata/rapidata_client/order/dataset/_rapidata_dataset.py +65 -187
  54. rapidata/rapidata_client/order/rapidata_order_manager.py +62 -124
  55. rapidata/rapidata_client/validation/rapidata_validation_set.py +9 -5
  56. rapidata/rapidata_client/validation/rapids/_validation_rapid_uploader.py +101 -0
  57. rapidata/rapidata_client/validation/rapids/box.py +35 -11
  58. rapidata/rapidata_client/validation/rapids/rapids.py +26 -128
  59. rapidata/rapidata_client/validation/rapids/rapids_manager.py +123 -104
  60. rapidata/rapidata_client/validation/validation_set_manager.py +41 -38
  61. rapidata/rapidata_client/workflow/_ranking_workflow.py +14 -17
  62. rapidata/rapidata_client/workflow/_select_words_workflow.py +3 -16
  63. rapidata/service/openapi_service.py +8 -3
  64. {rapidata-2.41.3.dist-info → rapidata-2.42.1.dist-info}/METADATA +1 -1
  65. {rapidata-2.41.3.dist-info → rapidata-2.42.1.dist-info}/RECORD +67 -58
  66. {rapidata-2.41.3.dist-info → rapidata-2.42.1.dist-info}/WHEEL +1 -1
  67. rapidata/rapidata_client/datapoints/assets/_base_asset.py +0 -13
  68. rapidata/rapidata_client/datapoints/assets/_media_asset.py +0 -318
  69. rapidata/rapidata_client/datapoints/assets/_multi_asset.py +0 -61
  70. rapidata/rapidata_client/datapoints/assets/_sessions.py +0 -40
  71. rapidata/rapidata_client/datapoints/assets/_text_asset.py +0 -34
  72. rapidata/rapidata_client/datapoints/assets/data_type_enum.py +0 -8
  73. rapidata/rapidata_client/order/dataset/_progress_tracker.py +0 -100
  74. {rapidata-2.41.3.dist-info → rapidata-2.42.1.dist-info}/licenses/LICENSE +0 -0
@@ -16,10 +16,10 @@ from rapidata.api_client.models.query_model import QueryModel
16
16
  from rapidata.api_client.models.page_info import PageInfo
17
17
  from rapidata.api_client.models.source_url_metadata_model import SourceUrlMetadataModel
18
18
  from rapidata.api_client.models.submit_prompt_model import SubmitPromptModel
19
- from rapidata.api_client.models.submit_prompt_model_prompt_asset import (
20
- SubmitPromptModelPromptAsset,
19
+ from rapidata.api_client.models.create_demographic_rapid_model_asset import (
20
+ CreateDemographicRapidModelAsset,
21
21
  )
22
- from rapidata.api_client.models.url_asset_input import UrlAssetInput
22
+ from rapidata.api_client.models.existing_asset_input import ExistingAssetInput
23
23
 
24
24
  from rapidata.rapidata_client.benchmark._detail_mapper import DetailMapper
25
25
  from rapidata.rapidata_client.benchmark.leaderboard.rapidata_leaderboard import (
@@ -28,7 +28,7 @@ from rapidata.rapidata_client.benchmark.leaderboard.rapidata_leaderboard import
28
28
  from rapidata.rapidata_client.benchmark.participant._participant import (
29
29
  BenchmarkParticipant,
30
30
  )
31
- from rapidata.rapidata_client.datapoints.assets import MediaAsset
31
+ from rapidata.rapidata_client.datapoints._asset_uploader import AssetUploader
32
32
  from rapidata.rapidata_client.filter import RapidataFilter
33
33
  from rapidata.rapidata_client.config import logger, managed_print, tracer
34
34
  from rapidata.rapidata_client.settings import RapidataSetting
@@ -50,96 +50,92 @@ class RapidataBenchmark:
50
50
  def __init__(self, name: str, id: str, openapi_service: OpenAPIService):
51
51
  self.name = name
52
52
  self.id = id
53
- self.__openapi_service = openapi_service
53
+ self._openapi_service = openapi_service
54
54
  self.__prompts: list[str | None] = []
55
55
  self.__prompt_assets: list[str | None] = []
56
56
  self.__leaderboards: list[RapidataLeaderboard] = []
57
57
  self.__identifiers: list[str] = []
58
58
  self.__tags: list[list[str]] = []
59
59
  self.__benchmark_page: str = (
60
- f"https://app.{self.__openapi_service.environment}/mri/benchmarks/{self.id}"
60
+ f"https://app.{self._openapi_service.environment}/mri/benchmarks/{self.id}"
61
61
  )
62
+ self._asset_uploader = AssetUploader(openapi_service)
62
63
 
63
64
  def __instantiate_prompts(self) -> None:
64
- current_page = 1
65
- total_pages = None
65
+ with tracer.start_as_current_span("RapidataBenchmark.__instantiate_prompts"):
66
+ current_page = 1
67
+ total_pages = None
66
68
 
67
- while True:
68
- prompts_result = (
69
- self.__openapi_service.benchmark_api.benchmark_benchmark_id_prompts_get(
69
+ while True:
70
+ prompts_result = self._openapi_service.benchmark_api.benchmark_benchmark_id_prompts_get(
70
71
  benchmark_id=self.id,
71
72
  request=QueryModel(page=PageInfo(index=current_page, size=100)),
72
73
  )
73
- )
74
74
 
75
- if prompts_result.total_pages is None:
76
- raise ValueError(
77
- "An error occurred while fetching prompts: total_pages is None"
78
- )
75
+ if prompts_result.total_pages is None:
76
+ raise ValueError(
77
+ "An error occurred while fetching prompts: total_pages is None"
78
+ )
79
79
 
80
- total_pages = prompts_result.total_pages
80
+ total_pages = prompts_result.total_pages
81
81
 
82
- for prompt in prompts_result.items:
83
- self.__prompts.append(prompt.prompt)
84
- self.__identifiers.append(prompt.identifier)
85
- if prompt.prompt_asset is None:
86
- self.__prompt_assets.append(None)
87
- else:
88
- assert isinstance(
89
- prompt.prompt_asset.actual_instance, FileAssetModel
90
- )
91
- source_url = prompt.prompt_asset.actual_instance.metadata[
92
- "sourceUrl"
93
- ].actual_instance
94
- assert isinstance(source_url, SourceUrlMetadataModel)
95
- self.__prompt_assets.append(source_url.url)
82
+ for prompt in prompts_result.items:
83
+ self.__prompts.append(prompt.prompt)
84
+ self.__identifiers.append(prompt.identifier)
85
+ if prompt.prompt_asset is None:
86
+ self.__prompt_assets.append(None)
87
+ else:
88
+ assert isinstance(
89
+ prompt.prompt_asset.actual_instance, FileAssetModel
90
+ )
91
+ source_url = prompt.prompt_asset.actual_instance.metadata[
92
+ "sourceUrl"
93
+ ].actual_instance
94
+ assert isinstance(source_url, SourceUrlMetadataModel)
95
+ self.__prompt_assets.append(source_url.url)
96
96
 
97
- self.__tags.append(prompt.tags)
98
- if current_page >= total_pages:
99
- break
97
+ self.__tags.append(prompt.tags)
98
+ if current_page >= total_pages:
99
+ break
100
100
 
101
- current_page += 1
101
+ current_page += 1
102
102
 
103
103
  @property
104
104
  def identifiers(self) -> list[str]:
105
- with tracer.start_as_current_span("RapidataBenchmark.identifiers"):
106
- if not self.__identifiers:
107
- self.__instantiate_prompts()
105
+ if not self.__identifiers:
106
+ self.__instantiate_prompts()
108
107
 
109
- return self.__identifiers
108
+ return self.__identifiers
110
109
 
111
110
  @property
112
111
  def prompts(self) -> list[str | None]:
113
112
  """
114
113
  Returns the prompts that are registered for the leaderboard.
115
114
  """
116
- with tracer.start_as_current_span("RapidataBenchmark.prompts"):
117
- if not self.__prompts:
118
- self.__instantiate_prompts()
115
+ if not self.__prompts:
116
+ self.__instantiate_prompts()
119
117
 
120
- return self.__prompts
118
+ return self.__prompts
121
119
 
122
120
  @property
123
121
  def prompt_assets(self) -> list[str | None]:
124
122
  """
125
123
  Returns the prompt assets that are registered for the benchmark.
126
124
  """
127
- with tracer.start_as_current_span("RapidataBenchmark.prompt_assets"):
128
- if not self.__prompt_assets:
129
- self.__instantiate_prompts()
125
+ if not self.__prompt_assets:
126
+ self.__instantiate_prompts()
130
127
 
131
- return self.__prompt_assets
128
+ return self.__prompt_assets
132
129
 
133
130
  @property
134
131
  def tags(self) -> list[list[str]]:
135
132
  """
136
133
  Returns the tags that are registered for the benchmark.
137
134
  """
138
- with tracer.start_as_current_span("RapidataBenchmark.tags"):
139
- if not self.__tags:
140
- self.__instantiate_prompts()
135
+ if not self.__tags:
136
+ self.__instantiate_prompts()
141
137
 
142
- return self.__tags
138
+ return self.__tags
143
139
 
144
140
  @property
145
141
  def leaderboards(self) -> list[RapidataLeaderboard]:
@@ -152,7 +148,7 @@ class RapidataBenchmark:
152
148
  total_pages = None
153
149
 
154
150
  while True:
155
- leaderboards_result = self.__openapi_service.benchmark_api.benchmark_benchmark_id_leaderboards_get(
151
+ leaderboards_result = self._openapi_service.benchmark_api.benchmark_benchmark_id_leaderboards_get(
156
152
  benchmark_id=self.id,
157
153
  request=QueryModel(
158
154
  page=PageInfo(index=current_page, size=100),
@@ -178,7 +174,7 @@ class RapidataBenchmark:
178
174
  leaderboard.min_responses,
179
175
  self.id,
180
176
  leaderboard.id,
181
- self.__openapi_service,
177
+ self._openapi_service,
182
178
  )
183
179
  for leaderboard in leaderboards_result.items
184
180
  ]
@@ -239,9 +235,6 @@ class RapidataBenchmark:
239
235
  if identifier in self.identifiers:
240
236
  raise ValueError("Identifier already exists in the benchmark.")
241
237
 
242
- if prompt_asset is not None and not re.match(r"^https?://", prompt_asset):
243
- raise ValueError("Prompt asset must be a link to the asset.")
244
-
245
238
  if tags is not None and (
246
239
  not isinstance(tags, list)
247
240
  or not all(isinstance(tag, str) for tag in tags)
@@ -263,14 +256,17 @@ class RapidataBenchmark:
263
256
  self.__prompts.append(prompt)
264
257
  self.__prompt_assets.append(prompt_asset)
265
258
 
266
- self.__openapi_service.benchmark_api.benchmark_benchmark_id_prompt_post(
259
+ self._openapi_service.benchmark_api.benchmark_benchmark_id_prompt_post(
267
260
  benchmark_id=self.id,
268
261
  submit_prompt_model=SubmitPromptModel(
269
262
  identifier=identifier,
270
263
  prompt=prompt,
271
264
  promptAsset=(
272
- SubmitPromptModelPromptAsset(
273
- UrlAssetInput(_t="UrlAssetInput", url=prompt_asset)
265
+ CreateDemographicRapidModelAsset(
266
+ actual_instance=ExistingAssetInput(
267
+ _t="ExistingAssetInput",
268
+ name=self._asset_uploader.upload_asset(prompt_asset),
269
+ )
274
270
  )
275
271
  if prompt_asset is not None
276
272
  else None
@@ -328,32 +324,30 @@ class RapidataBenchmark:
328
324
  settings,
329
325
  )
330
326
 
331
- leaderboard_result = (
332
- self.__openapi_service.leaderboard_api.leaderboard_post(
333
- create_leaderboard_model=CreateLeaderboardModel(
334
- benchmarkId=self.id,
335
- name=name,
336
- instruction=instruction,
337
- showPrompt=show_prompt,
338
- showPromptAsset=show_prompt_asset,
339
- isInversed=inverse_ranking,
340
- minResponses=min_responses_per_matchup,
341
- responseBudget=DetailMapper.get_budget(level_of_detail),
342
- validationSetId=validation_set_id,
343
- filters=(
344
- [
345
- AndUserFilterModelFiltersInner(filter._to_model())
346
- for filter in filters
347
- ]
348
- if filters
349
- else None
350
- ),
351
- featureFlags=(
352
- [setting._to_feature_flag() for setting in settings]
353
- if settings
354
- else None
355
- ),
356
- )
327
+ leaderboard_result = self._openapi_service.leaderboard_api.leaderboard_post(
328
+ create_leaderboard_model=CreateLeaderboardModel(
329
+ benchmarkId=self.id,
330
+ name=name,
331
+ instruction=instruction,
332
+ showPrompt=show_prompt,
333
+ showPromptAsset=show_prompt_asset,
334
+ isInversed=inverse_ranking,
335
+ minResponses=min_responses_per_matchup,
336
+ responseBudget=DetailMapper.get_budget(level_of_detail),
337
+ validationSetId=validation_set_id,
338
+ filters=(
339
+ [
340
+ AndUserFilterModelFiltersInner(filter._to_model())
341
+ for filter in filters
342
+ ]
343
+ if filters
344
+ else None
345
+ ),
346
+ featureFlags=(
347
+ [setting._to_feature_flag() for setting in settings]
348
+ if settings
349
+ else None
350
+ ),
357
351
  )
358
352
  )
359
353
 
@@ -373,7 +367,7 @@ class RapidataBenchmark:
373
367
  min_responses_per_matchup,
374
368
  self.id,
375
369
  leaderboard_result.id,
376
- self.__openapi_service,
370
+ self._openapi_service,
377
371
  )
378
372
 
379
373
  def evaluate_model(
@@ -421,12 +415,7 @@ class RapidataBenchmark:
421
415
  "All identifiers/prompts must be in the registered identifiers/prompts list. To see the registered identifiers/prompts, use the identifiers/prompts property."
422
416
  )
423
417
 
424
- # happens before the creation of the participant to ensure all media paths are valid
425
- assets: list[MediaAsset] = []
426
- for media_path in media:
427
- assets.append(MediaAsset(media_path))
428
-
429
- participant_result = self.__openapi_service.benchmark_api.benchmark_benchmark_id_participants_post(
418
+ participant_result = self._openapi_service.benchmark_api.benchmark_benchmark_id_participants_post(
430
419
  benchmark_id=self.id,
431
420
  create_benchmark_participant_model=CreateBenchmarkParticipantModel(
432
421
  name=name,
@@ -436,20 +425,20 @@ class RapidataBenchmark:
436
425
  logger.info(f"Participant created: {participant_result.participant_id}")
437
426
 
438
427
  participant = BenchmarkParticipant(
439
- name, participant_result.participant_id, self.__openapi_service
428
+ name, participant_result.participant_id, self._openapi_service
440
429
  )
441
430
 
442
431
  with tracer.start_as_current_span("upload_media_for_participant"):
443
432
  logger.info(
444
- f"Uploading {len(assets)} media assets to participant {participant.id}"
433
+ f"Uploading {len(media)} media assets to participant {participant.id}"
445
434
  )
446
435
 
447
436
  successful_uploads, failed_uploads = participant.upload_media(
448
- assets,
437
+ media,
449
438
  identifiers,
450
439
  )
451
440
 
452
- total_uploads = len(assets)
441
+ total_uploads = len(media)
453
442
  success_rate = (
454
443
  (len(successful_uploads) / total_uploads * 100)
455
444
  if total_uploads > 0
@@ -460,9 +449,7 @@ class RapidataBenchmark:
460
449
  )
461
450
 
462
451
  if failed_uploads:
463
- logger.error(
464
- f"Failed uploads for media: {[asset.path for asset in failed_uploads]}"
465
- )
452
+ logger.error(f"Failed uploads for media: {failed_uploads}")
466
453
  logger.warning(
467
454
  "Some uploads failed. The model evaluation may be incomplete."
468
455
  )
@@ -472,9 +459,9 @@ class RapidataBenchmark:
472
459
  "No uploads were successful. The model evaluation will not be completed."
473
460
  )
474
461
 
475
- self.__openapi_service.participant_api.participants_participant_id_submit_post(
476
- participant_id=participant_result.participant_id
477
- )
462
+ self._openapi_service.participant_api.participants_participant_id_submit_post(
463
+ participant_id=participant_result.participant_id
464
+ )
478
465
 
479
466
  def view(self) -> None:
480
467
  """
@@ -1,5 +1,4 @@
1
1
  from ._datapoint import Datapoint
2
- from .assets import MediaAsset, MultiAsset, TextAsset
3
2
  from .metadata import (
4
3
  Metadata,
5
4
  PromptMetadata,
@@ -0,0 +1,71 @@
1
+ import re
2
+ import os
3
+ from rapidata.api_client.models.existing_asset_input import ExistingAssetInput
4
+ from rapidata.api_client.models.multi_asset_input import (
5
+ MultiAssetInput,
6
+ MultiAssetInputAssetsInner,
7
+ )
8
+ from rapidata.api_client.models.text_asset_input import TextAssetInput
9
+ from rapidata.service.openapi_service import OpenAPIService
10
+ from rapidata.rapidata_client.config import logger
11
+ from rapidata.rapidata_client.config import tracer
12
+
13
+
14
+ class AssetUploader:
15
+ def __init__(self, openapi_service: OpenAPIService):
16
+ self.openapi_service = openapi_service
17
+
18
+ def upload_asset(self, asset: str) -> str:
19
+ with tracer.start_as_current_span("AssetUploader.upload_asset"):
20
+ logger.debug("Uploading asset: %s", asset)
21
+ assert isinstance(asset, str), "Asset must be a string"
22
+
23
+ if re.match(r"^https?://", asset):
24
+ response = self.openapi_service.asset_api.asset_url_post(
25
+ url=asset,
26
+ )
27
+ else:
28
+ if not os.path.exists(asset):
29
+ raise FileNotFoundError(f"File not found: {asset}")
30
+ response = self.openapi_service.asset_api.asset_file_post(
31
+ file=asset,
32
+ )
33
+ return response.file_name
34
+
35
+ def get_uploaded_text_input(
36
+ self, assets: list[str] | str
37
+ ) -> MultiAssetInput | TextAssetInput:
38
+ if isinstance(assets, list):
39
+ return MultiAssetInput(
40
+ _t="MultiAssetInput",
41
+ assets=[
42
+ MultiAssetInputAssetsInner(
43
+ actual_instance=TextAssetInput(_t="TextAssetInput", text=asset)
44
+ )
45
+ for asset in assets
46
+ ],
47
+ )
48
+ else:
49
+ return TextAssetInput(_t="TextAssetInput", text=assets)
50
+
51
+ def get_uploaded_asset_input(
52
+ self, assets: list[str] | str
53
+ ) -> MultiAssetInput | ExistingAssetInput:
54
+ if isinstance(assets, list):
55
+ return MultiAssetInput(
56
+ _t="MultiAssetInput",
57
+ assets=[
58
+ MultiAssetInputAssetsInner(
59
+ actual_instance=ExistingAssetInput(
60
+ _t="ExistingAssetInput",
61
+ name=self.upload_asset(asset),
62
+ ),
63
+ )
64
+ for asset in assets
65
+ ],
66
+ )
67
+ else:
68
+ return ExistingAssetInput(
69
+ _t="ExistingAssetInput",
70
+ name=self.upload_asset(assets),
71
+ )
@@ -1,196 +1,83 @@
1
- from typing import Sequence, cast
2
- from rapidata.rapidata_client.datapoints.assets import (
3
- MediaAsset,
4
- TextAsset,
5
- MultiAsset,
6
- BaseAsset,
7
- )
1
+ from typing import Literal
2
+
3
+ from pydantic import BaseModel, model_validator, field_validator
4
+ from typing_extensions import Self
8
5
  from rapidata.rapidata_client.datapoints.assets.constants import (
9
6
  ALLOWED_VIDEO_EXTENSIONS,
10
7
  ALLOWED_IMAGE_EXTENSIONS,
11
8
  ALLOWED_AUDIO_EXTENSIONS,
12
9
  )
13
- from rapidata.rapidata_client.datapoints.metadata import Metadata
14
- from rapidata.api_client.models.dataset_dataset_id_datapoints_post_request_metadata_inner import (
15
- DatasetDatasetIdDatapointsPostRequestMetadataInner,
16
- )
17
- from rapidata.api_client.models.create_datapoint_from_text_sources_model import (
18
- CreateDatapointFromTextSourcesModel,
19
- )
20
- from pydantic import StrictStr, StrictBytes
21
10
  from rapidata.api_client.models.asset_type import AssetType
22
11
  from rapidata.api_client.models.prompt_type import PromptType
23
- from rapidata.rapidata_client.datapoints.metadata._media_asset_metadata import (
24
- MediaAssetMetadata,
25
- )
26
- from rapidata.rapidata_client.datapoints.metadata._prompt_metadata import PromptMetadata
27
12
  from rapidata.rapidata_client.config import logger
28
13
 
29
14
 
30
- class Datapoint:
31
- def __init__(
32
- self,
33
- asset: MediaAsset | TextAsset | MultiAsset,
34
- metadata: Sequence[Metadata] | None = None,
35
- ):
36
- if not isinstance(asset, (MediaAsset, TextAsset, MultiAsset)):
37
- raise TypeError(
38
- "Asset must be of type MediaAsset, TextAsset, or MultiAsset."
39
- )
40
-
41
- if metadata and not isinstance(metadata, Sequence):
42
- raise TypeError("Metadata must be a list of Metadata objects.")
43
-
44
- if metadata and not all(isinstance(m, Metadata) for m in metadata):
45
- raise TypeError("All metadata objects must be of type Metadata.")
46
-
47
- self.asset = asset
48
- self.metadata = metadata
49
-
50
- def _get_effective_asset_type(self) -> type:
51
- """Get the effective asset type, handling MultiAsset by looking at its first asset."""
52
- if isinstance(self.asset, MultiAsset):
53
- return type(self.asset.assets[0])
54
- return type(self.asset)
15
+ class Datapoint(BaseModel):
16
+ asset: str | list[str]
17
+ data_type: Literal["text", "media"]
18
+ context: str | None = None
19
+ media_context: str | None = None
20
+ sentence: str | None = None
21
+ private_note: str | None = None
55
22
 
56
- def is_media_asset(self) -> bool:
57
- """Check if this datapoint contains media assets."""
58
- effective_type = self._get_effective_asset_type()
59
- return issubclass(effective_type, MediaAsset)
23
+ @field_validator("context")
24
+ @classmethod
25
+ def context_not_empty(cls, v: str | None) -> str | None:
26
+ if v is not None and v == "":
27
+ raise ValueError(
28
+ "context cannot be an empty string. If not needed, set to None."
29
+ )
30
+ return v
60
31
 
61
- def is_text_asset(self) -> bool:
62
- """Check if this datapoint contains text assets."""
63
- effective_type = self._get_effective_asset_type()
64
- return issubclass(effective_type, TextAsset)
32
+ @field_validator("media_context")
33
+ @classmethod
34
+ def media_context_not_empty(cls, v: str | None) -> str | None:
35
+ if v is not None and v == "":
36
+ raise ValueError(
37
+ "media_context cannot be an empty string. If not needed, set to None."
38
+ )
39
+ return v
40
+
41
+ @field_validator("sentence")
42
+ @classmethod
43
+ def sentence_has_space(cls, v: str | None) -> str | None:
44
+ if v is not None and len(v.split()) <= 1:
45
+ raise ValueError("sentence must contain at least two words.")
46
+ return v
47
+
48
+ @model_validator(mode="after")
49
+ def check_sentence_and_context(self) -> Self:
50
+ if isinstance(self.sentence, str) and isinstance(self.context, str):
51
+ raise ValueError(
52
+ "Both 'sentence' and 'context' cannot be strings at the same time."
53
+ )
54
+ return self
65
55
 
66
56
  def get_asset_type(self) -> AssetType:
67
- """Get the asset type of the datapoint."""
68
- if self.is_text_asset():
57
+ if self.data_type == "text":
69
58
  return AssetType.TEXT
70
- elif self.is_media_asset():
71
- if isinstance(self.asset, MultiAsset):
72
- asset = self.asset.assets[0]
73
- else:
74
- asset = self.asset
75
- assert isinstance(asset, MediaAsset)
76
- if any(asset.path.endswith(ext) for ext in ALLOWED_IMAGE_EXTENSIONS):
77
- return AssetType.IMAGE
78
- elif any(asset.path.endswith(ext) for ext in ALLOWED_VIDEO_EXTENSIONS):
79
- return AssetType.VIDEO
80
- elif any(asset.path.endswith(ext) for ext in ALLOWED_AUDIO_EXTENSIONS):
81
- return AssetType.AUDIO
82
- else:
83
- logger.debug(
84
- f"Cannot get asset type for asset type: {type(self.asset)}"
85
- )
86
- return AssetType.NONE
59
+
60
+ evaluation_asset = self.asset[0] if isinstance(self.asset, list) else self.asset
61
+ if any(evaluation_asset.endswith(ext) for ext in ALLOWED_IMAGE_EXTENSIONS):
62
+ return AssetType.IMAGE
63
+ elif any(evaluation_asset.endswith(ext) for ext in ALLOWED_VIDEO_EXTENSIONS):
64
+ return AssetType.VIDEO
65
+ elif any(evaluation_asset.endswith(ext) for ext in ALLOWED_AUDIO_EXTENSIONS):
66
+ return AssetType.AUDIO
87
67
  else:
88
- logger.debug(f"Cannot get asset type for asset type: {type(self.asset)}")
68
+ logger.debug(
69
+ f"Cannot get asset type for asset type: {type(evaluation_asset)}"
70
+ )
89
71
  return AssetType.NONE
90
72
 
91
73
  def get_prompt_type(self) -> list[PromptType]:
92
- """Get the prompt type of the datapoint."""
93
74
  prompt_types = []
94
- for metadata in self.metadata or []:
95
- if isinstance(metadata, MediaAssetMetadata):
96
- prompt_types.append(PromptType.ASSET)
97
- elif isinstance(metadata, PromptMetadata):
98
- prompt_types.append(PromptType.TEXT)
75
+ if self.context:
76
+ prompt_types.append(PromptType.TEXT)
77
+ if self.media_context:
78
+ prompt_types.append(PromptType.ASSET)
99
79
 
100
80
  if len(prompt_types) == 0:
101
81
  return [PromptType.NONE]
102
82
 
103
83
  return prompt_types
104
-
105
- def get_texts(self) -> list[str]:
106
- """Extract text content from the asset(s)."""
107
- if isinstance(self.asset, TextAsset):
108
- return [self.asset.text]
109
- elif isinstance(self.asset, MultiAsset):
110
- texts = []
111
- for asset in self.asset.assets:
112
- if isinstance(asset, TextAsset):
113
- texts.append(asset.text)
114
- return texts
115
- else:
116
- raise ValueError(f"Cannot extract text from asset type: {type(self.asset)}")
117
-
118
- def get_media_assets(self) -> list[MediaAsset]:
119
- """Extract media assets from the datapoint."""
120
- if isinstance(self.asset, MediaAsset):
121
- return [self.asset]
122
- elif isinstance(self.asset, MultiAsset):
123
- media_assets = []
124
- for asset in self.asset.assets:
125
- if isinstance(asset, MediaAsset):
126
- media_assets.append(asset)
127
- return media_assets
128
- else:
129
- raise ValueError(
130
- f"Cannot extract media assets from asset type: {type(self.asset)}"
131
- )
132
-
133
- def get_local_file_paths(
134
- self,
135
- ) -> list[StrictStr | tuple[StrictStr, StrictBytes] | StrictBytes]:
136
- """Get local file paths for media assets that are stored locally."""
137
- if not self.is_media_asset():
138
- return []
139
-
140
- media_assets = self.get_media_assets()
141
- return [asset.to_file() for asset in media_assets if asset.is_local()]
142
-
143
- def get_urls(self) -> list[str]:
144
- """Get URLs for media assets that are remote."""
145
- if not self.is_media_asset():
146
- return []
147
-
148
- media_assets = self.get_media_assets()
149
- return [asset.path for asset in media_assets if not asset.is_local()]
150
-
151
- def get_prepared_metadata(
152
- self,
153
- ) -> list[DatasetDatasetIdDatapointsPostRequestMetadataInner]:
154
- """Prepare metadata for API upload."""
155
- metadata: list[DatasetDatasetIdDatapointsPostRequestMetadataInner] = []
156
- if self.metadata:
157
- for meta in self.metadata:
158
- meta_model = meta.to_model() if meta else None
159
- if meta_model:
160
- metadata.append(
161
- DatasetDatasetIdDatapointsPostRequestMetadataInner(meta_model)
162
- )
163
- return metadata
164
-
165
- def create_text_upload_model(
166
- self, index: int
167
- ) -> CreateDatapointFromTextSourcesModel:
168
- """Create the model for uploading text datapoints."""
169
- if not self.is_text_asset():
170
- raise ValueError("Cannot create text upload model for non-text asset")
171
-
172
- texts = self.get_texts()
173
- metadata = self.get_prepared_metadata()
174
-
175
- return CreateDatapointFromTextSourcesModel(
176
- textSources=texts,
177
- sortIndex=index,
178
- metadata=metadata,
179
- )
180
-
181
- def get_datapoint_string(self) -> str:
182
- """Get the datapoint string for the datapoint."""
183
- if isinstance(self.asset, MediaAsset):
184
- return self.asset.path
185
- elif isinstance(self.asset, TextAsset):
186
- return self.asset.text
187
- else:
188
- raise ValueError(
189
- f"Cannot get datapoint string for asset type: {type(self.asset)}"
190
- )
191
-
192
- def __str__(self):
193
- return f"Datapoint(asset={self.asset})"
194
-
195
- def __repr__(self):
196
- return self.__str__()