rapidata 2.1.4__py3-none-any.whl → 2.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rapidata might be problematic. Click here for more details.

Files changed (35) hide show
  1. rapidata/api_client/__init__.py +7 -0
  2. rapidata/api_client/models/__init__.py +7 -0
  3. rapidata/api_client/models/add_validation_rapid_model_payload.py +30 -16
  4. rapidata/api_client/models/add_validation_rapid_model_truth.py +30 -16
  5. rapidata/api_client/models/query_validation_rapids_result.py +29 -2
  6. rapidata/api_client/models/query_validation_rapids_result_payload.py +252 -0
  7. rapidata/api_client/models/query_validation_rapids_result_truth.py +258 -0
  8. rapidata/api_client/models/rapid_answer_result.py +31 -17
  9. rapidata/api_client/models/rapid_result_model_result.py +31 -17
  10. rapidata/api_client/models/scrub_payload.py +96 -0
  11. rapidata/api_client/models/scrub_range.py +89 -0
  12. rapidata/api_client/models/scrub_rapid_blueprint.py +96 -0
  13. rapidata/api_client/models/scrub_result.py +98 -0
  14. rapidata/api_client/models/scrub_truth.py +104 -0
  15. rapidata/api_client/models/simple_workflow_config_model_blueprint.py +30 -16
  16. rapidata/api_client/models/simple_workflow_model_blueprint.py +30 -16
  17. rapidata/api_client/models/validation_import_post_request_blueprint.py +30 -16
  18. rapidata/api_client_README.md +7 -0
  19. rapidata/rapidata_client/assets/_media_asset.py +45 -0
  20. rapidata/rapidata_client/order/_rapidata_dataset.py +4 -1
  21. rapidata/rapidata_client/order/rapidata_order_manager.py +46 -1
  22. rapidata/rapidata_client/selection/rapidata_selections.py +4 -2
  23. rapidata/rapidata_client/validation/_validation_rapid_parts.py +5 -0
  24. rapidata/rapidata_client/validation/_validation_set_builder.py +107 -7
  25. rapidata/rapidata_client/validation/rapidata_validation_set.py +4 -0
  26. rapidata/rapidata_client/validation/rapids/rapids.py +34 -0
  27. rapidata/rapidata_client/validation/rapids/rapids_manager.py +29 -1
  28. rapidata/rapidata_client/validation/validation_set_manager.py +50 -0
  29. rapidata/rapidata_client/workflow/__init__.py +1 -0
  30. rapidata/rapidata_client/workflow/_select_words_workflow.py +2 -2
  31. rapidata/rapidata_client/workflow/_timestamp_workflow.py +34 -0
  32. {rapidata-2.1.4.dist-info → rapidata-2.2.0.dist-info}/METADATA +2 -1
  33. {rapidata-2.1.4.dist-info → rapidata-2.2.0.dist-info}/RECORD +35 -27
  34. {rapidata-2.1.4.dist-info → rapidata-2.2.0.dist-info}/WHEEL +1 -1
  35. {rapidata-2.1.4.dist-info → rapidata-2.2.0.dist-info}/LICENSE +0 -0
@@ -25,12 +25,13 @@ from rapidata.api_client.models.line_rapid_blueprint import LineRapidBlueprint
25
25
  from rapidata.api_client.models.locate_rapid_blueprint import LocateRapidBlueprint
26
26
  from rapidata.api_client.models.named_entity_rapid_blueprint import NamedEntityRapidBlueprint
27
27
  from rapidata.api_client.models.polygon_rapid_blueprint import PolygonRapidBlueprint
28
+ from rapidata.api_client.models.scrub_rapid_blueprint import ScrubRapidBlueprint
28
29
  from rapidata.api_client.models.transcription_rapid_blueprint import TranscriptionRapidBlueprint
29
30
  from pydantic import StrictStr, Field
30
31
  from typing import Union, List, Set, Optional, Dict
31
32
  from typing_extensions import Literal, Self
32
33
 
33
- VALIDATIONIMPORTPOSTREQUESTBLUEPRINT_ONE_OF_SCHEMAS = ["AttachCategoryRapidBlueprint", "BoundingBoxRapidBlueprint", "CompareRapidBlueprint", "FreeTextRapidBlueprint", "LineRapidBlueprint", "LocateRapidBlueprint", "NamedEntityRapidBlueprint", "PolygonRapidBlueprint", "TranscriptionRapidBlueprint"]
34
+ VALIDATIONIMPORTPOSTREQUESTBLUEPRINT_ONE_OF_SCHEMAS = ["AttachCategoryRapidBlueprint", "BoundingBoxRapidBlueprint", "CompareRapidBlueprint", "FreeTextRapidBlueprint", "LineRapidBlueprint", "LocateRapidBlueprint", "NamedEntityRapidBlueprint", "PolygonRapidBlueprint", "ScrubRapidBlueprint", "TranscriptionRapidBlueprint"]
34
35
 
35
36
  class ValidationImportPostRequestBlueprint(BaseModel):
36
37
  """
@@ -38,24 +39,26 @@ class ValidationImportPostRequestBlueprint(BaseModel):
38
39
  """
39
40
  # data type: TranscriptionRapidBlueprint
40
41
  oneof_schema_1_validator: Optional[TranscriptionRapidBlueprint] = None
