rapidata 0.5.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. rapidata/__init__.py +3 -0
  2. rapidata/api_client/__init__.py +3 -1
  3. rapidata/api_client/api/validation_api.py +276 -0
  4. rapidata/api_client/models/__init__.py +3 -1
  5. rapidata/api_client/models/add_campaign_model.py +3 -3
  6. rapidata/api_client/models/add_validation_text_rapid_model.py +118 -0
  7. rapidata/api_client/models/capped_selection.py +108 -0
  8. rapidata/api_client/models/capped_selection_selections_inner.py +198 -0
  9. rapidata/api_client/models/create_order_model.py +3 -3
  10. rapidata/api_client_README.md +4 -1
  11. rapidata/rapidata_client/__init__.py +1 -0
  12. rapidata/rapidata_client/assets/__init__.py +8 -0
  13. rapidata/rapidata_client/assets/base_asset.py +11 -0
  14. rapidata/rapidata_client/assets/media_asset.py +33 -0
  15. rapidata/rapidata_client/assets/multi_asset.py +44 -0
  16. rapidata/rapidata_client/assets/text_asset.py +25 -0
  17. rapidata/rapidata_client/dataset/rapidata_dataset.py +4 -4
  18. rapidata/rapidata_client/dataset/rapidata_validation_set.py +54 -27
  19. rapidata/rapidata_client/dataset/validation_rapid_parts.py +4 -1
  20. rapidata/rapidata_client/dataset/validation_set_builder.py +41 -24
  21. rapidata/rapidata_client/order/rapidata_order_builder.py +98 -33
  22. rapidata/rapidata_client/rapidata_client.py +24 -0
  23. rapidata/rapidata_client/simple_builders/simple_classification_builders.py +122 -0
  24. rapidata/rapidata_client/simple_builders/simple_compare_builders.py +86 -0
  25. {rapidata-0.5.1.dist-info → rapidata-1.0.0.dist-info}/METADATA +1 -1
  26. {rapidata-0.5.1.dist-info → rapidata-1.0.0.dist-info}/RECORD +28 -18
  27. {rapidata-0.5.1.dist-info → rapidata-1.0.0.dist-info}/LICENSE +0 -0
  28. {rapidata-0.5.1.dist-info → rapidata-1.0.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,198 @@
1
+ # coding: utf-8
2
+
3
+ """
4
+ Rapidata.Dataset
5
+
6
+ No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
7
+
8
+ The version of the OpenAPI document: v1
9
+ Generated by OpenAPI Generator (https://openapi-generator.tech)
10
+
11
+ Do not edit the class manually.
12
+ """ # noqa: E501
13
+
14
+
15
+ from __future__ import annotations
16
+ import json
17
+ import pprint
18
+ from pydantic import BaseModel, ConfigDict, Field, StrictStr, ValidationError, field_validator
19
+ from typing import Any, List, Optional
20
+ from rapidata.api_client.models.conditional_validation_selection import ConditionalValidationSelection
21
+ from rapidata.api_client.models.demographic_selection import DemographicSelection
22
+ from rapidata.api_client.models.labeling_selection import LabelingSelection
23
+ from rapidata.api_client.models.static_selection import StaticSelection
24
+ from rapidata.api_client.models.validation_selection import ValidationSelection
25
+ from pydantic import StrictStr, Field
26
+ from typing import Union, List, Set, Optional, Dict
27
+ from typing_extensions import Literal, Self
28
+
29
+ CAPPEDSELECTIONSELECTIONSINNER_ONE_OF_SCHEMAS = ["CappedSelection", "ConditionalValidationSelection", "DemographicSelection", "LabelingSelection", "StaticSelection", "ValidationSelection"]
30
+
31
+ class CappedSelectionSelectionsInner(BaseModel):
32
+ """
33
+ CappedSelectionSelectionsInner
34
+ """
35
+ # data type: CappedSelection
36
+ oneof_schema_1_validator: Optional[CappedSelection] = None
37
+ # data type: ConditionalValidationSelection
38
+ oneof_schema_2_validator: Optional[ConditionalValidationSelection] = None
39
+ # data type: DemographicSelection
40
+ oneof_schema_3_validator: Optional[DemographicSelection] = None
41
+ # data type: LabelingSelection
42
+ oneof_schema_4_validator: Optional[LabelingSelection] = None
43
+ # data type: StaticSelection
44
+ oneof_schema_5_validator: Optional[StaticSelection] = None
45
+ # data type: ValidationSelection
46
+ oneof_schema_6_validator: Optional[ValidationSelection] = None
47
+ actual_instance: Optional[Union[CappedSelection, ConditionalValidationSelection, DemographicSelection, LabelingSelection, StaticSelection, ValidationSelection]] = None
48
+ one_of_schemas: Set[str] = { "CappedSelection", "ConditionalValidationSelection", "DemographicSelection", "LabelingSelection", "StaticSelection", "ValidationSelection" }
49
+
50
+ model_config = ConfigDict(
51
+ validate_assignment=True,
52
+ protected_namespaces=(),
53
+ )
54
+
55
+
56
+ discriminator_value_class_map: Dict[str, str] = {
57
+ }
58
+
59
+ def __init__(self, *args, **kwargs) -> None:
60
+ if args:
61
+ if len(args) > 1:
62
+ raise ValueError("If a position argument is used, only 1 is allowed to set `actual_instance`")
63
+ if kwargs:
64
+ raise ValueError("If a position argument is used, keyword arguments cannot be used.")
65
+ super().__init__(actual_instance=args[0])
66
+ else:
67
+ super().__init__(**kwargs)
68
+
69
+ @field_validator('actual_instance')
70
+ def actual_instance_must_validate_oneof(cls, v):
71
+ instance = CappedSelectionSelectionsInner.model_construct()
72
+ error_messages = []
73
+ match = 0
74
+ # validate data type: CappedSelection
75
+ if not isinstance(v, CappedSelection):
76
+ error_messages.append(f"Error! Input type `{type(v)}` is not `CappedSelection`")
77
+ else:
78
+ match += 1
79
+ # validate data type: ConditionalValidationSelection
80
+ if not isinstance(v, ConditionalValidationSelection):
81
+ error_messages.append(f"Error! Input type `{type(v)}` is not `ConditionalValidationSelection`")
82
+ else:
83
+ match += 1
84
+ # validate data type: DemographicSelection
85
+ if not isinstance(v, DemographicSelection):
86
+ error_messages.append(f"Error! Input type `{type(v)}` is not `DemographicSelection`")
87
+ else:
88
+ match += 1
89
+ # validate data type: LabelingSelection
90
+ if not isinstance(v, LabelingSelection):
91
+ error_messages.append(f"Error! Input type `{type(v)}` is not `LabelingSelection`")
92
+ else:
93
+ match += 1
94
+ # validate data type: StaticSelection
95
+ if not isinstance(v, StaticSelection):
96
+ error_messages.append(f"Error! Input type `{type(v)}` is not `StaticSelection`")
97
+ else:
98
+ match += 1
99
+ # validate data type: ValidationSelection
100
+ if not isinstance(v, ValidationSelection):
101
+ error_messages.append(f"Error! Input type `{type(v)}` is not `ValidationSelection`")
102
+ else:
103
+ match += 1
104
+ if match > 1:
105
+ # more than 1 match
106
+ raise ValueError("Multiple matches found when setting `actual_instance` in CappedSelectionSelectionsInner with oneOf schemas: CappedSelection, ConditionalValidationSelection, DemographicSelection, LabelingSelection, StaticSelection, ValidationSelection. Details: " + ", ".join(error_messages))
107
+ elif match == 0:
108
+ # no match
109
+ raise ValueError("No match found when setting `actual_instance` in CappedSelectionSelectionsInner with oneOf schemas: CappedSelection, ConditionalValidationSelection, DemographicSelection, LabelingSelection, StaticSelection, ValidationSelection. Details: " + ", ".join(error_messages))
110
+ else:
111
+ return v
112
+
113
+ @classmethod
114
+ def from_dict(cls, obj: Union[str, Dict[str, Any]]) -> Self:
115
+ return cls.from_json(json.dumps(obj))
116
+
117
+ @classmethod
118
+ def from_json(cls, json_str: str) -> Self:
119
+ """Returns the object represented by the json string"""
120
+ instance = cls.model_construct()
121
+ error_messages = []
122
+ match = 0
123
+
124
+ # deserialize data into CappedSelection
125
+ try:
126
+ instance.actual_instance = CappedSelection.from_json(json_str)
127
+ match += 1
128
+ except (ValidationError, ValueError) as e:
129
+ error_messages.append(str(e))
130
+ # deserialize data into ConditionalValidationSelection
131
+ try:
132
+ instance.actual_instance = ConditionalValidationSelection.from_json(json_str)
133
+ match += 1
134
+ except (ValidationError, ValueError) as e:
135
+ error_messages.append(str(e))
136
+ # deserialize data into DemographicSelection
137
+ try:
138
+ instance.actual_instance = DemographicSelection.from_json(json_str)
139
+ match += 1
140
+ except (ValidationError, ValueError) as e:
141
+ error_messages.append(str(e))
142
+ # deserialize data into LabelingSelection
143
+ try:
144
+ instance.actual_instance = LabelingSelection.from_json(json_str)
145
+ match += 1
146
+ except (ValidationError, ValueError) as e:
147
+ error_messages.append(str(e))
148
+ # deserialize data into StaticSelection
149
+ try:
150
+ instance.actual_instance = StaticSelection.from_json(json_str)
151
+ match += 1
152
+ except (ValidationError, ValueError) as e:
153
+ error_messages.append(str(e))
154
+ # deserialize data into ValidationSelection
155
+ try:
156
+ instance.actual_instance = ValidationSelection.from_json(json_str)
157
+ match += 1
158
+ except (ValidationError, ValueError) as e:
159
+ error_messages.append(str(e))
160
+
161
+ if match > 1:
162
+ # more than 1 match
163
+ raise ValueError("Multiple matches found when deserializing the JSON string into CappedSelectionSelectionsInner with oneOf schemas: CappedSelection, ConditionalValidationSelection, DemographicSelection, LabelingSelection, StaticSelection, ValidationSelection. Details: " + ", ".join(error_messages))
164
+ elif match == 0:
165
+ # no match
166
+ raise ValueError("No match found when deserializing the JSON string into CappedSelectionSelectionsInner with oneOf schemas: CappedSelection, ConditionalValidationSelection, DemographicSelection, LabelingSelection, StaticSelection, ValidationSelection. Details: " + ", ".join(error_messages))
167
+ else:
168
+ return instance
169
+
170
+ def to_json(self) -> str:
171
+ """Returns the JSON representation of the actual instance"""
172
+ if self.actual_instance is None:
173
+ return "null"
174
+
175
+ if hasattr(self.actual_instance, "to_json") and callable(self.actual_instance.to_json):
176
+ return self.actual_instance.to_json()
177
+ else:
178
+ return json.dumps(self.actual_instance)
179
+
180
+ def to_dict(self) -> Optional[Union[Dict[str, Any], CappedSelection, ConditionalValidationSelection, DemographicSelection, LabelingSelection, StaticSelection, ValidationSelection]]:
181
+ """Returns the dict representation of the actual instance"""
182
+ if self.actual_instance is None:
183
+ return None
184
+
185
+ if hasattr(self.actual_instance, "to_dict") and callable(self.actual_instance.to_dict):
186
+ return self.actual_instance.to_dict()
187
+ else:
188
+ # primitive type
189
+ return self.actual_instance
190
+
191
+ def to_str(self) -> str:
192
+ """Returns the string representation of the actual instance"""
193
+ return pprint.pformat(self.model_dump())
194
+
195
+ from rapidata.api_client.models.capped_selection import CappedSelection
196
+ # TODO: Rewrite to not use raise_errors
197
+ CappedSelectionSelectionsInner.model_rebuild(raise_errors=False)
198
+
@@ -19,8 +19,8 @@ import json
19
19
 
20
20
  from pydantic import BaseModel, ConfigDict, Field, StrictInt, StrictStr, field_validator
21
21
  from typing import Any, ClassVar, Dict, List, Optional
22
+ from rapidata.api_client.models.capped_selection_selections_inner import CappedSelectionSelectionsInner
22
23
  from rapidata.api_client.models.create_order_model_referee import CreateOrderModelReferee
23
- from rapidata.api_client.models.create_order_model_selections_inner import CreateOrderModelSelectionsInner
24
24
  from rapidata.api_client.models.create_order_model_user_filters_inner import CreateOrderModelUserFiltersInner
25
25
  from rapidata.api_client.models.create_order_model_workflow import CreateOrderModelWorkflow
26
26
  from rapidata.api_client.models.feature_flag_model import FeatureFlagModel
@@ -40,7 +40,7 @@ class CreateOrderModel(BaseModel):
40
40
  priority: Optional[StrictInt] = Field(default=None, description="The priority is used to prioritize over other orders.")
41
41
  user_filters: List[CreateOrderModelUserFiltersInner] = Field(description="The user filters are used to restrict the order to only collect votes from a specific demographic.", alias="userFilters")
42
42
  validation_set_id: Optional[StrictStr] = Field(default=None, description="The validation set id can be changed to point to a specific validation set. if not provided a sane default will be used.", alias="validationSetId")
43
- selections: Optional[List[CreateOrderModelSelectionsInner]] = Field(default=None, description="The selections are used to determine which tasks are shown to a user.")
43
+ selections: Optional[List[CappedSelectionSelectionsInner]] = Field(default=None, description="The selections are used to determine which tasks are shown to a user.")
44
44
  __properties: ClassVar[List[str]] = ["_t", "orderName", "workflow", "referee", "aggregator", "featureFlags", "priority", "userFilters", "validationSetId", "selections"]
45
45
 
46
46
  @field_validator('t')
@@ -172,7 +172,7 @@ class CreateOrderModel(BaseModel):
172
172
  "priority": obj.get("priority"),
173
173
  "userFilters": [CreateOrderModelUserFiltersInner.from_dict(_item) for _item in obj["userFilters"]] if obj.get("userFilters") is not None else None,
174
174
  "validationSetId": obj.get("validationSetId"),
175
- "selections": [CreateOrderModelSelectionsInner.from_dict(_item) for _item in obj["selections"]] if obj.get("selections") is not None else None
175
+ "selections": [CappedSelectionSelectionsInner.from_dict(_item) for _item in obj["selections"]] if obj.get("selections") is not None else None
176
176
  })
