rapidata 2.41.3__py3-none-any.whl → 2.42.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rapidata might be problematic. Click here for more details.

Files changed (74) hide show
  1. rapidata/__init__.py +1 -5
  2. rapidata/api_client/__init__.py +14 -14
  3. rapidata/api_client/api/__init__.py +1 -0
  4. rapidata/api_client/api/asset_api.py +851 -0
  5. rapidata/api_client/api/benchmark_api.py +298 -0
  6. rapidata/api_client/api/customer_rapid_api.py +29 -43
  7. rapidata/api_client/api/dataset_api.py +163 -1143
  8. rapidata/api_client/api/participant_api.py +28 -74
  9. rapidata/api_client/api/validation_set_api.py +283 -0
  10. rapidata/api_client/models/__init__.py +13 -14
  11. rapidata/api_client/models/add_validation_rapid_model.py +3 -3
  12. rapidata/api_client/models/add_validation_rapid_new_model.py +152 -0
  13. rapidata/api_client/models/add_validation_rapid_new_model_asset.py +182 -0
  14. rapidata/api_client/models/compare_workflow_model.py +3 -3
  15. rapidata/api_client/models/create_datapoint_from_files_model.py +3 -3
  16. rapidata/api_client/models/create_datapoint_from_text_sources_model.py +3 -3
  17. rapidata/api_client/models/create_datapoint_from_urls_model.py +3 -3
  18. rapidata/api_client/models/create_datapoint_model.py +108 -0
  19. rapidata/api_client/models/create_datapoint_model_asset.py +182 -0
  20. rapidata/api_client/models/create_demographic_rapid_model.py +13 -2
  21. rapidata/api_client/models/create_demographic_rapid_model_asset.py +188 -0
  22. rapidata/api_client/models/create_demographic_rapid_model_new.py +119 -0
  23. rapidata/api_client/models/create_sample_model.py +8 -2
  24. rapidata/api_client/models/create_sample_model_asset.py +182 -0
  25. rapidata/api_client/models/create_sample_model_obsolete.py +87 -0
  26. rapidata/api_client/models/file_asset_input_file.py +8 -22
  27. rapidata/api_client/models/fork_benchmark_result.py +87 -0
  28. rapidata/api_client/models/form_file_wrapper.py +17 -2
  29. rapidata/api_client/models/get_asset_metadata_result.py +100 -0
  30. rapidata/api_client/models/multi_asset_input_assets_inner.py +10 -24
  31. rapidata/api_client/models/prompt_asset_metadata_input.py +3 -3
  32. rapidata/api_client/models/proxy_file_wrapper.py +17 -2
  33. rapidata/api_client/models/stream_file_wrapper.py +25 -3
  34. rapidata/api_client/models/submit_prompt_model.py +3 -3
  35. rapidata/api_client/models/text_metadata.py +6 -1
  36. rapidata/api_client/models/text_metadata_model.py +7 -2
  37. rapidata/api_client/models/upload_file_from_url_result.py +87 -0
  38. rapidata/api_client/models/upload_file_result.py +87 -0
  39. rapidata/api_client/models/zip_entry_file_wrapper.py +33 -2
  40. rapidata/api_client_README.md +28 -25
  41. rapidata/rapidata_client/__init__.py +0 -1
  42. rapidata/rapidata_client/benchmark/participant/_participant.py +25 -24
  43. rapidata/rapidata_client/benchmark/rapidata_benchmark.py +89 -102
  44. rapidata/rapidata_client/datapoints/__init__.py +0 -1
  45. rapidata/rapidata_client/datapoints/_asset_uploader.py +71 -0
  46. rapidata/rapidata_client/datapoints/_datapoint.py +58 -171
  47. rapidata/rapidata_client/datapoints/_datapoint_uploader.py +95 -0
  48. rapidata/rapidata_client/datapoints/assets/__init__.py +0 -11
  49. rapidata/rapidata_client/datapoints/metadata/_media_asset_metadata.py +10 -7
  50. rapidata/rapidata_client/demographic/demographic_manager.py +21 -8
  51. rapidata/rapidata_client/exceptions/failed_upload_exception.py +0 -62
  52. rapidata/rapidata_client/order/_rapidata_order_builder.py +0 -10
  53. rapidata/rapidata_client/order/dataset/_rapidata_dataset.py +65 -187
  54. rapidata/rapidata_client/order/rapidata_order_manager.py +62 -124
  55. rapidata/rapidata_client/validation/rapidata_validation_set.py +9 -5
  56. rapidata/rapidata_client/validation/rapids/_validation_rapid_uploader.py +101 -0
  57. rapidata/rapidata_client/validation/rapids/box.py +35 -11
  58. rapidata/rapidata_client/validation/rapids/rapids.py +26 -128
  59. rapidata/rapidata_client/validation/rapids/rapids_manager.py +123 -104
  60. rapidata/rapidata_client/validation/validation_set_manager.py +41 -38
  61. rapidata/rapidata_client/workflow/_ranking_workflow.py +14 -17
  62. rapidata/rapidata_client/workflow/_select_words_workflow.py +3 -16
  63. rapidata/service/openapi_service.py +8 -3
  64. {rapidata-2.41.3.dist-info → rapidata-2.42.1.dist-info}/METADATA +1 -1
  65. {rapidata-2.41.3.dist-info → rapidata-2.42.1.dist-info}/RECORD +67 -58
  66. {rapidata-2.41.3.dist-info → rapidata-2.42.1.dist-info}/WHEEL +1 -1
  67. rapidata/rapidata_client/datapoints/assets/_base_asset.py +0 -13
  68. rapidata/rapidata_client/datapoints/assets/_media_asset.py +0 -318
  69. rapidata/rapidata_client/datapoints/assets/_multi_asset.py +0 -61
  70. rapidata/rapidata_client/datapoints/assets/_sessions.py +0 -40
  71. rapidata/rapidata_client/datapoints/assets/_text_asset.py +0 -34
  72. rapidata/rapidata_client/datapoints/assets/data_type_enum.py +0 -8
  73. rapidata/rapidata_client/order/dataset/_progress_tracker.py +0 -100
  74. {rapidata-2.41.3.dist-info → rapidata-2.42.1.dist-info}/licenses/LICENSE +0 -0