42
+ # data type: ScrubRapidBlueprint
43
+ oneof_schema_2_validator: Optional[ScrubRapidBlueprint] = None
41
44
  # data type: PolygonRapidBlueprint
42
- oneof_schema_2_validator: Optional[PolygonRapidBlueprint] = None
45
+ oneof_schema_3_validator: Optional[PolygonRapidBlueprint] = None
43
46
  # data type: NamedEntityRapidBlueprint
44
- oneof_schema_3_validator: Optional[NamedEntityRapidBlueprint] = None
47
+ oneof_schema_4_validator: Optional[NamedEntityRapidBlueprint] = None
45
48
  # data type: LocateRapidBlueprint
46
- oneof_schema_4_validator: Optional[LocateRapidBlueprint] = None
49
+ oneof_schema_5_validator: Optional[LocateRapidBlueprint] = None
47
50
  # data type: LineRapidBlueprint
48
- oneof_schema_5_validator: Optional[LineRapidBlueprint] = None
51
+ oneof_schema_6_validator: Optional[LineRapidBlueprint] = None
49
52
  # data type: FreeTextRapidBlueprint
50
- oneof_schema_6_validator: Optional[FreeTextRapidBlueprint] = None
53
+ oneof_schema_7_validator: Optional[FreeTextRapidBlueprint] = None
51
54
  # data type: CompareRapidBlueprint
52
- oneof_schema_7_validator: Optional[CompareRapidBlueprint] = None
55
+ oneof_schema_8_validator: Optional[CompareRapidBlueprint] = None
53
56
  # data type: AttachCategoryRapidBlueprint
54
- oneof_schema_8_validator: Optional[AttachCategoryRapidBlueprint] = None
57
+ oneof_schema_9_validator: Optional[AttachCategoryRapidBlueprint] = None
55
58
  # data type: BoundingBoxRapidBlueprint
56
- oneof_schema_9_validator: Optional[BoundingBoxRapidBlueprint] = None
57
- actual_instance: Optional[Union[AttachCategoryRapidBlueprint, BoundingBoxRapidBlueprint, CompareRapidBlueprint, FreeTextRapidBlueprint, LineRapidBlueprint, LocateRapidBlueprint, NamedEntityRapidBlueprint, PolygonRapidBlueprint, TranscriptionRapidBlueprint]] = None
58
- one_of_schemas: Set[str] = { "AttachCategoryRapidBlueprint", "BoundingBoxRapidBlueprint", "CompareRapidBlueprint", "FreeTextRapidBlueprint", "LineRapidBlueprint", "LocateRapidBlueprint", "NamedEntityRapidBlueprint", "PolygonRapidBlueprint", "TranscriptionRapidBlueprint" }
59
+ oneof_schema_10_validator: Optional[BoundingBoxRapidBlueprint] = None
60
+ actual_instance: Optional[Union[AttachCategoryRapidBlueprint, BoundingBoxRapidBlueprint, CompareRapidBlueprint, FreeTextRapidBlueprint, LineRapidBlueprint, LocateRapidBlueprint, NamedEntityRapidBlueprint, PolygonRapidBlueprint, ScrubRapidBlueprint, TranscriptionRapidBlueprint]] = None
61
+ one_of_schemas: Set[str] = { "AttachCategoryRapidBlueprint", "BoundingBoxRapidBlueprint", "CompareRapidBlueprint", "FreeTextRapidBlueprint", "LineRapidBlueprint", "LocateRapidBlueprint", "NamedEntityRapidBlueprint", "PolygonRapidBlueprint", "ScrubRapidBlueprint", "TranscriptionRapidBlueprint" }
59
62
 
