rapidata 2.41.2__py3-none-any.whl → 2.42.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rapidata might be problematic. Click here for more details.

Files changed (75) hide show
  1. rapidata/__init__.py +1 -5
  2. rapidata/api_client/__init__.py +14 -14
  3. rapidata/api_client/api/__init__.py +1 -0
  4. rapidata/api_client/api/asset_api.py +851 -0
  5. rapidata/api_client/api/benchmark_api.py +298 -0
  6. rapidata/api_client/api/customer_rapid_api.py +29 -43
  7. rapidata/api_client/api/dataset_api.py +163 -1143
  8. rapidata/api_client/api/participant_api.py +28 -74
  9. rapidata/api_client/api/validation_set_api.py +283 -0
  10. rapidata/api_client/models/__init__.py +13 -14
  11. rapidata/api_client/models/add_validation_rapid_model.py +3 -3
  12. rapidata/api_client/models/add_validation_rapid_new_model.py +152 -0
  13. rapidata/api_client/models/add_validation_rapid_new_model_asset.py +182 -0
  14. rapidata/api_client/models/compare_workflow_model.py +3 -3
  15. rapidata/api_client/models/create_datapoint_from_files_model.py +3 -3
  16. rapidata/api_client/models/create_datapoint_from_text_sources_model.py +3 -3
  17. rapidata/api_client/models/create_datapoint_from_urls_model.py +3 -3
  18. rapidata/api_client/models/create_datapoint_model.py +108 -0
  19. rapidata/api_client/models/create_datapoint_model_asset.py +182 -0
  20. rapidata/api_client/models/create_demographic_rapid_model.py +13 -2
  21. rapidata/api_client/models/create_demographic_rapid_model_asset.py +188 -0
  22. rapidata/api_client/models/create_demographic_rapid_model_new.py +119 -0
  23. rapidata/api_client/models/create_sample_model.py +8 -2
  24. rapidata/api_client/models/create_sample_model_asset.py +182 -0
  25. rapidata/api_client/models/create_sample_model_obsolete.py +87 -0
  26. rapidata/api_client/models/file_asset_input_file.py +8 -22
  27. rapidata/api_client/models/fork_benchmark_result.py +87 -0
  28. rapidata/api_client/models/form_file_wrapper.py +17 -2
  29. rapidata/api_client/models/get_asset_metadata_result.py +100 -0
  30. rapidata/api_client/models/multi_asset_input_assets_inner.py +10 -24
  31. rapidata/api_client/models/prompt_asset_metadata_input.py +3 -3
  32. rapidata/api_client/models/proxy_file_wrapper.py +17 -2
  33. rapidata/api_client/models/stream_file_wrapper.py +25 -3
  34. rapidata/api_client/models/submit_prompt_model.py +3 -3
  35. rapidata/api_client/models/text_metadata.py +6 -1
  36. rapidata/api_client/models/text_metadata_model.py +7 -2
  37. rapidata/api_client/models/upload_file_from_url_result.py +87 -0
  38. rapidata/api_client/models/upload_file_result.py +87 -0
  39. rapidata/api_client/models/zip_entry_file_wrapper.py +33 -2
  40. rapidata/api_client_README.md +28 -25
  41. rapidata/rapidata_client/__init__.py +0 -1
  42. rapidata/rapidata_client/benchmark/participant/_participant.py +24 -22
  43. rapidata/rapidata_client/benchmark/rapidata_benchmark.py +89 -102
  44. rapidata/rapidata_client/datapoints/__init__.py +0 -1
  45. rapidata/rapidata_client/datapoints/_asset_uploader.py +71 -0
  46. rapidata/rapidata_client/datapoints/_datapoint.py +58 -171
  47. rapidata/rapidata_client/datapoints/_datapoint_uploader.py +95 -0
  48. rapidata/rapidata_client/datapoints/assets/__init__.py +0 -11
  49. rapidata/rapidata_client/datapoints/metadata/_media_asset_metadata.py +10 -7
  50. rapidata/rapidata_client/demographic/demographic_manager.py +21 -8
  51. rapidata/rapidata_client/exceptions/failed_upload_exception.py +0 -62
  52. rapidata/rapidata_client/order/_rapidata_order_builder.py +0 -10
  53. rapidata/rapidata_client/order/dataset/_rapidata_dataset.py +67 -187
  54. rapidata/rapidata_client/order/rapidata_order_manager.py +58 -116
  55. rapidata/rapidata_client/settings/translation_behaviour.py +1 -1
  56. rapidata/rapidata_client/validation/rapidata_validation_set.py +9 -5
  57. rapidata/rapidata_client/validation/rapids/_validation_rapid_uploader.py +101 -0
  58. rapidata/rapidata_client/validation/rapids/box.py +35 -11
  59. rapidata/rapidata_client/validation/rapids/rapids.py +26 -128
  60. rapidata/rapidata_client/validation/rapids/rapids_manager.py +123 -104
  61. rapidata/rapidata_client/validation/validation_set_manager.py +25 -34
  62. rapidata/rapidata_client/workflow/_ranking_workflow.py +14 -17
  63. rapidata/rapidata_client/workflow/_select_words_workflow.py +3 -16
  64. rapidata/service/openapi_service.py +8 -3
  65. {rapidata-2.41.2.dist-info → rapidata-2.42.0.dist-info}/METADATA +1 -1
  66. {rapidata-2.41.2.dist-info → rapidata-2.42.0.dist-info}/RECORD +68 -59
  67. {rapidata-2.41.2.dist-info → rapidata-2.42.0.dist-info}/WHEEL +1 -1
  68. rapidata/rapidata_client/datapoints/assets/_base_asset.py +0 -13
  69. rapidata/rapidata_client/datapoints/assets/_media_asset.py +0 -318
  70. rapidata/rapidata_client/datapoints/assets/_multi_asset.py +0 -61
  71. rapidata/rapidata_client/datapoints/assets/_sessions.py +0 -40
  72. rapidata/rapidata_client/datapoints/assets/_text_asset.py +0 -34
  73. rapidata/rapidata_client/datapoints/assets/data_type_enum.py +0 -8
  74. rapidata/rapidata_client/order/dataset/_progress_tracker.py +0 -100
  75. {rapidata-2.41.2.dist-info → rapidata-2.42.0.dist-info}/licenses/LICENSE +0 -0