@@ -2,16 +2,10 @@ from typing import Sequence, Optional, Literal
2
2
  from itertools import zip_longest
3
3
 
4
4
  from rapidata.rapidata_client.config.tracer import tracer
5
+ from rapidata.rapidata_client.datapoints.metadata._base_metadata import Metadata
5
6
  from rapidata.service.openapi_service import OpenAPIService
6
7
  from rapidata.rapidata_client.order.rapidata_order import RapidataOrder
7
8
  from rapidata.rapidata_client.order._rapidata_order_builder import RapidataOrderBuilder
8
- from rapidata.rapidata_client.datapoints.metadata import (
9
- PromptMetadata,
10
- SelectWordsMetadata,
11
- PrivateTextMetadata,
12
- MediaAssetMetadata,
13
- Metadata,
14
- )
15
9
  from rapidata.rapidata_client.referee._naive_referee import NaiveReferee
16
10
  from rapidata.rapidata_client.referee._early_stopping_referee import (
17
11
  EarlyStoppingReferee,
@@ -28,13 +22,17 @@ from rapidata.rapidata_client.workflow import (
28
22
  TimestampWorkflow,
29
23
  RankingWorkflow,
30
24
  )
31
- from rapidata.rapidata_client.datapoints.assets import MediaAsset, TextAsset, MultiAsset
32
25
  from rapidata.rapidata_client.datapoints._datapoint import Datapoint
26
+ from rapidata.rapidata_client.datapoints.metadata import (
27
+ PromptMetadata,
28
+ MediaAssetMetadata,
29
+ )
33
30
  from rapidata.rapidata_client.filter import RapidataFilter
34
31
  from rapidata.rapidata_client.filter.rapidata_filters import RapidataFilters
35
32
  from rapidata.rapidata_client.settings import RapidataSettings, RapidataSetting
36
33
  from rapidata.rapidata_client.selection.rapidata_selections import RapidataSelections
37
34
  from rapidata.rapidata_client.config import logger, rapidata_config
35
+ from rapidata.rapidata_client.datapoints._asset_uploader import AssetUploader
38
36
 
39
37
  from rapidata.api_client.models.query_model import QueryModel
40
38
  from rapidata.api_client.models.page_info import PageInfo
@@ -64,13 +62,15 @@ class RapidataOrderManager:
64
62
  self.selections = RapidataSelections
65
63
  self.__priority: int | None = None
66
64
  self.__sticky_state: Literal["None", "Temporary", "Permanent"] | None = None
65
+ self.__asset_uploader = AssetUploader(openapi_service)
67
66
  logger.debug("RapidataOrderManager initialized")
68
67
 
69
68
  def _create_general_order(
70
69
  self,
71
70
  name: str,
72
71
  workflow: Workflow,
73
- assets: list[MediaAsset] | list[TextAsset] | list[MultiAsset],
72
+ assets: list[str] | list[list[str]],
73
+ data_type: Literal["media", "text"] = "media",
74
74
  responses_per_datapoint: int = 10,
75
75
  contexts: list[str] | None = None,
76
76
  media_contexts: list[str] | None = None,
@@ -89,22 +89,9 @@ class RapidataOrderManager:
89
89
  if contexts and len(contexts) != len(assets):
90
90
  raise ValueError("Number of contexts must match number of datapoints")
91
91
 
92
- if contexts:
93
- if any(not isinstance(context, str) for context in contexts) or any(
94
- len(context) == 0 for context in contexts
95
- ):
96
- raise ValueError(
97
- "Contexts must all be strings that are not empty\nProvide list of strings or set contexts to None"
98
- )
99
-
100
92
  if media_contexts and len(media_contexts) != len(assets):
101
93
  raise ValueError("Number of media contexts must match number of datapoints")
102
94
 
103
- if media_contexts:
104
- for media_context in media_contexts:
105
- if not media_context.startswith("http"):
106
- raise ValueError("Media contexts must all be URLs")
107
-
108
95
  if sentences and len(sentences) != len(assets):
109
96
  raise ValueError("Number of sentences must match number of datapoints")
110
97
 
@@ -122,11 +109,15 @@ class RapidataOrderManager:
122
109
  max_vote_count=responses_per_datapoint,
123
110
  )
124
111
 
112
+ if data_type not in ["media", "text"]:
113
+ raise ValueError("Data type must be one of 'media' or 'text'")
114
+
125
115
  logger.debug(
126
- "Creating order with parameters: name %s, workflow %s, assets %s, responses_per_datapoint %s, contexts %s, media_contexts %s, validation_set_id %s, confidence_threshold %s, filters %s, settings %s, sentences %s, selections %s, private_notes %s",
116
+ "Creating order with parameters: name %s, workflow %s, datapoints %s, data_type %s, responses_per_datapoint %s, contexts %s, media_contexts %s, validation_set_id %s, confidence_threshold %s, filters %s, settings %s, sentences %s, selections %s, private_notes %s",
127
117
  name,
128
118
  workflow,
129
119
  assets,
120
+ data_type,
130
121
  responses_per_datapoint,
131
122
  contexts,
132
123
  media_contexts,
@@ -148,45 +139,25 @@ class RapidataOrderManager:
148
139
  "Warning: Both selections and validation_set_id provided. Ignoring validation_set_id."
149
140
  )
150
141
 
151
- prompts_metadata = (
152
- [PromptMetadata(prompt=prompt) for prompt in contexts] if contexts else None
153
- )
154
- sentence_metadata = (
155
- [SelectWordsMetadata(select_words=sentence) for sentence in sentences]
156
- if sentences
157
- else None
158
- )
159
-
160
- if prompts_metadata and sentence_metadata:
161
- raise ValueError("You can only use contexts or sentences, not both")
162
-
163
- asset_metadata: Sequence[Metadata] = (
164
- [MediaAssetMetadata(url=context) for context in media_contexts]
165
- if media_contexts
166
- else []
167
- )
168
- prompt_metadata: Sequence[Metadata] = (
169
- prompts_metadata or sentence_metadata or []
170
- )
171
- private_notes_metadata: Sequence[Metadata] = (
172
- [PrivateTextMetadata(text=text) for text in private_notes]
173
- if private_notes
174
- else []
175
- )
176
-
177
- multi_metadata = [
178
- [item for item in items if item is not None]
179
- for items in zip_longest(
180
- prompt_metadata, asset_metadata, private_notes_metadata
181
- )
182
- ]
183
-
184
142
  order = (
185
143
  order_builder._workflow(workflow)
186
144
  ._datapoints(
187
145
  datapoints=[
188
- Datapoint(asset=asset, metadata=metadata)
189
- for asset, metadata in zip_longest(assets, multi_metadata)
146
+ Datapoint(
147
+ asset=asset,
148
+ data_type=data_type,
149
+ context=context,
150
+ media_context=media_context,
151
+ sentence=sentence,
152
+ private_note=private_note,
153
+ )
154
+ for asset, context, media_context, sentence, private_note in zip_longest(
155
+ assets,
156
+ contexts or [],
157
+ media_contexts or [],
158
+ sentences or [],
159
+ private_notes or [],
160
+ )
190
161
  ]
191
162
  )
192
163
  ._referee(referee)
@@ -269,21 +240,18 @@ class RapidataOrderManager:
269
240
  with tracer.start_as_current_span(
270
241
  "RapidataOrderManager.create_classification_order"
271
242
  ):
272
- if data_type == "media":
273
- assets = [MediaAsset(path=path) for path in datapoints]
274
- elif data_type == "text":
275
- assets = [TextAsset(text=text) for text in datapoints]
276
- else:
277
- raise ValueError(
278
- f"Unsupported data type: {data_type}, must be one of 'media' or 'text'"
279
- )
243
+ if not isinstance(datapoints, list) or not all(
244
+ isinstance(datapoint, str) for datapoint in datapoints
245
+ ):
246
+ raise ValueError("Datapoints must be a list of strings")
280
247
 
281
248
  return self._create_general_order(
282
249
  name=name,
283
250
  workflow=ClassifyWorkflow(
284
251
  instruction=instruction, answer_options=answer_options
285
252
  ),
286
- assets=assets,
253
+ assets=datapoints,
254
+ data_type=data_type,
287
255
  responses_per_datapoint=responses_per_datapoint,
288
256
  contexts=contexts,
289
257
  media_contexts=media_contexts,
@@ -350,7 +318,7 @@ class RapidataOrderManager:
350
318
  This will NOT be shown to the labelers but will be included in the result purely for your own reference.
351
319
  """
352
320
  with tracer.start_as_current_span("RapidataOrderManager.create_compare_order"):
353
- if any(type(datapoint) != list for datapoint in datapoints):
321
+ if any(not isinstance(datapoint, list) for datapoint in datapoints):
354
322
  raise ValueError("Each datapoint must be a list of 2 paths/texts")
355
323
 
356
324
  if any(len(datapoint) != 2 for datapoint in datapoints):
@@ -361,25 +329,11 @@ class RapidataOrderManager:
361
329
  "A_B_naming must be a list of exactly two strings or None"
362
330
  )
363
331
 
364
- if data_type == "media":
365
- assets = [
366
- MultiAsset([MediaAsset(path=path) for path in datapoint])
367
- for datapoint in datapoints
368
- ]
369
- elif data_type == "text":
370
- assets = [
371
- MultiAsset([TextAsset(text=text) for text in datapoint])
372
- for datapoint in datapoints
373
- ]
374
- else:
375
- raise ValueError(
376
- f"Unsupported data type: {data_type}, must be one of 'media' or 'text'"
377
- )
378
-
379
332
  return self._create_general_order(
380
333
  name=name,
381
334
  workflow=CompareWorkflow(instruction=instruction, a_b_names=a_b_names),
382
- assets=assets,
335
+ assets=datapoints,
336
+ data_type=data_type,
383
337
  responses_per_datapoint=responses_per_datapoint,
384
338
  contexts=contexts,
385
339
  media_contexts=media_contexts,
@@ -401,6 +355,7 @@ class RapidataOrderManager:
401
355
  data_type: Literal["media", "text"] = "media",
402
356
  random_comparisons_ratio: float = 0.5,
403
357
  context: Optional[str] = None,
358
+ media_context: Optional[str] = None,
404
359
  validation_set_id: Optional[str] = None,
405
360
  filters: Sequence[RapidataFilter] = [],
406
361
  settings: Sequence[RapidataSetting] = [],
@@ -424,6 +379,8 @@ class RapidataOrderManager:
424
379
  The rest will focus on pairing similarly ranked datapoints. Defaults to 0.5 and can be left untouched.
425
380
  context (str, optional): The context for all the comparison. Defaults to None.\n
426
381
  If provided will be shown in addition to the instruction for all the matchups.
382
+ media_context (str, optional): The media context for all the comparison. Defaults to None.\n
383
+ If provided will be shown in addition to the instruction for all the matchups.
427
384
  validation_set_id (str, optional): The ID of the validation set. Defaults to None.\n
428
385
  If provided, one validation task will be shown infront of the datapoints that will be labeled.
429
386
  filters (Sequence[RapidataFilter], optional): The list of filters for the order. Defaults to []. Decides who the tasks should be shown to.
@@ -435,13 +392,18 @@ class RapidataOrderManager:
435
392
  if len(datapoints) < 2:
436
393
  raise ValueError("At least two datapoints are required")
437
394
 
438
- if data_type == "media":
439
- assets = [MediaAsset(path=path) for path in datapoints]
440
- elif data_type == "text":
441
- assets = [TextAsset(text=text) for text in datapoints]
442
- else:
443
- raise ValueError(
444
- f"Unsupported data type: {data_type}, must be one of 'media' or 'text'"
395
+ metadatas: list[Metadata] = []
396
+ if context:
397
+ if not isinstance(context, str) or context == "":
398
+ raise ValueError("Context must be a non-empty string")
399
+ metadatas.append(PromptMetadata(context))
400
+ if media_context:
401
+ if not isinstance(media_context, str) or media_context == "":
402
+ raise ValueError("Media context must be a non-empty string")
403
+ metadatas.append(
404
+ MediaAssetMetadata(
405
+ self.__asset_uploader.upload_asset(media_context)
406
+ )
445
407
  )
446
408
 
447
409
  return self._create_general_order(
@@ -450,9 +412,10 @@ class RapidataOrderManager:
450
412
  criteria=instruction,
451
413
  total_comparison_budget=total_comparison_budget,
452
414
  random_comparisons_ratio=random_comparisons_ratio,
453
- context=context,
415
+ metadatas=metadatas,
454
416
  ),
455
- assets=assets,
417
+ assets=datapoints,
418
+ data_type=data_type,
456
419
  responses_per_datapoint=responses_per_comparison,
457
420
  validation_set_id=validation_set_id,
458
421
  filters=filters,
@@ -502,20 +465,11 @@ class RapidataOrderManager:
502
465
  with tracer.start_as_current_span(
503
466
  "RapidataOrderManager.create_free_text_order"
504
467
  ):
505
-
506
- if data_type == "media":
507
- assets = [MediaAsset(path=path) for path in datapoints]
508
- elif data_type == "text":
509
- assets = [TextAsset(text=text) for text in datapoints]
510
- else:
511
- raise ValueError(
512
- f"Unsupported data type: {data_type}, must be one of 'media' or 'text'"
513
- )
514
-
515
468
  return self._create_general_order(
516
469
  name=name,
517
470
  workflow=FreeTextWorkflow(instruction=instruction),
518
- assets=assets,
471
+ assets=datapoints,
472
+ data_type=data_type,
519
473
  responses_per_datapoint=responses_per_datapoint,
520
474
  contexts=contexts,
521
475
  media_contexts=media_contexts,
@@ -563,14 +517,12 @@ class RapidataOrderManager:
563
517
  with tracer.start_as_current_span(
564
518
  "RapidataOrderManager.create_select_words_order"
565
519
  ):
566
- assets = [MediaAsset(path=path) for path in datapoints]
567
-
568
520
  return self._create_general_order(
569
521
  name=name,
570
522
  workflow=SelectWordsWorkflow(
571
523
  instruction=instruction,
572
524
  ),
573
- assets=assets,
525
+ assets=datapoints,
574
526
  responses_per_datapoint=responses_per_datapoint,
575
527
  validation_set_id=validation_set_id,
576
528
  filters=filters,
@@ -619,12 +571,11 @@ class RapidataOrderManager:
619
571
  This will NOT be shown to the labelers but will be included in the result purely for your own reference.
620
572
  """
621
573
  with tracer.start_as_current_span("RapidataOrderManager.create_locate_order"):
622
- assets = [MediaAsset(path=path) for path in datapoints]
623
574
 
624
575
  return self._create_general_order(
625
576
  name=name,
626
577
  workflow=LocateWorkflow(target=instruction),
627
- assets=assets,
578
+ assets=datapoints,
628
579
  responses_per_datapoint=responses_per_datapoint,
629
580
  contexts=contexts,
630
581
  media_contexts=media_contexts,
@@ -674,12 +625,11 @@ class RapidataOrderManager:
674
625
  This will NOT be shown to the labelers but will be included in the result purely for your own reference.
675
626
  """
676
627
  with tracer.start_as_current_span("RapidataOrderManager.create_draw_order"):
677
- assets = [MediaAsset(path=path) for path in datapoints]
678
628
 
679
629
  return self._create_general_order(
680
630
  name=name,
681
631
  workflow=DrawWorkflow(target=instruction),
682
- assets=assets,
632
+ assets=datapoints,
683
633
  responses_per_datapoint=responses_per_datapoint,
684
634
  contexts=contexts,
685
635
  media_contexts=media_contexts,
@@ -735,22 +685,10 @@ class RapidataOrderManager:
735
685
  with tracer.start_as_current_span(
736
686
  "RapidataOrderManager.create_timestamp_order"
737
687
  ):
738
- assets = [MediaAsset(path=path) for path in datapoints]
739
-
740
- for asset in tqdm(
741
- assets,
742
- desc="Downloading assets and checking duration",
743
- disable=rapidata_config.logging.silent_mode,
744
- ):
745
- if not asset.get_duration():
746
- raise ValueError(
747
- "The datapoints for this order must have a duration. (e.g. video or audio)"
748
- )
749
-
750
688
  return self._create_general_order(
751
689
  name=name,
752
690
  workflow=TimestampWorkflow(instruction=instruction),
753
- assets=assets,
691
+ assets=datapoints,
754
692
  responses_per_datapoint=responses_per_datapoint,
755
693
  contexts=contexts,
756
694
  media_contexts=media_contexts,
@@ -8,6 +8,9 @@ from rapidata.api_client.models.update_validation_set_model import (
8
8
  UpdateValidationSetModel,
9
9
  )
10
10
  from rapidata.api_client.models.update_should_alert_model import UpdateShouldAlertModel
11
+ from rapidata.rapidata_client.validation.rapids._validation_rapid_uploader import (
12
+ ValidationRapidUploader,
13
+ )
11
14
 
12
15
 
13
16
  class RapidataValidationSet:
@@ -28,7 +31,8 @@ class RapidataValidationSet:
28
31
  self.validation_set_details_page = (
29
32
  f"https://app.{openapi_service.environment}/validation-set/detail/{self.id}"
30
33
  )
31
- self.__openapi_service = openapi_service
34
+ self._openapi_service = openapi_service
35
+ self.validation_rapid_uploader = ValidationRapidUploader(openapi_service)
32
36
 
33
37
  def add_rapid(self, rapid: Rapid):
34
38
  """Add a Rapid to the validation set.
@@ -38,7 +42,7 @@ class RapidataValidationSet:
38
42
  """
39
43
  with tracer.start_as_current_span("RapidataValidationSet.add_rapid"):
40
44
  logger.debug("Adding rapid %s to validation set %s", rapid, self.id)
41
- rapid._add_to_validation_set(self.id, self.__openapi_service)
45
+ self.validation_rapid_uploader.upload_rapid(rapid, self.id)
42
46
  return self
43
47
 
44
48
  def update_dimensions(self, dimensions: list[str]):
@@ -51,7 +55,7 @@ class RapidataValidationSet:
51
55
  logger.debug(
52
56
  "Updating dimensions for validation set %s to %s", self.id, dimensions
53
57
  )
54
- self.__openapi_service.validation_api.validation_set_validation_set_id_patch(
58
+ self._openapi_service.validation_api.validation_set_validation_set_id_patch(
55
59
  self.id, UpdateValidationSetModel(dimensions=dimensions)
56
60
  )
57
61
  return self
@@ -69,7 +73,7 @@ class RapidataValidationSet:
69
73
  logger.debug(
70
74
  "Setting shouldAlert for validation set %s to %s", self.id, should_alert
71
75
  )
72
- self.__openapi_service.validation_api.validation_set_validation_set_id_patch(
76
+ self._openapi_service.validation_api.validation_set_validation_set_id_patch(
73
77
  self.id, UpdateValidationSetModel(shouldAlert=should_alert)
74
78
  )
75
79
  return self
@@ -97,7 +101,7 @@ class RapidataValidationSet:
97
101
  """Deletes the validation set"""
98
102
  with tracer.start_as_current_span("RapidataValidationSet.delete"):
99
103
  logger.info("Deleting ValidationSet '%s'", self)
100
- self.__openapi_service.validation_api.validation_set_validation_set_id_delete(
104
+ self._openapi_service.validation_api.validation_set_validation_set_id_delete(
101
105
  self.id
102
106
  )
103
107
  logger.debug("ValidationSet '%s' has been deleted.", self)
@@ -0,0 +1,101 @@
1
+ from rapidata.rapidata_client.validation.rapids.rapids import Rapid
2
+ from rapidata.service.openapi_service import OpenAPIService
3
+ from rapidata.api_client.models.multi_asset_input_assets_inner import (
4
+ MultiAssetInput,
5
+ MultiAssetInputAssetsInner,
6
+ )
7
+ from rapidata.api_client.models.add_validation_rapid_new_model import (
8
+ AddValidationRapidNewModel,
9
+ )
10
+ from rapidata.api_client.models.add_validation_rapid_model_truth import (
11
+ AddValidationRapidModelTruth,
12
+ )
13
+ from rapidata.api_client.models.create_datapoint_from_files_model_metadata_inner import (
14
+ CreateDatapointFromFilesModelMetadataInner,
15
+ )
16
+ from rapidata.api_client.models.existing_asset_input import ExistingAssetInput
17
+ from rapidata.rapidata_client.datapoints._asset_uploader import AssetUploader
18
+ from rapidata.rapidata_client.datapoints.metadata import (
19
+ PromptMetadata,
20
+ MediaAssetMetadata,
21
+ SelectWordsMetadata,
22
+ Metadata,
23
+ )
24
+ from rapidata.api_client.models.add_validation_rapid_new_model_asset import (
25
+ AddValidationRapidNewModelAsset,
26
+ )
27
+ from rapidata.api_client.models.text_asset_input import TextAssetInput
28
+ from rapidata.api_client.models.add_validation_rapid_model_payload import (
29
+ AddValidationRapidModelPayload,
30
+ )
31
+
32
+
33
+ class ValidationRapidUploader:
34
+ def __init__(self, openapi_service: OpenAPIService):
35
+ self.openapi_service = openapi_service
36
+ self.asset_uploader = AssetUploader(openapi_service)
37
+
38
+ def upload_rapid(self, rapid: Rapid, validation_set_id: str) -> None:
39
+ metadata = self._get_metadata(rapid)
40
+
41
+ uploaded_asset = (
42
+ self._handle_media_rapid(rapid)
43
+ if rapid.data_type == "media"
44
+ else self._handle_text_rapid(rapid)
45
+ )
46
+
47
+ self.openapi_service.validation_api.validation_set_validation_set_id_rapid_new_post(
48
+ validation_set_id=validation_set_id,
49
+ add_validation_rapid_new_model=AddValidationRapidNewModel(
50
+ asset=uploaded_asset,
51
+ metadata=metadata,
52
+ payload=self._get_payload(rapid),
53
+ truth=AddValidationRapidModelTruth(actual_instance=rapid.truth),
54
+ featureFlags=(
55
+ [setting._to_feature_flag() for setting in rapid.settings]
56
+ if rapid.settings
57
+ else None
58
+ ),
59
+ ),
60
+ )
61
+
62
+ def _get_payload(self, rapid: Rapid) -> AddValidationRapidModelPayload:
63
+ if isinstance(rapid.payload, dict):
64
+ return AddValidationRapidModelPayload(actual_instance=rapid.payload)
65
+ return AddValidationRapidModelPayload(actual_instance=rapid.payload.to_dict())
66
+
67
+ def _get_metadata(
68
+ self, rapid: Rapid
69
+ ) -> list[CreateDatapointFromFilesModelMetadataInner]:
70
+ rapid_metadata: list[Metadata] = []
71
+ if rapid.context:
72
+ rapid_metadata.append(PromptMetadata(prompt=rapid.context))
73
+ if rapid.sentence:
74
+ rapid_metadata.append(SelectWordsMetadata(select_words=rapid.sentence))
75
+ if rapid.media_context:
76
+ rapid_metadata.append(
77
+ MediaAssetMetadata(
78
+ internal_file_name=self.asset_uploader.upload_asset(
79
+ rapid.media_context
80
+ )
81
+ )
82
+ )
83
+
84
+ metadata = [
85
+ CreateDatapointFromFilesModelMetadataInner(
86
+ actual_instance=metadata.to_model()
87
+ )
88
+ for metadata in rapid_metadata
89
+ ]
90
+
91
+ return metadata
92
+
93
+ def _handle_text_rapid(self, rapid: Rapid) -> AddValidationRapidNewModelAsset:
94
+ return AddValidationRapidNewModelAsset(
95
+ actual_instance=self.asset_uploader.get_uploaded_text_input(rapid.asset),
96
+ )
97
+
98
+ def _handle_media_rapid(self, rapid: Rapid) -> AddValidationRapidNewModelAsset:
99
+ return AddValidationRapidNewModelAsset(
100
+ actual_instance=self.asset_uploader.get_uploaded_asset_input(rapid.asset),
101
+ )
@@ -1,19 +1,43 @@
1
1
  from rapidata.api_client.models.box_shape import BoxShape
2
+ from pydantic import BaseModel, field_validator, model_validator
2
3
 
3
4
 
4
- class Box:
5
+ class Box(BaseModel):
5
6
  """
6
- Used in the Locate and Draw Validation sets. All coordinates are in pixels.
7
+ Used in the Locate and Draw Validation sets. All coordinates are in ratio of the image size (0.0 to 1.0).
7
8
 
8
9
  Args:
9
- x_min (float): The minimum x value of the box.
10
- y_min (float): The minimum y value of the box.
11
- x_max (float): The maximum x value of the box.
12
- y_max (float): The maximum y value of the box.
10
+ x_min (float): The minimum x value of the box in ratio of the image size.
11
+ y_min (float): The minimum y value of the box in ratio of the image size.
12
+ x_max (float): The maximum x value of the box in ratio of the image size.
13
+ y_max (float): The maximum y value of the box in ratio of the image size.
13
14
  """
14
15
 
15
- def __init__(self, x_min: float, y_min: float, x_max: float, y_max: float):
16
- self.x_min = x_min
17
- self.y_min = y_min
18
- self.x_max = x_max
19
- self.y_max = y_max
16
+ x_min: float
17
+ y_min: float
18
+ x_max: float
19
+ y_max: float
20
+
21
+ @field_validator("x_min", "y_min", "x_max", "y_max")
22
+ @classmethod
23
+ def coordinates_between_zero_and_one(cls, v: float) -> float:
24
+ if not (0.0 <= v <= 1.0):
25
+ raise ValueError("Box coordinates must be between 0 and 1")
26
+ return v
27
+
28
+ @model_validator(mode="after")
29
+ def check_min_less_than_max(self) -> "Box":
30
+ if self.x_min >= self.x_max:
31
+ raise ValueError("x_min must be less than x_max")
32
+ if self.y_min >= self.y_max:
33
+ raise ValueError("y_min must be less than y_max")
34
+ return self
35
+
36
+ def to_model(self) -> BoxShape:
37
+ return BoxShape(
38
+ _t="BoxShape",
39
+ xMin=self.x_min * 100,
40
+ yMin=self.y_min * 100,
41
+ xMax=self.x_max * 100,
42
+ yMax=self.y_max * 100,
43
+ )