60
63
  model_config = ConfigDict(
61
64
  validate_assignment=True,
@@ -86,6 +89,11 @@ class ValidationImportPostRequestBlueprint(BaseModel):
86
89
  error_messages.append(f"Error! Input type `{type(v)}` is not `TranscriptionRapidBlueprint`")
87
90
  else:
88
91
  match += 1
92
+ # validate data type: ScrubRapidBlueprint
93
+ if not isinstance(v, ScrubRapidBlueprint):
94
+ error_messages.append(f"Error! Input type `{type(v)}` is not `ScrubRapidBlueprint`")
95
+ else:
96
+ match += 1
89
97
  # validate data type: PolygonRapidBlueprint
90
98
  if not isinstance(v, PolygonRapidBlueprint):
91
99
  error_messages.append(f"Error! Input type `{type(v)}` is not `PolygonRapidBlueprint`")
@@ -128,10 +136,10 @@ class ValidationImportPostRequestBlueprint(BaseModel):
128
136
  match += 1
129
137
  if match > 1:
130
138
  # more than 1 match
131
- raise ValueError("Multiple matches found when setting `actual_instance` in ValidationImportPostRequestBlueprint with oneOf schemas: AttachCategoryRapidBlueprint, BoundingBoxRapidBlueprint, CompareRapidBlueprint, FreeTextRapidBlueprint, LineRapidBlueprint, LocateRapidBlueprint, NamedEntityRapidBlueprint, PolygonRapidBlueprint, TranscriptionRapidBlueprint. Details: " + ", ".join(error_messages))
139
+ raise ValueError("Multiple matches found when setting `actual_instance` in ValidationImportPostRequestBlueprint with oneOf schemas: AttachCategoryRapidBlueprint, BoundingBoxRapidBlueprint, CompareRapidBlueprint, FreeTextRapidBlueprint, LineRapidBlueprint, LocateRapidBlueprint, NamedEntityRapidBlueprint, PolygonRapidBlueprint, ScrubRapidBlueprint, TranscriptionRapidBlueprint. Details: " + ", ".join(error_messages))
132
140
  elif match == 0:
133
141
  # no match
134
- raise ValueError("No match found when setting `actual_instance` in ValidationImportPostRequestBlueprint with oneOf schemas: AttachCategoryRapidBlueprint, BoundingBoxRapidBlueprint, CompareRapidBlueprint, FreeTextRapidBlueprint, LineRapidBlueprint, LocateRapidBlueprint, NamedEntityRapidBlueprint, PolygonRapidBlueprint, TranscriptionRapidBlueprint. Details: " + ", ".join(error_messages))
142
+ raise ValueError("No match found when setting `actual_instance` in ValidationImportPostRequestBlueprint with oneOf schemas: AttachCategoryRapidBlueprint, BoundingBoxRapidBlueprint, CompareRapidBlueprint, FreeTextRapidBlueprint, LineRapidBlueprint, LocateRapidBlueprint, NamedEntityRapidBlueprint, PolygonRapidBlueprint, ScrubRapidBlueprint, TranscriptionRapidBlueprint. Details: " + ", ".join(error_messages))
135
143
  else:
136
144
  return v
137
145
 
@@ -152,6 +160,12 @@ class ValidationImportPostRequestBlueprint(BaseModel):
152
160
  match += 1
153
161
  except (ValidationError, ValueError) as e:
154
162
  error_messages.append(str(e))
163
+ # deserialize data into ScrubRapidBlueprint
164
+ try:
165
+ instance.actual_instance = ScrubRapidBlueprint.from_json(json_str)
166
+ match += 1
167
+ except (ValidationError, ValueError) as e:
168
+ error_messages.append(str(e))
155
169
  # deserialize data into PolygonRapidBlueprint
156
170
  try:
157
171
  instance.actual_instance = PolygonRapidBlueprint.from_json(json_str)
@@ -203,10 +217,10 @@ class ValidationImportPostRequestBlueprint(BaseModel):
203
217
 
204
218
  if match > 1:
205
219
  # more than 1 match
206
- raise ValueError("Multiple matches found when deserializing the JSON string into ValidationImportPostRequestBlueprint with oneOf schemas: AttachCategoryRapidBlueprint, BoundingBoxRapidBlueprint, CompareRapidBlueprint, FreeTextRapidBlueprint, LineRapidBlueprint, LocateRapidBlueprint, NamedEntityRapidBlueprint, PolygonRapidBlueprint, TranscriptionRapidBlueprint. Details: " + ", ".join(error_messages))
220
+ raise ValueError("Multiple matches found when deserializing the JSON string into ValidationImportPostRequestBlueprint with oneOf schemas: AttachCategoryRapidBlueprint, BoundingBoxRapidBlueprint, CompareRapidBlueprint, FreeTextRapidBlueprint, LineRapidBlueprint, LocateRapidBlueprint, NamedEntityRapidBlueprint, PolygonRapidBlueprint, ScrubRapidBlueprint, TranscriptionRapidBlueprint. Details: " + ", ".join(error_messages))
207
221
  elif match == 0:
208
222
  # no match
209
- raise ValueError("No match found when deserializing the JSON string into ValidationImportPostRequestBlueprint with oneOf schemas: AttachCategoryRapidBlueprint, BoundingBoxRapidBlueprint, CompareRapidBlueprint, FreeTextRapidBlueprint, LineRapidBlueprint, LocateRapidBlueprint, NamedEntityRapidBlueprint, PolygonRapidBlueprint, TranscriptionRapidBlueprint. Details: " + ", ".join(error_messages))
223
+ raise ValueError("No match found when deserializing the JSON string into ValidationImportPostRequestBlueprint with oneOf schemas: AttachCategoryRapidBlueprint, BoundingBoxRapidBlueprint, CompareRapidBlueprint, FreeTextRapidBlueprint, LineRapidBlueprint, LocateRapidBlueprint, NamedEntityRapidBlueprint, PolygonRapidBlueprint, ScrubRapidBlueprint, TranscriptionRapidBlueprint. Details: " + ", ".join(error_messages))
210
224
  else:
211
225
  return instance
212
226
 
@@ -220,7 +234,7 @@ class ValidationImportPostRequestBlueprint(BaseModel):
220
234
  else:
221
235
  return json.dumps(self.actual_instance)
222
236
 
223
- def to_dict(self) -> Optional[Union[Dict[str, Any], AttachCategoryRapidBlueprint, BoundingBoxRapidBlueprint, CompareRapidBlueprint, FreeTextRapidBlueprint, LineRapidBlueprint, LocateRapidBlueprint, NamedEntityRapidBlueprint, PolygonRapidBlueprint, TranscriptionRapidBlueprint]]:
237
+ def to_dict(self) -> Optional[Union[Dict[str, Any], AttachCategoryRapidBlueprint, BoundingBoxRapidBlueprint, CompareRapidBlueprint, FreeTextRapidBlueprint, LineRapidBlueprint, LocateRapidBlueprint, NamedEntityRapidBlueprint, PolygonRapidBlueprint, ScrubRapidBlueprint, TranscriptionRapidBlueprint]]:
224
238
  """Returns the dict representation of the actual instance"""
225
239
  if self.actual_instance is None:
226
240
  return None
@@ -349,6 +349,8 @@ Class | Method | HTTP request | Description
349
349
  - [QueryValidationRapidsResult](rapidata/api_client/docs/QueryValidationRapidsResult.md)
350
350
  - [QueryValidationRapidsResultAsset](rapidata/api_client/docs/QueryValidationRapidsResultAsset.md)
351
351
  - [QueryValidationRapidsResultPagedResult](rapidata/api_client/docs/QueryValidationRapidsResultPagedResult.md)
352
+ - [QueryValidationRapidsResultPayload](rapidata/api_client/docs/QueryValidationRapidsResultPayload.md)
353
+ - [QueryValidationRapidsResultTruth](rapidata/api_client/docs/QueryValidationRapidsResultTruth.md)
352
354
  - [QueryValidationSetModel](rapidata/api_client/docs/QueryValidationSetModel.md)
353
355
  - [QueryWorkflowsModel](rapidata/api_client/docs/QueryWorkflowsModel.md)
354
356
  - [RankedDatapointModel](rapidata/api_client/docs/RankedDatapointModel.md)
@@ -361,6 +363,11 @@ Class | Method | HTTP request | Description
361
363
  - [RegisterTemporaryCustomerModel](rapidata/api_client/docs/RegisterTemporaryCustomerModel.md)
362
364
  - [RegisterTemporaryCustomerResult](rapidata/api_client/docs/RegisterTemporaryCustomerResult.md)
363
365
  - [RootFilter](rapidata/api_client/docs/RootFilter.md)
366
+ - [ScrubPayload](rapidata/api_client/docs/ScrubPayload.md)
367
+ - [ScrubRange](rapidata/api_client/docs/ScrubRange.md)
368
+ - [ScrubRapidBlueprint](rapidata/api_client/docs/ScrubRapidBlueprint.md)
369
+ - [ScrubResult](rapidata/api_client/docs/ScrubResult.md)
370
+ - [ScrubTruth](rapidata/api_client/docs/ScrubTruth.md)
364
371
  - [SendCompletionMailStepModel](rapidata/api_client/docs/SendCompletionMailStepModel.md)
365
372
  - [Shape](rapidata/api_client/docs/Shape.md)
366
373
  - [SimpleWorkflowConfig](rapidata/api_client/docs/SimpleWorkflowConfig.md)
@@ -9,6 +9,8 @@ from rapidata.rapidata_client.assets._base_asset import BaseAsset
9
9
  import requests
10
10
  import re
11
11
  from PIL import Image
12
+ from tinytag import TinyTag
13
+ import tempfile
12
14
 
13
15
  class MediaAsset(BaseAsset):
14
16
  """MediaAsset Class
@@ -55,6 +57,49 @@ class MediaAsset(BaseAsset):
55
57
  self.path: str | bytes = path
56
58
  self.name = path
57
59
 
60
+ def get_duration(self) -> int:
61
+ """
62
+ Get the duration of audio/video files in milliseconds.
63
+ Returns 0 for static images.
64
+
65
+ Returns:
66
+ int: Duration in milliseconds for audio/video, 0 for static images
67
+
68
+ Raises:
69
+ ValueError: If the duration cannot be determined
70
+ """
71
+ path_to_check = self.name.lower()
72
+
73
+ # Return 0 for other static images
74
+ if any(path_to_check.endswith(ext) for ext in ('.jpg', '.jpeg', '.png', '.webp', '.gif')):
75
+ return 0
76
+
77
+ try:
78
+ # For URL downloads (bytes), write to temporary file first
79
+ if isinstance(self.path, bytes):
80
+ with tempfile.NamedTemporaryFile(suffix=os.path.splitext(self.name)[1], delete=False) as tmp:
81
+ tmp.write(self.path)
82
+ tmp.flush()
83
+ # Close the file so it can be read
84
+ tmp_path = tmp.name
85
+
86
+ try:
87
+ tag = TinyTag.get(tmp_path)
88
+ finally:
89
+ # Clean up the temporary file
90
+ os.unlink(tmp_path)
91
+ else:
92
+ # For local files, use path directly
93
+ tag = TinyTag.get(self.path)
94
+
95
+ if tag.duration is None:
96
+ raise ValueError("Could not read duration from file")
97
+
98
+ return int(tag.duration * 1000) # Convert to milliseconds
99
+
100
+ except Exception as e:
101
+ raise ValueError(f"Could not determine media duration: {str(e)}")
102
+
58
103
  def get_image_dimension(self) -> tuple[int, int] | None:
59
104
  """
60
105
  Get the dimensions (width, height) of an image file.
@@ -48,7 +48,10 @@ class RapidataDataset:
48
48
  textSources=texts
49
49
  )