@@ -2,16 +2,10 @@ from typing import Sequence, Optional, Literal
2
2
  from itertools import zip_longest
3
3
 
4
4
  from rapidata.rapidata_client.config.tracer import tracer
5
+ from rapidata.rapidata_client.datapoints.metadata._base_metadata import Metadata
5
6
  from rapidata.service.openapi_service import OpenAPIService
6
7
  from rapidata.rapidata_client.order.rapidata_order import RapidataOrder
7
8
  from rapidata.rapidata_client.order._rapidata_order_builder import RapidataOrderBuilder
8
- from rapidata.rapidata_client.datapoints.metadata import (
9
- PromptMetadata,
10
- SelectWordsMetadata,
11
- PrivateTextMetadata,
12
- MediaAssetMetadata,
13
- Metadata,
14
- )
15
9
  from rapidata.rapidata_client.referee._naive_referee import NaiveReferee
16
10
  from rapidata.rapidata_client.referee._early_stopping_referee import (
17
11
  EarlyStoppingReferee,
@@ -28,13 +22,17 @@ from rapidata.rapidata_client.workflow import (
28
22
  TimestampWorkflow,
29
23
  RankingWorkflow,
30
24
  )
31
- from rapidata.rapidata_client.datapoints.assets import MediaAsset, TextAsset, MultiAsset
32
25
  from rapidata.rapidata_client.datapoints._datapoint import Datapoint
26
+ from rapidata.rapidata_client.datapoints.metadata import (
27
+ PromptMetadata,
28
+ MediaAssetMetadata,
29
+ )
33
30
  from rapidata.rapidata_client.filter import RapidataFilter
34
31
  from rapidata.rapidata_client.filter.rapidata_filters import RapidataFilters
35
32
  from rapidata.rapidata_client.settings import RapidataSettings, RapidataSetting
36
33
  from rapidata.rapidata_client.selection.rapidata_selections import RapidataSelections
37
34
  from rapidata.rapidata_client.config import logger, rapidata_config
35
+ from rapidata.rapidata_client.datapoints._asset_uploader import AssetUploader
38
36
 
39
37
  from rapidata.api_client.models.query_model import QueryModel
40
38
  from rapidata.api_client.models.page_info import PageInfo
@@ -64,13 +62,15 @@ class RapidataOrderManager:
64
62
  self.selections = RapidataSelections
65
63
  self.__priority: int | None = None
66
64
  self.__sticky_state: Literal["None", "Temporary", "Permanent"] | None = None
65
+ self.__asset_uploader = AssetUploader(openapi_service)
67
66
  logger.debug("RapidataOrderManager initialized")
68
67
 
69
68
  def _create_general_order(
70
69
  self,
71
70
  name: str,
72
71
  workflow: Workflow,
73
- assets: list[MediaAsset] | list[TextAsset] | list[MultiAsset],
72
+ assets: list[str] | list[list[str]],
73
+ data_type: Literal["media", "text"] = "media",
74
74
  responses_per_datapoint: int = 10,
75
75
  contexts: list[str] | None = None,
76
76
  media_contexts: list[str] | None = None,
@@ -92,11 +92,6 @@ class RapidataOrderManager:
92
92
  if media_contexts and len(media_contexts) != len(assets):
93
93
  raise ValueError("Number of media contexts must match number of datapoints")
94
94
 
95
- if media_contexts:
96
- for media_context in media_contexts:
97
- if not media_context.startswith("http"):
98
- raise ValueError("Media contexts must all be URLs")
99
-
100
95
  if sentences and len(sentences) != len(assets):
101
96
  raise ValueError("Number of sentences must match number of datapoints")
102
97
 
@@ -114,11 +109,15 @@ class RapidataOrderManager:
114
109
  max_vote_count=responses_per_datapoint,
115
110
  )
116
111
 
112
+ if data_type not in ["media", "text"]:
113
+ raise ValueError("Data type must be one of 'media' or 'text'")
114
+
117
115
  logger.debug(
118
- "Creating order with parameters: name %s, workflow %s, assets %s, responses_per_datapoint %s, contexts %s, media_contexts %s, validation_set_id %s, confidence_threshold %s, filters %s, settings %s, sentences %s, selections %s, private_notes %s",
116
+ "Creating order with parameters: name %s, workflow %s, datapoints %s, data_type %s, responses_per_datapoint %s, contexts %s, media_contexts %s, validation_set_id %s, confidence_threshold %s, filters %s, settings %s, sentences %s, selections %s, private_notes %s",
119
117
  name,
120
118
  workflow,
121
119
  assets,
120
+ data_type,
122
121
  responses_per_datapoint,
123
122
  contexts,
124
123
  media_contexts,
@@ -140,45 +139,25 @@ class RapidataOrderManager:
140
139
  "Warning: Both selections and validation_set_id provided. Ignoring validation_set_id."
141
140
  )
142
141
 
143
- prompts_metadata = (
144
- [PromptMetadata(prompt=prompt) for prompt in contexts] if contexts else None
145
- )
146
- sentence_metadata = (
147
- [SelectWordsMetadata(select_words=sentence) for sentence in sentences]
148
- if sentences
149
- else None
150
- )
151
-
152
- if prompts_metadata and sentence_metadata:
153
- raise ValueError("You can only use contexts or sentences, not both")
154
-
155
- asset_metadata: Sequence[Metadata] = (
156
- [MediaAssetMetadata(url=context) for context in media_contexts]
157
- if media_contexts
158
- else []
159
- )
160
- prompt_metadata: Sequence[Metadata] = (
161
- prompts_metadata or sentence_metadata or []
162
- )
163
- private_notes_metadata: Sequence[Metadata] = (
164
- [PrivateTextMetadata(text=text) for text in private_notes]
165
- if private_notes
166
- else []
167
- )
168
-
169
- multi_metadata = [
170
- [item for item in items if item is not None]
171
- for items in zip_longest(
172
- prompt_metadata, asset_metadata, private_notes_metadata
173
- )
174
- ]
175
-
176
142
  order = (
177
143
  order_builder._workflow(workflow)
178
144
  ._datapoints(
179
145
  datapoints=[
180
- Datapoint(asset=asset, metadata=metadata)
181
- for asset, metadata in zip_longest(assets, multi_metadata)
146
+ Datapoint(
147
+ asset=asset,
148
+ data_type=data_type,
149
+ context=context,
150
+ media_context=media_context,
151
+ sentence=sentence,
152
+ private_note=private_note,
153
+ )
154
+ for asset, context, media_context, sentence, private_note in zip_longest(
155
+ assets,
156
+ contexts or [],
157
+ media_contexts or [],
158
+ sentences or [],
159
+ private_notes or [],
160
+ )
182
161
  ]
183
162
  )
184
163
  ._referee(referee)
@@ -261,21 +240,18 @@ class RapidataOrderManager:
261
240
  with tracer.start_as_current_span(
262
241
  "RapidataOrderManager.create_classification_order"
263
242
  ):
264
- if data_type == "media":
265
- assets = [MediaAsset(path=path) for path in datapoints]
266
- elif data_type == "text":
267
- assets = [TextAsset(text=text) for text in datapoints]
268
- else:
269
- raise ValueError(
270
- f"Unsupported data type: {data_type}, must be one of 'media' or 'text'"
271
- )
243
+ if not isinstance(datapoints, list) or not all(
244
+ isinstance(datapoint, str) for datapoint in datapoints
245
+ ):
246
+ raise ValueError("Datapoints must be a list of strings")
272
247
 
273
248
  return self._create_general_order(
274
249
  name=name,
275
250
  workflow=ClassifyWorkflow(
276
251
  instruction=instruction, answer_options=answer_options
277
252
  ),
278
- assets=assets,
253
+ assets=datapoints,
254
+ data_type=data_type,
279
255
  responses_per_datapoint=responses_per_datapoint,
280
256
  contexts=contexts,
281
257
  media_contexts=media_contexts,
@@ -342,7 +318,7 @@ class RapidataOrderManager:
342
318
  This will NOT be shown to the labelers but will be included in the result purely for your own reference.
343
319
  """
344
320
  with tracer.start_as_current_span("RapidataOrderManager.create_compare_order"):
345
- if any(type(datapoint) != list for datapoint in datapoints):
321
+ if any(not isinstance(datapoint, list) for datapoint in datapoints):
346
322
  raise ValueError("Each datapoint must be a list of 2 paths/texts")
347
323
 
348
324
  if any(len(datapoint) != 2 for datapoint in datapoints):
@@ -353,25 +329,11 @@ class RapidataOrderManager:
353
329
  "A_B_naming must be a list of exactly two strings or None"
354
330
  )
355
331
 
356
- if data_type == "media":
357
- assets = [
358
- MultiAsset([MediaAsset(path=path) for path in datapoint])
359
- for datapoint in datapoints
360
- ]
361
- elif data_type == "text":
362
- assets = [
363
- MultiAsset([TextAsset(text=text) for text in datapoint])
364
- for datapoint in datapoints
365
- ]
366
- else:
367
- raise ValueError(
368
- f"Unsupported data type: {data_type}, must be one of 'media' or 'text'"
369
- )
370
-
371
332
  return self._create_general_order(
372
333
  name=name,
373
334
  workflow=CompareWorkflow(instruction=instruction, a_b_names=a_b_names),
374
- assets=assets,
335
+ assets=datapoints,
336
+ data_type=data_type,
375
337
  responses_per_datapoint=responses_per_datapoint,
376
338
  contexts=contexts,
377
339
  media_contexts=media_contexts,
@@ -393,6 +355,7 @@ class RapidataOrderManager:
393
355
  data_type: Literal["media", "text"] = "media",
394
356
  random_comparisons_ratio: float = 0.5,
395
357
  context: Optional[str] = None,
358
+ media_context: Optional[str] = None,
396
359
  validation_set_id: Optional[str] = None,
397
360
  filters: Sequence[RapidataFilter] = [],
398
361
  settings: Sequence[RapidataSetting] = [],
@@ -416,6 +379,8 @@ class RapidataOrderManager:
416
379
  The rest will focus on pairing similarly ranked datapoints. Defaults to 0.5 and can be left untouched.
417
380
  context (str, optional): The context for all the comparison. Defaults to None.\n
418
381
  If provided will be shown in addition to the instruction for all the matchups.
382
+ media_context (str, optional): The media context for all the comparison. Defaults to None.\n
383
+ If provided will be shown in addition to the instruction for all the matchups.
419
384
  validation_set_id (str, optional): The ID of the validation set. Defaults to None.\n
420
385
  If provided, one validation task will be shown infront of the datapoints that will be labeled.
421
386
  filters (Sequence[RapidataFilter], optional): The list of filters for the order. Defaults to []. Decides who the tasks should be shown to.
@@ -427,13 +392,14 @@ class RapidataOrderManager:
427
392
  if len(datapoints) < 2:
428
393
  raise ValueError("At least two datapoints are required")
429
394
 
430
- if data_type == "media":
431
- assets = [MediaAsset(path=path) for path in datapoints]
432
- elif data_type == "text":
433
- assets = [TextAsset(text=text) for text in datapoints]
434
- else:
435
- raise ValueError(
436
- f"Unsupported data type: {data_type}, must be one of 'media' or 'text'"
395
+ metadatas: list[Metadata] = []
396
+ if context:
397
+ metadatas.append(PromptMetadata(context))
398
+ if media_context:
399
+ metadatas.append(
400
+ MediaAssetMetadata(
401
+ self.__asset_uploader.upload_asset(media_context)
402
+ )
437
403
  )
438
404
 
439
405
  return self._create_general_order(
@@ -442,9 +408,10 @@ class RapidataOrderManager:
442
408
  criteria=instruction,
443
409
  total_comparison_budget=total_comparison_budget,
444
410
  random_comparisons_ratio=random_comparisons_ratio,
445
- context=context,
411
+ metadatas=metadatas,
446
412
  ),
447
- assets=assets,
413
+ assets=datapoints,
414
+ data_type=data_type,
448
415
  responses_per_datapoint=responses_per_comparison,
449
416
  validation_set_id=validation_set_id,
450
417
  filters=filters,
@@ -494,20 +461,11 @@ class RapidataOrderManager:
494
461
  with tracer.start_as_current_span(
495
462
  "RapidataOrderManager.create_free_text_order"
496
463
  ):
497
-
498
- if data_type == "media":
499
- assets = [MediaAsset(path=path) for path in datapoints]
500
- elif data_type == "text":
501
- assets = [TextAsset(text=text) for text in datapoints]
502
- else:
503
- raise ValueError(
504
- f"Unsupported data type: {data_type}, must be one of 'media' or 'text'"
505
- )
506
-
507
464
  return self._create_general_order(
508
465
  name=name,
509
466
  workflow=FreeTextWorkflow(instruction=instruction),
510
- assets=assets,
467
+ assets=datapoints,
468
+ data_type=data_type,
511
469
  responses_per_datapoint=responses_per_datapoint,
512
470
  contexts=contexts,
513
471
  media_contexts=media_contexts,
@@ -555,14 +513,12 @@ class RapidataOrderManager:
555
513
  with tracer.start_as_current_span(
556
514
  "RapidataOrderManager.create_select_words_order"
557
515
  ):
558
- assets = [MediaAsset(path=path) for path in datapoints]
559
-
560
516
  return self._create_general_order(
561
517
  name=name,
562
518
  workflow=SelectWordsWorkflow(
563
519
  instruction=instruction,
564
520
  ),
565
- assets=assets,
521
+ assets=datapoints,
566
522
  responses_per_datapoint=responses_per_datapoint,
567
523
  validation_set_id=validation_set_id,
568
524
  filters=filters,
@@ -611,12 +567,11 @@ class RapidataOrderManager:
611
567
  This will NOT be shown to the labelers but will be included in the result purely for your own reference.
612
568
  """
613
569
  with tracer.start_as_current_span("RapidataOrderManager.create_locate_order"):
614
- assets = [MediaAsset(path=path) for path in datapoints]
615
570
 
616
571
  return self._create_general_order(
617
572
  name=name,
618
573
  workflow=LocateWorkflow(target=instruction),
619
- assets=assets,
574
+ assets=datapoints,
620
575
  responses_per_datapoint=responses_per_datapoint,
621
576
  contexts=contexts,
622
577
  media_contexts=media_contexts,
@@ -666,12 +621,11 @@ class RapidataOrderManager:
666
621
  This will NOT be shown to the labelers but will be included in the result purely for your own reference.
667
622
  """
668
623
  with tracer.start_as_current_span("RapidataOrderManager.create_draw_order"):
669
- assets = [MediaAsset(path=path) for path in datapoints]
670
624
 
671
625
  return self._create_general_order(
672
626
  name=name,
673
627
  workflow=DrawWorkflow(target=instruction),
674
- assets=assets,
628
+ assets=datapoints,
675
629
  responses_per_datapoint=responses_per_datapoint,
676
630
  contexts=contexts,
677
631
  media_contexts=media_contexts,
@@ -727,22 +681,10 @@ class RapidataOrderManager:
727
681
  with tracer.start_as_current_span(
728
682
  "RapidataOrderManager.create_timestamp_order"
729
683
  ):
730
- assets = [MediaAsset(path=path) for path in datapoints]
731
-
732
- for asset in tqdm(
733
- assets,
734
- desc="Downloading assets and checking duration",
735
- disable=rapidata_config.logging.silent_mode,
736
- ):
737
- if not asset.get_duration():
738
- raise ValueError(
739
- "The datapoints for this order must have a duration. (e.g. video or audio)"
740
- )
741
-
742
684
  return self._create_general_order(
743
685
  name=name,
744
686
  workflow=TimestampWorkflow(instruction=instruction),
745
- assets=assets,
687
+ assets=datapoints,
746
688
  responses_per_datapoint=responses_per_datapoint,
747
689
  contexts=contexts,
748
690
  media_contexts=media_contexts,
@@ -17,4 +17,4 @@ class TranslationBehaviour(RapidataSetting):
17
17
  if not isinstance(value, TranslationBehaviourOptions):
18
18
  raise ValueError("The value must be a TranslationBehaviourOptions.")
19
19
 
20
- super().__init__(key="translation_behaviour", value=value)
20
+ super().__init__(key="translation_behaviour", value=value.value)
@@ -8,6 +8,9 @@ from rapidata.api_client.models.update_validation_set_model import (
8
8
  UpdateValidationSetModel,
9
9
  )
10
10
  from rapidata.api_client.models.update_should_alert_model import UpdateShouldAlertModel
11
+ from rapidata.rapidata_client.validation.rapids._validation_rapid_uploader import (
12
+ ValidationRapidUploader,
13
+ )
11
14
 
12
15
 
13
16
  class RapidataValidationSet:
@@ -28,7 +31,8 @@ class RapidataValidationSet:
28
31
  self.validation_set_details_page = (
29
32
  f"https://app.{openapi_service.environment}/validation-set/detail/{self.id}"
30
33
  )
31
- self.__openapi_service = openapi_service
34
+ self._openapi_service = openapi_service
35
+ self.validation_rapid_uploader = ValidationRapidUploader(openapi_service)
32
36
 
33
37
  def add_rapid(self, rapid: Rapid):
34
38
  """Add a Rapid to the validation set.
@@ -38,7 +42,7 @@ class RapidataValidationSet:
38
42
  """
39
43
  with tracer.start_as_current_span("RapidataValidationSet.add_rapid"):
40
44
  logger.debug("Adding rapid %s to validation set %s", rapid, self.id)
41
- rapid._add_to_validation_set(self.id, self.__openapi_service)
45
+ self.validation_rapid_uploader.upload_rapid(rapid, self.id)
42
46
  return self
43
47
 
44
48
  def update_dimensions(self, dimensions: list[str]):
@@ -51,7 +55,7 @@ class RapidataValidationSet:
51
55
  logger.debug(
52
56
  "Updating dimensions for validation set %s to %s", self.id, dimensions
53
57
  )
54
- self.__openapi_service.validation_api.validation_set_validation_set_id_patch(
58
+ self._openapi_service.validation_api.validation_set_validation_set_id_patch(
55
59
  self.id, UpdateValidationSetModel(dimensions=dimensions)
56
60
  )
57
61
  return self
@@ -69,7 +73,7 @@ class RapidataValidationSet:
69
73
  logger.debug(
70
74
  "Setting shouldAlert for validation set %s to %s", self.id, should_alert
71
75
  )
72
- self.__openapi_service.validation_api.validation_set_validation_set_id_patch(
76
+ self._openapi_service.validation_api.validation_set_validation_set_id_patch(
73
77
  self.id, UpdateValidationSetModel(shouldAlert=should_alert)
74
78
  )
75
79
  return self
@@ -97,7 +101,7 @@ class RapidataValidationSet:
97
101
  """Deletes the validation set"""
98
102
  with tracer.start_as_current_span("RapidataValidationSet.delete"):
99
103
  logger.info("Deleting ValidationSet '%s'", self)
100
- self.__openapi_service.validation_api.validation_set_validation_set_id_delete(
104
+ self._openapi_service.validation_api.validation_set_validation_set_id_delete(
101
105
  self.id
102
106
  )
103
107
  logger.debug("ValidationSet '%s' has been deleted.", self)
@@ -0,0 +1,101 @@
1
+ from rapidata.rapidata_client.validation.rapids.rapids import Rapid
2
+ from rapidata.service.openapi_service import OpenAPIService
3
+ from rapidata.api_client.models.multi_asset_input_assets_inner import (
4
+ MultiAssetInput,
5
+ MultiAssetInputAssetsInner,
6
+ )
7
+ from rapidata.api_client.models.add_validation_rapid_new_model import (
8
+ AddValidationRapidNewModel,
9
+ )
10
+ from rapidata.api_client.models.add_validation_rapid_model_truth import (
11
+ AddValidationRapidModelTruth,
12
+ )
13
+ from rapidata.api_client.models.create_datapoint_from_files_model_metadata_inner import (
14
+ CreateDatapointFromFilesModelMetadataInner,
15
+ )
16
+ from rapidata.api_client.models.existing_asset_input import ExistingAssetInput
17
+ from rapidata.rapidata_client.datapoints._asset_uploader import AssetUploader
18
+ from rapidata.rapidata_client.datapoints.metadata import (
19
+ PromptMetadata,
20
+ MediaAssetMetadata,
21
+ SelectWordsMetadata,
22
+ Metadata,
23
+ )
24
+ from rapidata.api_client.models.add_validation_rapid_new_model_asset import (
25
+ AddValidationRapidNewModelAsset,
26
+ )
27
+ from rapidata.api_client.models.text_asset_input import TextAssetInput
28
+ from rapidata.api_client.models.add_validation_rapid_model_payload import (
29
+ AddValidationRapidModelPayload,
30
+ )
31
+
32
+
33
+ class ValidationRapidUploader:
34
+ def __init__(self, openapi_service: OpenAPIService):
35
+ self.openapi_service = openapi_service
36
+ self.asset_uploader = AssetUploader(openapi_service)
37
+
38
+ def upload_rapid(self, rapid: Rapid, validation_set_id: str) -> None:
39
+ metadata = self._get_metadata(rapid)
40
+
41
+ uploaded_asset = (
42
+ self._handle_media_rapid(rapid)
43
+ if rapid.data_type == "media"
44
+ else self._handle_text_rapid(rapid)
45
+ )
46
+
47
+ self.openapi_service.validation_api.validation_set_validation_set_id_rapid_new_post(
48
+ validation_set_id=validation_set_id,
49
+ add_validation_rapid_new_model=AddValidationRapidNewModel(
50
+ asset=uploaded_asset,
51
+ metadata=metadata,
52
+ payload=self._get_payload(rapid),
53
+ truth=AddValidationRapidModelTruth(actual_instance=rapid.truth),
54
+ featureFlags=(
55
+ [setting._to_feature_flag() for setting in rapid.settings]
56
+ if rapid.settings
57
+ else None
58
+ ),
59
+ ),
60
+ )
61
+
62
+ def _get_payload(self, rapid: Rapid) -> AddValidationRapidModelPayload:
63
+ if isinstance(rapid.payload, dict):
64
+ return AddValidationRapidModelPayload(actual_instance=rapid.payload)
65
+ return AddValidationRapidModelPayload(actual_instance=rapid.payload.to_dict())
66
+
67
+ def _get_metadata(
68
+ self, rapid: Rapid
69
+ ) -> list[CreateDatapointFromFilesModelMetadataInner]:
70
+ rapid_metadata: list[Metadata] = []
71
+ if rapid.context:
72
+ rapid_metadata.append(PromptMetadata(prompt=rapid.context))
73
+ if rapid.sentence:
74
+ rapid_metadata.append(SelectWordsMetadata(select_words=rapid.sentence))
75
+ if rapid.media_context:
76
+ rapid_metadata.append(
77
+ MediaAssetMetadata(
78
+ internal_file_name=self.asset_uploader.upload_asset(
79
+ rapid.media_context
80
+ )
81
+ )
82
+ )
83
+
84
+ metadata = [
85
+ CreateDatapointFromFilesModelMetadataInner(
86
+ actual_instance=metadata.to_model()
87
+ )
88
+ for metadata in rapid_metadata
89
+ ]
90
+
91
+ return metadata
92
+
93
+ def _handle_text_rapid(self, rapid: Rapid) -> AddValidationRapidNewModelAsset:
94
+ return AddValidationRapidNewModelAsset(
95
+ actual_instance=self.asset_uploader.get_uploaded_text_input(rapid.asset),
96
+ )
97
+
98
+ def _handle_media_rapid(self, rapid: Rapid) -> AddValidationRapidNewModelAsset:
99
+ return AddValidationRapidNewModelAsset(
100
+ actual_instance=self.asset_uploader.get_uploaded_asset_input(rapid.asset),
101
+ )
@@ -1,19 +1,43 @@
1
1
  from rapidata.api_client.models.box_shape import BoxShape
2
+ from pydantic import BaseModel, field_validator, model_validator
2
3
 
3
4
 
4
- class Box:
5
+ class Box(BaseModel):
5
6
  """
6
- Used in the Locate and Draw Validation sets. All coordinates are in pixels.
7
+ Used in the Locate and Draw Validation sets. All coordinates are in ratio of the image size (0.0 to 1.0).
7
8
 
8
9
  Args:
9
- x_min (float): The minimum x value of the box.
10
- y_min (float): The minimum y value of the box.
11
- x_max (float): The maximum x value of the box.
12
- y_max (float): The maximum y value of the box.
10
+ x_min (float): The minimum x value of the box in ratio of the image size.
11
+ y_min (float): The minimum y value of the box in ratio of the image size.
12
+ x_max (float): The maximum x value of the box in ratio of the image size.
13
+ y_max (float): The maximum y value of the box in ratio of the image size.
13
14
  """
14
15
 
15
- def __init__(self, x_min: float, y_min: float, x_max: float, y_max: float):
16
- self.x_min = x_min
17
- self.y_min = y_min
18
- self.x_max = x_max
19
- self.y_max = y_max
16
+ x_min: float
17
+ y_min: float
18
+ x_max: float
19
+ y_max: float
20
+
21
+ @field_validator("x_min", "y_min", "x_max", "y_max")
22
+ @classmethod
23
+ def coordinates_between_zero_and_one(cls, v: float) -> float:
24
+ if not (0.0 <= v <= 1.0):
25
+ raise ValueError("Box coordinates must be between 0 and 1")
26
+ return v
27
+
28
+ @model_validator(mode="after")
29
+ def check_min_less_than_max(self) -> "Box":
30
+ if self.x_min >= self.x_max:
31
+ raise ValueError("x_min must be less than x_max")
32
+ if self.y_min >= self.y_max:
33
+ raise ValueError("y_min must be less than y_max")
34
+ return self
35
+
36
+ def to_model(self) -> BoxShape:
37
+ return BoxShape(
38
+ _t="BoxShape",
39
+ xMin=self.x_min * 100,
40
+ yMin=self.y_min * 100,
41
+ xMax=self.x_max * 100,
42
+ yMax=self.y_max * 100,
43
+ )