177
177
  return _obj
178
178
 
@@ -133,6 +133,7 @@ Class | Method | HTTP request | Description
133
133
  *RapidApi* | [**rapid_validate_current_rapid_bag_get**](rapidata/api_client/docs/RapidApi.md#rapid_validate_current_rapid_bag_get) | **GET** /Rapid/ValidateCurrentRapidBag | Validates that the rapids associated with the current user are active.
134
134
  *SimpleWorkflowApi* | [**simple_workflow_get_result_overview_get**](rapidata/api_client/docs/SimpleWorkflowApi.md#simple_workflow_get_result_overview_get) | **GET** /SimpleWorkflow/GetResultOverview | Get the result overview for a simple workflow.
135
135
  *ValidationApi* | [**validation_add_validation_rapid_post**](rapidata/api_client/docs/ValidationApi.md#validation_add_validation_rapid_post) | **POST** /Validation/AddValidationRapid | Adds a new validation rapid to the specified validation set.
136
+ *ValidationApi* | [**validation_add_validation_text_rapid_post**](rapidata/api_client/docs/ValidationApi.md#validation_add_validation_text_rapid_post) | **POST** /Validation/AddValidationTextRapid | Adds a new validation rapid to the specified validation set.
136
137
  *ValidationApi* | [**validation_create_validation_set_post**](rapidata/api_client/docs/ValidationApi.md#validation_create_validation_set_post) | **POST** /Validation/CreateValidationSet | Creates a new empty validation set.
137
138
  *ValidationApi* | [**validation_get_available_validation_sets_get**](rapidata/api_client/docs/ValidationApi.md#validation_get_available_validation_sets_get) | **GET** /Validation/GetAvailableValidationSets | Gets the available validation sets for the current user.
138
139
  *ValidationApi* | [**validation_import_compare_post**](rapidata/api_client/docs/ValidationApi.md#validation_import_compare_post) | **POST** /Validation/ImportCompare | Imports a compare validation set from a zip file.
@@ -157,6 +158,7 @@ Class | Method | HTTP request | Description
157
158
  - [AddValidationRapidModelPayload](rapidata/api_client/docs/AddValidationRapidModelPayload.md)
158
159
  - [AddValidationRapidModelTruth](rapidata/api_client/docs/AddValidationRapidModelTruth.md)
159
160
  - [AddValidationRapidResult](rapidata/api_client/docs/AddValidationRapidResult.md)
161
+ - [AddValidationTextRapidModel](rapidata/api_client/docs/AddValidationTextRapidModel.md)
160
162
  - [AdminOrderModel](rapidata/api_client/docs/AdminOrderModel.md)
161
163
  - [AdminOrderModelPagedResult](rapidata/api_client/docs/AdminOrderModelPagedResult.md)
162
164
  - [AgeGroup](rapidata/api_client/docs/AgeGroup.md)
@@ -176,6 +178,8 @@ Class | Method | HTTP request | Description
176
178
  - [CampaignQueryModelPagedResult](rapidata/api_client/docs/CampaignQueryModelPagedResult.md)
177
179
  - [CampaignStatus](rapidata/api_client/docs/CampaignStatus.md)
178
180
  - [CampaignUserFilterModel](rapidata/api_client/docs/CampaignUserFilterModel.md)
181
+ - [CappedSelection](rapidata/api_client/docs/CappedSelection.md)
182
+ - [CappedSelectionSelectionsInner](rapidata/api_client/docs/CappedSelectionSelectionsInner.md)
179
183
  - [ClassificationMetadata](rapidata/api_client/docs/ClassificationMetadata.md)
180
184
  - [ClassificationMetadataFilterConfig](rapidata/api_client/docs/ClassificationMetadataFilterConfig.md)
181
185
  - [ClassifyPayload](rapidata/api_client/docs/ClassifyPayload.md)
@@ -211,7 +215,6 @@ Class | Method | HTTP request | Description
211
215
  - [CreateIndependentWorkflowResult](rapidata/api_client/docs/CreateIndependentWorkflowResult.md)
212
216
  - [CreateOrderModel](rapidata/api_client/docs/CreateOrderModel.md)
213
217
  - [CreateOrderModelReferee](rapidata/api_client/docs/CreateOrderModelReferee.md)
214
- - [CreateOrderModelSelectionsInner](rapidata/api_client/docs/CreateOrderModelSelectionsInner.md)
215
218
  - [CreateOrderModelUserFiltersInner](rapidata/api_client/docs/CreateOrderModelUserFiltersInner.md)
216
219
  - [CreateOrderModelWorkflow](rapidata/api_client/docs/CreateOrderModelWorkflow.md)
217
220
  - [CreateOrderResult](rapidata/api_client/docs/CreateOrderResult.md)
@@ -10,3 +10,4 @@ from .referee import NaiveReferee, ClassifyEarlyStoppingReferee
10
10
  from .metadata import PrivateTextMetadata, PublicTextMetadata, PromptMetadata, TranscriptionMetadata
11
11
  from .feature_flags import FeatureFlags
12
12
  from .country_codes import CountryCodes
13
+ from .assets import MediaAsset, TextAsset, MultiAsset
@@ -0,0 +1,8 @@
1
+ """Assets Package
2
+
3
+ This package provides classes for different types of assets, including MediaAsset, TextAsset, and MultiAsset.
4
+ """
5
+
6
+ from .media_asset import MediaAsset
7
+ from .text_asset import TextAsset
8
+ from .multi_asset import MultiAsset
@@ -0,0 +1,11 @@
1
+ """Base Asset Module
2
+
3
+ Defines the BaseAsset class, which serves as the abstract base class for all asset types.
4
+ """
5
+
6
+ class BaseAsset:
7
+ """BaseAsset Class
8
+
9
+ An abstract base class for different types of assets. This class is intended to be subclassed by specific asset types.
10
+ """
11
+ pass
@@ -0,0 +1,33 @@
1
+ """Media Asset Module
2
+
3
+ Defines the MediaAsset class for handling media file paths within assets.
4
+ """
5
+
6
+ import os
7
+ from rapidata.rapidata_client.assets.base_asset import BaseAsset
8
+
9
+ class MediaAsset(BaseAsset):
10
+ """MediaAsset Class
11
+
12
+ Represents a media asset by storing the file path.
13
+
14
+ Args:
15
+ path (str): The file system path to the media asset.
16
+
17
+ Raises:
18
+ FileNotFoundError: If the provided file path does not exist.
19
+ """
20
+
21
+ def __init__(self, path: str):
22
+ """
23
+ Initialize a MediaAsset instance.
24
+
25
+ Args:
26
+ path (str): The file system path to the media asset.
27
+
28
+ Raises:
29
+ FileNotFoundError: If the provided file path does not exist.
30
+ """
31
+ if not os.path.exists(path):
32
+ raise FileNotFoundError(f"File not found: {path}")
33
+ self.path = path
@@ -0,0 +1,44 @@
1
+ """Multi Asset Module
2
+
3
+ Defines the MultiAsset class for handling multiple BaseAsset instances.
4
+ """
5
+
6
+ from rapidata.rapidata_client.assets.base_asset import BaseAsset
7
+ from typing import Iterator, List
8
+
9
+
10
+ class MultiAsset(BaseAsset):
11
+ """MultiAsset Class
12
+
13
+ Represents a collection of multiple BaseAsset instances.
14
+
15
+ Args:
16
+ assets (List[BaseAsset]): A list of BaseAsset instances to be managed together.
17
+ """
18
+
19
+ def __init__(self, assets: list[BaseAsset]):
20
+ """
21
+ Initialize a MultiAsset instance.
22
+
23
+ Args:
24
+ assets (List[BaseAsset]): A list of BaseAsset instances to be managed together.
25
+ """
26
+ self.assets = assets
27
+
28
+ def __len__(self) -> int:
29
+ """
30
+ Get the number of assets in the MultiAsset.
31
+
32
+ Returns:
33
+ int: The number of assets.
34
+ """
35
+ return len(self.assets)
36
+
37
+ def __iter__(self) -> Iterator[BaseAsset]:
38
+ """
39
+ Return an iterator over the assets in the MultiAsset.
40
+
41
+ Returns:
42
+ Iterator[BaseAsset]: An iterator over the assets.
43
+ """
44
+ return iter(self.assets)
@@ -0,0 +1,25 @@
1
+ """Text Asset Module
2
+
3
+ Defines the TextAsset class for handling textual data within assets.
4
+ """
5
+
6
+ from rapidata.rapidata_client.assets.base_asset import BaseAsset
7
+
8
+
9
+ class TextAsset(BaseAsset):
10
+ """TextAsset Class
11
+
12
+ Represents a textual asset.
13
+
14
+ Args:
15
+ text (str): The text content of the asset.
16
+ """
17
+
18
+ def __init__(self, text: str):
19
+ """
20
+ Initialize a TextAsset instance.
21
+
22
+ Args:
23
+ text (str): The textual content of the asset.
24
+ """
25
+ self.text = text
@@ -32,11 +32,11 @@ class RapidataDataset:
32
32
 
33
33
  def add_media_from_paths(
34
34
  self,
35
- image_paths: list[str | list[str]],
35
+ media_paths: list[str | list[str]],
36
36
  metadata: list[Metadata] | None = None,
37
37
  max_workers: int = 10,
38
38
  ):
39
- if metadata is not None and len(metadata) != len(image_paths):
39
+ if metadata is not None and len(metadata) != len(media_paths):
40
40
  raise ValueError(
41
41
  "metadata must be None or have the same length as image_paths"
42
42
  )
@@ -66,11 +66,11 @@ class RapidataDataset:
66
66
  files=media_paths_rapid if isinstance(media_paths_rapid, list) else [media_paths_rapid] # type: ignore
67
67
  )
68
68
 
69
- total_uploads = len(image_paths)
69
+ total_uploads = len(media_paths)
70
70
  with ThreadPoolExecutor(max_workers=max_workers) as executor:
71
71
  futures = [
72
72
  executor.submit(upload_datapoint, media_paths, meta)
73
- for media_paths, meta in zip_longest(image_paths, metadata or [])
73
+ for media_paths, meta in zip_longest(media_paths, metadata or [])
74
74
  ]
75
75
 
76
76
  with tqdm(total=total_uploads, desc="Uploading datapoints") as pbar:
@@ -3,6 +3,9 @@ from typing import Any
3
3
  from rapidata.api_client.models.add_validation_rapid_model import (
4
4
  AddValidationRapidModel,
5
5
  )
6
+ from rapidata.api_client.models.add_validation_text_rapid_model import (
7
+ AddValidationTextRapidModel,
8
+ )
6
9
  from rapidata.api_client.models.add_validation_rapid_model_payload import (
7
10
  AddValidationRapidModelPayload,
8
11
  )
@@ -31,6 +34,9 @@ from rapidata.api_client.models.polygon_truth import PolygonTruth
31
34
  from rapidata.api_client.models.transcription_payload import TranscriptionPayload
32
35
  from rapidata.api_client.models.transcription_truth import TranscriptionTruth
33
36
  from rapidata.api_client.models.transcription_word import TranscriptionWord
37
+ from rapidata.rapidata_client.assets.media_asset import MediaAsset
38
+ from rapidata.rapidata_client.assets.multi_asset import MultiAsset
39
+ from rapidata.rapidata_client.assets.text_asset import TextAsset
34
40
  from rapidata.rapidata_client.metadata.base_metadata import Metadata
35
41
  from rapidata.service.openapi_service import OpenAPIService
36
42
 
@@ -38,7 +44,7 @@ from rapidata.service.openapi_service import OpenAPIService
38
44
  class RapidataValidationSet:
39
45
  """A class for interacting with a Rapidata validation set.
40
46
 
41
- Get a `ValidationSet` either by using `rapi.get_validation_set(id)` to get an exisitng validation set or by using `rapi.new_validation_set(name)` to create a new validation set.
47
+ Get a `ValidationSet` either by using `rapi.get_validation_set(id)` to get an existing validation set or by using `rapi.new_validation_set(name)` to create a new validation set.
42
48
  """
43
49
 
44
50
  def __init__(self, validation_set_id, openapi_service: OpenAPIService):
@@ -70,21 +76,25 @@ class RapidataValidationSet:
70
76
  | TranscriptionTruth
71
77
  ),
72
78
  metadata: list[Metadata],
73
- media_paths: str | list[str],
79
+ asset: MediaAsset | TextAsset | MultiAsset,
74
80
  randomCorrectProbability: float,
75
81
  ) -> None:
76
82
  """Add a validation rapid to the validation set.
77
83
 
78
84
  Args:
79
- payload (Union[BoundingBoxPayload, ClassifyPayload, ComparePayload, FreeTextPayload, LinePayload, LocatePayload, NamedEntityPayload, PolygonPayload, TranscriptionPayload]): The payload for the rapid.
80
- truths (Union[AttachCategoryTruth, BoundingBoxTruth, CompareTruth, EmptyValidationTruth, LineTruth, LocateBoxTruth, NamedEntityTruth, PolygonTruth, TranscriptionTruth]): The truths for the rapid.
85
+ payload: The payload for the rapid.
86
+ truths: The truths for the rapid.
81
87
  metadata (list[Metadata]): The metadata for the rapid.
82
- media_paths (Union[str, list[str]]): The media paths for the rapid.
88
+ asset: The asset(s) for the rapid.
83
89
  randomCorrectProbability (float): The random correct probability for the rapid.
84
90
 
85
91
  Returns:
86
92
  None
93
+
94
+ Raises:
95
+ ValueError: If an invalid asset type is provided.
87
96
  """
97
+
88
98
  model = AddValidationRapidModel(
89
99
  validationSetId=self.id,
90
100
  payload=AddValidationRapidModelPayload(payload),
@@ -95,14 +105,37 @@ class RapidataValidationSet:
95
105
  ],
96
106
  randomCorrectProbability=randomCorrectProbability,
97
107
  )
108
+ if isinstance(asset, MediaAsset):
109
+ self.openapi_service.validation_api.validation_add_validation_rapid_post(
110
+ model=model, files=[asset.path]
111
+ )
98
112
 
99
- self.openapi_service.validation_api.validation_add_validation_rapid_post(
100
- model=model, files=media_paths if isinstance(media_paths, list) else [media_paths] # type: ignore
101
- )
113
+ elif isinstance(asset, MultiAsset):
114
+ self.openapi_service.validation_api.validation_add_validation_rapid_post(
115
+ model=model, files=[a.path for a in asset if isinstance(a, MediaAsset)]
116
+ )
117
+
118
+ elif isinstance(asset, TextAsset):
119
+ model = AddValidationTextRapidModel(
120
+ validationSetId=self.id,
121
+ payload=AddValidationRapidModelPayload(payload),
122
+ truth=AddValidationRapidModelTruth(truths),
123
+ metadata=[
124
+ DatapointMetadataModelMetadataInner(meta.to_model())
125
+ for meta in metadata
126
+ ],
127
+ randomCorrectProbability=randomCorrectProbability,
128
+ text=asset.text,
129
+ )
130
+ self.openapi_service.validation_api.validation_add_validation_text_rapid_post(
131
+ add_validation_text_rapid_model=model
132
+ )
133
+ else:
134
+ raise ValueError("Invalid asset type")
102
135
 
103
136
  def add_classify_validation_rapid(
104
137
  self,
105
- media_path: str,
138
+ asset: MediaAsset | TextAsset,
106
139
  question: str,
107
140
  categories: list[str],
108
141
  truths: list[str],
@@ -111,7 +144,7 @@ class RapidataValidationSet:
111
144
  """Add a classify rapid to the validation set.
112
145
 
113
146
  Args:
114
- media_path (str): The path to the media file.
147
+ asset (MediaAsset | TextAsset): The asset for the rapid.
115
148
  question (str): The question for the rapid.
116
149
  categories (list[str]): The list of categories for the rapid.
117
150
  truths (list[str]): The list of truths for the rapid.
@@ -131,13 +164,13 @@ class RapidataValidationSet:
131
164
  payload=payload,
132
165
  truths=model_truth,
133
166
  metadata=metadata,
134
- media_paths=media_path,
167
+ asset=asset,
135
168
  randomCorrectProbability=len(truths) / len(categories),
136
169
  )
137
170
 
138
171
  def add_compare_validation_rapid(
139
172
  self,
140
- media_paths: list[str],
173
+ asset: MultiAsset,
141
174
  question: str,
142
175
  truth: str,
143
176
  metadata: list[Metadata] = [],
@@ -145,7 +178,7 @@ class RapidataValidationSet:
145
178
  """Add a compare rapid to the validation set.
146
179
 
147
180
  Args:
148
- media_paths (list[str]): The list of media paths for the rapid.
181
+ asset (MultiAsset): The assets for the rapid.
149
182
  question (str): The question for the rapid.
150
183
  truth (str): The path to the truth file.
151
184
  metadata (list[Metadata], optional): The metadata for the rapid. Defaults to an empty list.
@@ -154,33 +187,27 @@ class RapidataValidationSet:
154
187
  None
155
188
 
156
189
  Raises:
157
- ValueError: If the number of media paths is not exactly two.
158
- FileNotFoundError: If any of the specified files are not found.
190
+ ValueError: If the number of assets is not exactly two.
159
191
  """
160
192
  payload = ComparePayload(_t="ComparePayload", criteria=question)
161
193
  # take only last part of truth path
162
194
  truth = os.path.basename(truth)
163
195
  model_truth = CompareTruth(_t="CompareTruth", winnerId=truth)
164
196
 
165
- if len(media_paths) != 2:
197
+ if len(asset) != 2:
166
198
  raise ValueError("Compare rapid requires exactly two media paths")
167
199
 
168
- # check that files exist
169
- for media_path in media_paths:
170
- if not os.path.exists(media_path):
171
- raise FileNotFoundError(f"File not found: {media_path}")
172
-
173
200
  self.add_general_validation_rapid(
174
201
  payload=payload,
175
202
  truths=model_truth,
176
203
  metadata=metadata,
177
- media_paths=media_paths,
178
- randomCorrectProbability=1 / len(media_paths),
204
+ asset=asset,
205
+ randomCorrectProbability=1 / len(asset),
179
206
  )
180
207
 
181
208
  def add_transcription_validation_rapid(
182
209
  self,
183
- media_path: str,
210
+ asset: MediaAsset | TextAsset,
184
211
  question: str,
185
212
  transcription: list[str],
186
213
  correct_words: list[str],
@@ -190,11 +217,11 @@ class RapidataValidationSet:
190
217
  """Add a transcription rapid to the validation set.
191
218
 
192
219
  Args:
193
- media_path (str): The path to the media file.
220
+ asset (MediaAsset | TextAsset): The asset for the rapid.
194
221
  question (str): The question for the rapid.
195
222
  transcription (list[str]): The transcription for the rapid.
196
223
  correct_words (list[str]): The list of correct words for the rapid.
197
- strict_grading (Optional[bool], optional): The strict grading for the rapid. Defaults to None.
224
+ strict_grading (bool | None, optional): The strict grading for the rapid. Defaults to None.
198
225
  metadata (list[Metadata], optional): The metadata for the rapid. Defaults to an empty list.
199
226
 
200
227
  Returns:
@@ -230,6 +257,6 @@ class RapidataValidationSet:
230
257
  payload=payload,
231
258
  truths=model_truth,
232
259
  metadata=metadata,
233
- media_paths=media_path,
260
+ asset=asset,
234
261
  randomCorrectProbability=len(correct_words) / len(transcription),
235
262
  )
@@ -19,13 +19,16 @@ from rapidata.api_client.models.polygon_payload import PolygonPayload
19
19
  from rapidata.api_client.models.polygon_truth import PolygonTruth
20
20
  from rapidata.api_client.models.transcription_payload import TranscriptionPayload
21
21
  from rapidata.api_client.models.transcription_truth import TranscriptionTruth
22
+ from rapidata.rapidata_client.assets.media_asset import MediaAsset
23
+ from rapidata.rapidata_client.assets.multi_asset import MultiAsset
24
+ from rapidata.rapidata_client.assets.text_asset import TextAsset
22
25
  from rapidata.rapidata_client.metadata.base_metadata import Metadata
23
26
 
24
27
 
25
28
  @dataclass
26
29
  class ValidatioRapidParts:
27
30
  question: str
28
- media_paths: str | list[str]
31
+ asset: MediaAsset | TextAsset | MultiAsset
29
32
  payload: (
30
33
  BoundingBoxPayload
31
34
  | ClassifyPayload