50
50
 
51
- self.openapi_service.dataset_api.dataset_creat_text_datapoint_post(model)
51
+ upload_response = self.openapi_service.dataset_api.dataset_creat_text_datapoint_post(model)
52
+
53
+ if upload_response.errors:
54
+ raise ValueError(f"Error uploading text datapoint: {upload_response.errors}")
52
55
 
53
56
  total_uploads = len(text_assets)
54
57
  with ThreadPoolExecutor(max_workers=max_workers) as executor:
@@ -19,7 +19,8 @@ from rapidata.rapidata_client.workflow import (
19
19
  FreeTextWorkflow,
20
20
  SelectWordsWorkflow,
21
21
  LocateWorkflow,
22
- DrawWorkflow)
22
+ DrawWorkflow,
23
+ TimestampWorkflow)
23
24
  from rapidata.rapidata_client.selection.validation_selection import ValidationSelection
24
25
  from rapidata.rapidata_client.selection.labeling_selection import LabelingSelection
25
26
  from rapidata.rapidata_client.assets import MediaAsset, TextAsset, MultiAsset
@@ -413,6 +414,50 @@ class RapidataOrderManager:
413
414
  selections=selections,
414
415
  settings=settings
415
416
  )
417
+
418
+ def create_timestamp_order(self,
419
+ name: str,
420
+ instruction: str,
421
+ datapoints: list[str],
422
+ responses_per_datapoint: int = 10,
423
+ contexts: list[str] | None = None,
424
+ filters: Sequence[RapidataFilter] = [],
425
+ settings: Sequence[RapidataSetting] = [],
426
+ selections: Sequence[RapidataSelection] | None = None,
427
+ ) -> RapidataOrder:
428
+ """Create a timestamp order.
429
+
430
+ Args:
431
+ name (str): The name of the order.
432
+ instruction (str): The instruction for the timestamp task. Will be shown along side each datapoint.
433
+ datapoints (list[str]): The list of datapoints for the timestamp - each datapoint will be labeled.
434
+ responses_per_datapoint (int, optional): The number of responses that will be collected per datapoint. Defaults to 10.
435
+ contexts (list[str], optional): The list of contexts for the comparison. Defaults to None.\n
436
+ If provided has to be the same length as datapoints and will be shown in addition to the instruction. (Therefore will be different for each datapoint)
437
+ Will be match up with the datapoints using the list index.
438
+ filters (Sequence[RapidataFilter], optional): The list of filters for the timestamp. Defaults to []. Decides who the tasks should be shown to.
439
+ settings (Sequence[RapidataSetting], optional): The list of settings for the timestamp. Defaults to []. Decides how the tasks should be shown.
440
+ selections (Sequence[RapidataSelection], optional): The list of selections for the timestamp. Defaults to None. Decides in what order the tasks should be shown.
441
+ """
442
+
443
+ assets = [MediaAsset(path=path) for path in datapoints]
444
+
445
+ for asset in assets:
446
+ if not asset.get_duration():
447
+ raise ValueError("The datapoints for this order must have a duration. (e.g. video or audio)")
448
+
449
+ return self.__create_general_order(
450
+ name=name,
451
+ workflow=TimestampWorkflow(
452
+ instruction=instruction
453
+ ),
454
+ assets=assets,
455
+ responses_per_datapoint=responses_per_datapoint,
456
+ contexts=contexts,
457
+ filters=filters,
458
+ selections=selections,
459
+ settings=settings
460
+ )
416
461
 
417
462
  def get_order_by_id(self, order_id: str) -> RapidataOrder:
418
463
  """Get an order by ID.
@@ -8,14 +8,16 @@ from rapidata.rapidata_client.selection import (
8
8
  class RapidataSelections:
9
9
  """RapidataSelections Classes
10
10
 
11
+ Selections are used to define what type of tasks and in what order they are shown to the user.
12
+
11
13
  Attributes:
12
- demographic (DemographicSelection): The DemographicSelection instance.
13
14
  labeling (LabelingSelection): The LabelingSelection instance.
14
15
  validation (ValidationSelection): The ValidationSelection instance.
15
16
  conditional_validation (ConditionalValidationSelection): The ConditionalValidationSelection instance.
17
+ demographic (DemographicSelection): The DemographicSelection instance.
16
18
  capped (CappedSelection): The CappedSelection instance."""
17
- demographic = DemographicSelection
18
19
  labeling = LabelingSelection
19
20
  validation = ValidationSelection
20
21
  conditional_validation = ConditionalValidationSelection
22
+ demographic = DemographicSelection
21
23
  capped = CappedSelection
@@ -19,6 +19,9 @@ from rapidata.api_client.models.polygon_payload import PolygonPayload
19
19
  from rapidata.api_client.models.polygon_truth import PolygonTruth
20
20
  from rapidata.api_client.models.transcription_payload import TranscriptionPayload
21
21
  from rapidata.api_client.models.transcription_truth import TranscriptionTruth
22
+ from rapidata.api_client.models.scrub_payload import ScrubPayload
23
+ from rapidata.api_client.models.scrub_truth import ScrubTruth
24
+
22
25
  from rapidata.rapidata_client.assets._media_asset import MediaAsset
23
26
  from rapidata.rapidata_client.assets._multi_asset import MultiAsset
24
27
  from rapidata.rapidata_client.assets._text_asset import TextAsset
@@ -40,6 +43,7 @@ class ValidatioRapidParts:
40
43
  | NamedEntityPayload
41
44
  | PolygonPayload
42
45
  | TranscriptionPayload
46
+ | ScrubPayload
43
47
  )
44
48
  truths: (
45
49
  AttachCategoryTruth
@@ -51,6 +55,7 @@ class ValidatioRapidParts:
51
55
  | NamedEntityTruth
52
56
  | PolygonTruth
53
57
  | TranscriptionTruth
58
+ | ScrubTruth
54
59
  )
55
60
  metadata: Sequence[Metadata]
56
61
  randomCorrectProbability: float
@@ -6,6 +6,9 @@ from rapidata.api_client.models.compare_truth import CompareTruth
6
6
  from rapidata.api_client.models.transcription_payload import TranscriptionPayload
7
7
  from rapidata.api_client.models.transcription_truth import TranscriptionTruth
8
8
  from rapidata.api_client.models.transcription_word import TranscriptionWord
9
+ from rapidata.api_client.models.scrub_payload import ScrubPayload
10
+ from rapidata.api_client.models.scrub_truth import ScrubTruth
11
+ from rapidata.api_client.models.scrub_range import ScrubRange
9
12
  from rapidata.api_client.models.locate_payload import LocatePayload
10
13
  from rapidata.api_client.models.locate_box_truth import LocateBoxTruth
11
14
  from rapidata.api_client.models.line_payload import LinePayload
@@ -26,7 +29,8 @@ from rapidata.rapidata_client.validation.rapids.rapids import (
26
29
  CompareRapid,
27
30
  SelectWordsRapid,
28
31
  LocateRapid,
29
- DrawRapid
32
+ DrawRapid,
33
+ TimestampRapid
30
34
  )
31
35
  from typing import Sequence
32
36
 
@@ -106,10 +110,13 @@ class ValidationSetBuilder:
106
110
  self.__add_select_words_rapid(rapid.asset, rapid.instruction, rapid.sentence, rapid.truths, rapid.strict_grading)
107
111
 
108
112
  elif isinstance(rapid, LocateRapid):
109
- self.__add_locate_rapid(rapid.asset, rapid.instruction, rapid.truths)
113
+ self.__add_locate_rapid(rapid.asset, rapid.instruction, rapid.truths, rapid.metadata)
110
114
 
111
115
  elif isinstance(rapid, DrawRapid):
112
- self.__add_draw_rapid(rapid.asset, rapid.instruction, rapid.truths)
116
+ self.__add_draw_rapid(rapid.asset, rapid.instruction, rapid.truths, rapid.metadata)
117
+
118
+ elif isinstance(rapid, TimestampRapid):
119
+ self.__add_timestamp_rapid(rapid.asset, rapid.instruction, rapid.truths, rapid.metadata)
113
120
 
114
121
  else:
115
122
  raise ValueError("Unsupported rapid type")
@@ -262,7 +269,8 @@ class ValidationSetBuilder:
262
269
  self,
263
270
  asset: MediaAsset,
264
271
  instruction: str,
265
- truths: list[Box]
272
+ truths: list[Box],
273
+ metadata: Sequence[Metadata] = [],
266
274
  ):
267
275
  """Add a locate rapid to the validation set.
268
276
 
@@ -301,7 +309,7 @@ class ValidationSetBuilder:
301
309
  instruction=instruction,
302
310
  payload=payload,
303
311
  truths=model_truth,
304
- metadata=[],
312
+ metadata=metadata,
305
313
  randomCorrectProbability=coverage,
306
314
  asset=asset,
307
315
  )
@@ -311,7 +319,8 @@ class ValidationSetBuilder:
311
319
  self,
312
320
  asset: MediaAsset,
313
321
  instruction: str,
314
- truths: list[Box]
322
+ truths: list[Box],
323
+ metadata: Sequence[Metadata] = [],
315
324
  ):
316
325
  """Add a draw rapid to the validation set.
317
326
 
@@ -348,12 +357,62 @@ class ValidationSetBuilder:
348
357
  instruction=instruction,
349
358
  payload=payload,
350
359
  truths=model_truth,
351
- metadata=[],
360
+ metadata=metadata,
352
361
  randomCorrectProbability=coverage,
353
362
  asset=asset,
354
363
  )
355
364
  )
356
365
 
366
+ def __add_timestamp_rapid(
367
+ self,
368
+ asset: MediaAsset,
369
+ instruction: str,
370
+ truths: list[tuple[int, int]],
371
+ metadata: Sequence[Metadata] = [],
372
+ ):
373
+ """Add a timestamp rapid to the validation set.
374
+
375
+ Args:
376
+ instruction (str): The instruction for the timestamp rapid.
377
+ asset (MediaAsset): The asset for the rapid.
378
+ truths (list[tuple[int, int]]): The truths for the rapid.
379
+ This is a list of tuples where the first element is the start of the interval and the second element is the end of the interval.
380
+ The intervals are in miliseconds.
381
+ metadata (Sequence[Metadata], optional): The metadata for the rapid. Defaults to an empty list.
382
+
383
+ Returns:
384
+ ValidationSetBuilder: The ValidationSetBuilder instance.
385
+ """
386
+ for truth in truths:
387
+ if len(truth) != 2:
388
+ raise ValueError("The truths per datapoint must be a tuple of exactly two integers.")
389
+ if truth[0] > truth[1]:
390
+ raise ValueError("The start of the interval must be smaller than the end of the interval.")
391
+
392
+ payload = ScrubPayload(
393
+ _t="ScrubPayload",
394
+ target=instruction
395
+ )
396
+
397
+ model_truth = ScrubTruth(
398
+ _t="ScrubTruth",
399
+ validRanges=[ScrubRange(
400
+ start=truth[0],
401
+ end=truth[1]
402
+ ) for truth in truths]
403
+ )
404
+
405
+ self._rapid_parts.append(
406
+ ValidatioRapidParts(
407
+ instruction=instruction,
408
+ payload=payload,
409
+ truths=model_truth,
410
+ metadata=metadata,
411
+ randomCorrectProbability=self._calculate_coverage_ratio(asset.get_duration(), truths),
412
+ asset=asset,
413
+ )
414
+ )
415
+
357
416
 
358
417
  def _calculate_boxes_coverage(self, boxes: list[Box], image_width: int, image_height: int) -> float:
359
418
  if not boxes:
@@ -369,3 +428,44 @@ class ValidationSetBuilder:
369
428
 
370
429
  total_covered = len(pixels)
371
430
  return total_covered / (image_width * image_height)
431
+
432
+ def _calculate_coverage_ratio(self, total_duration: int, subsections: list[tuple[int, int]]) -> float:
433
+ """
434
+ Calculate the ratio of total_duration that is covered by subsections, handling overlaps.
435
+
436
+ Args:
437
+ total_duration: The total duration to consider
438
+ subsections: List of tuples containing (start, end) times
439
+
440
+ Returns:
441
+ float: Ratio of coverage (0 to 1)
442
+ """
443
+ if not subsections:
444
+ return 0.0
445
+
446
+ # Sort subsections by start time and clamp to valid range
447
+ sorted_ranges = sorted(
448
+ (max(0, start), min(end, total_duration))
449
+ for start, end in subsections
450
+ )
451
+
452
+ # Merge overlapping ranges
453
+ merged_ranges = []
454
+ current_range = list(sorted_ranges[0])
455
+
456
+ for next_start, next_end in sorted_ranges[1:]:
457
+ current_start, current_end = current_range
458
+
459
+ # If ranges overlap or are adjacent
460
+ if next_start <= current_end:
461
+ current_range[1] = max(current_end, next_end)
462
+ else:
463
+ merged_ranges.append(current_range)
464
+ current_range = [next_start, next_end]
465
+
466
+ merged_ranges.append(current_range)
467
+
468
+ # Calculate total coverage
469
+ total_coverage = sum(end - start for start, end in merged_ranges)
470
+
471
+ return total_coverage / total_duration
@@ -34,6 +34,8 @@ from rapidata.api_client.models.polygon_truth import PolygonTruth
34
34
  from rapidata.api_client.models.transcription_payload import TranscriptionPayload
35
35
  from rapidata.api_client.models.transcription_truth import TranscriptionTruth
36
36
  from rapidata.api_client.models.transcription_word import TranscriptionWord
37
+ from rapidata.api_client.models.scrub_payload import ScrubPayload
38
+ from rapidata.api_client.models.scrub_truth import ScrubTruth
37
39
  from rapidata.rapidata_client.assets._media_asset import MediaAsset
38
40
  from rapidata.rapidata_client.assets._multi_asset import MultiAsset
39
41
  from rapidata.rapidata_client.assets._text_asset import TextAsset
@@ -90,6 +92,7 @@ class RapidataValidationSet:
90
92
  | NamedEntityPayload
91
93
  | PolygonPayload
92
94
  | TranscriptionPayload
95
+ | ScrubPayload
93
96
  ),
94
97
  truths: (
95
98
  AttachCategoryTruth
@@ -101,6 +104,7 @@ class RapidataValidationSet:
101
104
  | NamedEntityTruth
102
105
  | PolygonTruth
103
106
  | TranscriptionTruth
107
+ | ScrubTruth
104
108
  ),
105
109
  metadata: Sequence[Metadata],
106
110
  asset: MediaAsset | TextAsset | MultiAsset,
@@ -92,3 +92,37 @@ class DrawRapid(Rapid):
92
92
  self.instruction = instruction
93
93
  self.asset = asset
94
94
  self.truths = truths
95
+ self.metadata = metadata
96
+
97
+ class TimestampRapid(Rapid):
98
+ """
99
+ Used to have the labeler timestamp a video or audio file.
100
+
101
+ Args:
102
+ instruction (str): The instruction for the labeler.
103
+ truths (list[tuple[int, int]]): The possible accepted timestamps intervals for the labeler (in miliseconds).
104
+ The first element of the tuple is the start of the interval and the second element is the end of the interval.
105
+ asset (MediaAsset): The asset that the labeler is timestamping.
106
+ metadata (Sequence[Metadata]): The metadata that is attached to the rapid.
107
+ """
108
+ def __init__(self, instruction: str, truths: list[tuple[int, int]], asset: MediaAsset, metadata: Sequence[Metadata]):
109
+ if not asset.get_duration():
110
+ raise ValueError("The datapoints must have a duration. (e.g. video or audio)")
111
+
112
+ if not isinstance(truths, list):
113
+ raise ValueError("The truths must be a list of tuples.")
114
+
115
+ for truth in truths:
116
+ if len(truth) != 2 or not all(isinstance(x, int) for x in truth):
117
+ raise ValueError("The truths per datapoint must be a tuple of exactly two integers.")
118
+ if truth[0] >= truth[1]:
119
+ raise ValueError("The start of the interval must be smaller than the end of the interval.")
120
+ if truth[0] < 0:
121
+ raise ValueError("The start of the interval must be greater than or equal to 0.")
122
+ if truth[1] > asset.get_duration():
123
+ raise ValueError("The end of the interval can not be greater than the duration of the datapoint.")
124
+
125
+ self.instruction = instruction
126
+ self.truths = truths
127
+ self.asset = asset
128
+ self.metadata = metadata
@@ -4,7 +4,8 @@ from rapidata.rapidata_client.validation.rapids.rapids import (
4
4
  CompareRapid,
5
5
  SelectWordsRapid,
6
6
  LocateRapid,
7
- DrawRapid)
7
+ DrawRapid,
8
+ TimestampRapid)
8
9
  from rapidata.rapidata_client.assets import MediaAsset, TextAsset, MultiAsset
9
10
  from rapidata.rapidata_client.metadata import Metadata
10
11
  from rapidata.rapidata_client.validation.rapids.box import Box
@@ -161,3 +162,30 @@ class RapidsManager:
161
162
  asset=asset,
162
163
  metadata=metadata,
163
164
  )
165
+
166
+ def timestamp_rapid(self,
167
+ instruction: str,
168
+ truths: list[tuple[int, int]],
169
+ datapoint: str,
170
+ metadata: Sequence[Metadata] = []
171
+ ) -> TimestampRapid:
172
+ """Build a timestamp rapid
173
+
174
+ Args:
175
+ instruction (str): The instruction for the labeler.
176
+ truths (list[tuple[int, int]]): The possible accepted timestamps intervals for the labeler (in miliseconds).
177
+ The first element of the tuple is the start of the interval and the second element is the end of the interval.
178
+ datapoint (str): The asset that the labeler will be timestamping.
179
+ metadata (Sequence[Metadata], optional): The metadata that is attached to the rapid. Defaults to [].
180
+ """
181
+
182
+ asset = MediaAsset(datapoint)
183
+
184
+ return TimestampRapid(
185
+ instruction=instruction,
186
+ truths=truths,
187
+ asset=asset,
188
+ metadata=metadata,
189
+ )
190
+
191
+
@@ -291,6 +291,56 @@ class ValidationSetManager:
291
291
 
292
292
  return validation_set_builder._submit(print_confirmation)
293
293
 
294
+
295
+ def create_timestamp_set(self,
296
+ name: str,
297
+ instruction: str,
298
+ truths: list[list[tuple[int, int]]],
299
+ datapoints: list[str],
300
+ contexts: list[str] | None = None,
301
+ print_confirmation: bool = True
302
+ ) -> RapidataValidationSet:
303
+ """Create a timestamp validation set.
304
+
305
+ Args:
306
+ name (str): The name of the validation set. (will not be shown to the labeler)
307
+ instruction (str): The instruction to show to the labeler.
308
+ truths (list[list[tuple[int, int]]]): The truths for each datapoint defined as start and endpoint based on miliseconds.
309
+ Outher list is for each datapoint, inner list is for each truth.\n
310
+ example:
311
+ datapoints: ["datapoint1", "datapoint2"]
312
+ truths: [[(0, 10)], [(20, 30)]] -> first datapoint the correct interval is from 0 to 10, second datapoint the correct interval is from 20 to 30
313
+ datapoints (list[str]): The datapoints that will be used for validation.
314
+ contexts (list[str], optional): The contexts for each datapoint. Defaults to None.
315
+ print_confirmation (bool, optional): Whether to print a confirmation message that validation set has been created. Defaults to True.
316
+ """
317
+
318
+ if len(datapoints) != len(truths):
319
+ raise ValueError("The number of datapoints and truths must be equal")
320
+
321
+ if not all([isinstance(truth, list) for truth in truths]):
322
+ raise ValueError("Truths must be a list of lists")
323
+
324
+ if contexts and len(contexts) != len(datapoints):
325
+ raise ValueError("The number of contexts and datapoints must be equal")
326
+
327
+ rapids = []
328
+ for i in range(len(datapoints)):
329
+ rapids.append(
330
+ self.rapid.timestamp_rapid(
331
+ instruction=instruction,
332
+ truths=truths[i],
333
+ datapoint=datapoints[i],
334
+ metadata=[PromptMetadata(contexts[i])] if contexts else []
335
+ )
336
+ )
337
+
338
+ validation_set_builder = ValidationSetBuilder(name, self.__openapi_service)
339
+ for rapid in rapids:
340
+ validation_set_builder._add_rapid(rapid)
341
+
342
+ return validation_set_builder._submit(print_confirmation)
343
+
294
344
  def create_mixed_set(self,
295
345
  name: str,
296
346
  rapids: Sequence[Rapid],
@@ -6,3 +6,4 @@ from ._compare_workflow import CompareWorkflow
6
6
  from ._free_text_workflow import FreeTextWorkflow
7
7
  from ._select_words_workflow import SelectWordsWorkflow
8
8
  from ._evaluation_workflow import EvaluationWorkflow
9
+ from ._timestamp_workflow import TimestampWorkflow