rapidata 1.1.0__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. rapidata/__init__.py +1 -0
  2. rapidata/api_client/__init__.py +5 -3
  3. rapidata/api_client/api/__init__.py +2 -0
  4. rapidata/api_client/api/campaign_api.py +8 -4
  5. rapidata/api_client/api/coco_api.py +4 -2
  6. rapidata/api_client/api/compare_workflow_api.py +2 -1
  7. rapidata/api_client/api/datapoint_api.py +6 -3
  8. rapidata/api_client/api/dataset_api.py +16 -8
  9. rapidata/api_client/api/identity_api.py +329 -50
  10. rapidata/api_client/api/newsletter_api.py +4 -2
  11. rapidata/api_client/api/order_api.py +40 -20
  12. rapidata/api_client/api/pipeline_api.py +6 -3
  13. rapidata/api_client/api/rapid_api.py +10 -5
  14. rapidata/api_client/api/rapidata_identity_api_api.py +272 -0
  15. rapidata/api_client/api/simple_workflow_api.py +2 -1
  16. rapidata/api_client/api/user_info_api.py +272 -0
  17. rapidata/api_client/api/validation_api.py +14 -7
  18. rapidata/api_client/api/workflow_api.py +18 -9
  19. rapidata/api_client/models/__init__.py +3 -3
  20. rapidata/api_client/models/issue_auth_token_result.py +1 -1
  21. rapidata/api_client/models/legacy_issue_client_auth_token_result.py +87 -0
  22. rapidata/api_client/models/legacy_request_password_reset_command.py +98 -0
  23. rapidata/api_client/models/legacy_submit_password_reset_command.py +102 -0
  24. rapidata/api_client_README.md +10 -3
  25. rapidata/rapidata_client/__init__.py +13 -2
  26. rapidata/rapidata_client/assets/multi_asset.py +2 -0
  27. rapidata/rapidata_client/dataset/rapidata_dataset.py +19 -15
  28. rapidata/rapidata_client/dataset/validation_set_builder.py +1 -1
  29. rapidata/rapidata_client/order/rapidata_order.py +49 -18
  30. rapidata/rapidata_client/order/rapidata_order_builder.py +23 -34
  31. rapidata/rapidata_client/selection/__init__.py +1 -0
  32. rapidata/rapidata_client/selection/capped_selection.py +25 -0
  33. rapidata/rapidata_client/simple_builders/__init__.py +0 -0
  34. rapidata/rapidata_client/simple_builders/simple_classification_builders.py +14 -9
  35. rapidata/rapidata_client/simple_builders/simple_compare_builders.py +6 -3
  36. rapidata/service/openapi_service.py +15 -0
  37. {rapidata-1.1.0.dist-info → rapidata-1.2.0.dist-info}/METADATA +1 -1
  38. {rapidata-1.1.0.dist-info → rapidata-1.2.0.dist-info}/RECORD +40 -33
  39. {rapidata-1.1.0.dist-info → rapidata-1.2.0.dist-info}/LICENSE +0 -0
  40. {rapidata-1.1.0.dist-info → rapidata-1.2.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,102 @@
1
+ # coding: utf-8
2
+
3
+ """
4
+ Rapidata.Dataset
5
+
6
+ No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
7
+
8
+ The version of the OpenAPI document: v1
9
+ Generated by OpenAPI Generator (https://openapi-generator.tech)
10
+
11
+ Do not edit the class manually.
12
+ """ # noqa: E501
13
+
14
+
15
+ from __future__ import annotations
16
+ import pprint
17
+ import re # noqa: F401
18
+ import json
19
+
20
+ from pydantic import BaseModel, ConfigDict, Field, StrictStr, field_validator
21
+ from typing import Any, ClassVar, Dict, List
22
+ from typing import Optional, Set
23
+ from typing_extensions import Self
24
+
25
+ class LegacySubmitPasswordResetCommand(BaseModel):
26
+ """
27
+ LegacySubmitPasswordResetCommand
28
+ """ # noqa: E501
29
+ t: StrictStr = Field(description="Discriminator value for LegacySubmitPasswordResetCommand", alias="_t")
30
+ user_id: StrictStr = Field(alias="userId")
31
+ password: StrictStr
32
+ password_repeated: StrictStr = Field(alias="passwordRepeated")
33
+ reset_token: StrictStr = Field(alias="resetToken")
34
+ __properties: ClassVar[List[str]] = ["_t", "userId", "password", "passwordRepeated", "resetToken"]
35
+
36
+ @field_validator('t')
37
+ def t_validate_enum(cls, value):
38
+ """Validates the enum"""
39
+ if value not in set(['LegacySubmitPasswordResetCommand']):
40
+ raise ValueError("must be one of enum values ('LegacySubmitPasswordResetCommand')")
41
+ return value
42
+
43
+ model_config = ConfigDict(
44
+ populate_by_name=True,
45
+ validate_assignment=True,
46
+ protected_namespaces=(),
47
+ )
48
+
49
+
50
+ def to_str(self) -> str:
51
+ """Returns the string representation of the model using alias"""
52
+ return pprint.pformat(self.model_dump(by_alias=True))
53
+
54
+ def to_json(self) -> str:
55
+ """Returns the JSON representation of the model using alias"""
56
+ # TODO: pydantic v2: use .model_dump_json(by_alias=True, exclude_unset=True) instead
57
+ return json.dumps(self.to_dict())
58
+
59
+ @classmethod
60
+ def from_json(cls, json_str: str) -> Optional[Self]:
61
+ """Create an instance of LegacySubmitPasswordResetCommand from a JSON string"""
62
+ return cls.from_dict(json.loads(json_str))
63
+
64
+ def to_dict(self) -> Dict[str, Any]:
65
+ """Return the dictionary representation of the model using alias.
66
+
67
+ This has the following differences from calling pydantic's
68
+ `self.model_dump(by_alias=True)`:
69
+
70
+ * `None` is only added to the output dict for nullable fields that
71
+ were set at model initialization. Other fields with value `None`
72
+ are ignored.
73
+ """
74
+ excluded_fields: Set[str] = set([
75
+ ])
76
+
77
+ _dict = self.model_dump(
78
+ by_alias=True,
79
+ exclude=excluded_fields,
80
+ exclude_none=True,
81
+ )
82
+ return _dict
83
+
84
+ @classmethod
85
+ def from_dict(cls, obj: Optional[Dict[str, Any]]) -> Optional[Self]:
86
+ """Create an instance of LegacySubmitPasswordResetCommand from a dict"""
87
+ if obj is None:
88
+ return None
89
+
90
+ if not isinstance(obj, dict):
91
+ return cls.model_validate(obj)
92
+
93
+ _obj = cls.model_validate({
94
+ "_t": obj.get("_t") if obj.get("_t") is not None else 'LegacySubmitPasswordResetCommand',
95
+ "userId": obj.get("userId"),
96
+ "password": obj.get("password"),
97
+ "passwordRepeated": obj.get("passwordRepeated"),
98
+ "resetToken": obj.get("resetToken")
99
+ })
100
+ return _obj
101
+
102
+
@@ -97,6 +97,7 @@ Class | Method | HTTP request | Description
97
97
  *IdentityApi* | [**identity_get_client_auth_token_post**](rapidata/api_client/docs/IdentityApi.md#identity_get_client_auth_token_post) | **POST** /Identity/GetClientAuthToken | Issues a new auth token using the client credentials.
98
98
  *IdentityApi* | [**identity_index_post**](rapidata/api_client/docs/IdentityApi.md#identity_index_post) | **POST** /Identity/Index | Logs in a user by username or email and password.
99
99
  *IdentityApi* | [**identity_logout_post**](rapidata/api_client/docs/IdentityApi.md#identity_logout_post) | **POST** /Identity/Logout | Logs out the current user by deleting the refresh token cookie.
100
+ *IdentityApi* | [**identity_register_temporary_post**](rapidata/api_client/docs/IdentityApi.md#identity_register_temporary_post) | **POST** /Identity/RegisterTemporary | Registers and logs in a temporary customer.
100
101
  *IdentityApi* | [**identity_request_reset_post**](rapidata/api_client/docs/IdentityApi.md#identity_request_reset_post) | **POST** /Identity/RequestReset | Request a password reset for a user.
101
102
  *IdentityApi* | [**identity_signup_post**](rapidata/api_client/docs/IdentityApi.md#identity_signup_post) | **POST** /Identity/Signup | Signs up a new user.
102
103
  *IdentityApi* | [**identity_submit_reset_post**](rapidata/api_client/docs/IdentityApi.md#identity_submit_reset_post) | **POST** /Identity/SubmitReset | Updates the password of a user after a password reset request.
@@ -131,7 +132,9 @@ Class | Method | HTTP request | Description
131
132
  *RapidApi* | [**rapid_query_validation_rapids_get**](rapidata/api_client/docs/RapidApi.md#rapid_query_validation_rapids_get) | **GET** /Rapid/QueryValidationRapids | Queries the validation rapids for a specific validation set.
132
133
  *RapidApi* | [**rapid_skip_user_guess_post**](rapidata/api_client/docs/RapidApi.md#rapid_skip_user_guess_post) | **POST** /Rapid/SkipUserGuess | Skips a Rapid for the user.
133
134
  *RapidApi* | [**rapid_validate_current_rapid_bag_get**](rapidata/api_client/docs/RapidApi.md#rapid_validate_current_rapid_bag_get) | **GET** /Rapid/ValidateCurrentRapidBag | Validates that the rapids associated with the current user are active.
135
+ *RapidataIdentityAPIApi* | [**root_get**](rapidata/api_client/docs/RapidataIdentityAPIApi.md#root_get) | **GET** / |
134
136
  *SimpleWorkflowApi* | [**simple_workflow_get_result_overview_get**](rapidata/api_client/docs/SimpleWorkflowApi.md#simple_workflow_get_result_overview_get) | **GET** /SimpleWorkflow/GetResultOverview | Get the result overview for a simple workflow.
137
+ *UserInfoApi* | [**connect_userinfo_get**](rapidata/api_client/docs/UserInfoApi.md#connect_userinfo_get) | **GET** /connect/userinfo | Retrieves information about the authenticated user.
135
138
  *ValidationApi* | [**validation_add_validation_rapid_post**](rapidata/api_client/docs/ValidationApi.md#validation_add_validation_rapid_post) | **POST** /Validation/AddValidationRapid | Adds a new validation rapid to the specified validation set.
136
139
  *ValidationApi* | [**validation_add_validation_text_rapid_post**](rapidata/api_client/docs/ValidationApi.md#validation_add_validation_text_rapid_post) | **POST** /Validation/AddValidationTextRapid | Adds a new validation rapid to the specified validation set.
137
140
  *ValidationApi* | [**validation_create_validation_set_post**](rapidata/api_client/docs/ValidationApi.md#validation_create_validation_set_post) | **POST** /Validation/CreateValidationSet | Creates a new empty validation set.
@@ -273,9 +276,11 @@ Class | Method | HTTP request | Description
273
276
  - [ImportValidationSetFromFileResult](rapidata/api_client/docs/ImportValidationSetFromFileResult.md)
274
277
  - [InProgressRapidModel](rapidata/api_client/docs/InProgressRapidModel.md)
275
278
  - [IssueAuthTokenResult](rapidata/api_client/docs/IssueAuthTokenResult.md)
276
- - [IssueClientAuthTokenResult](rapidata/api_client/docs/IssueClientAuthTokenResult.md)
277
279
  - [LabelingSelection](rapidata/api_client/docs/LabelingSelection.md)
278
280
  - [LanguageUserFilterModel](rapidata/api_client/docs/LanguageUserFilterModel.md)
281
+ - [LegacyIssueClientAuthTokenResult](rapidata/api_client/docs/LegacyIssueClientAuthTokenResult.md)
282
+ - [LegacyRequestPasswordResetCommand](rapidata/api_client/docs/LegacyRequestPasswordResetCommand.md)
283
+ - [LegacySubmitPasswordResetCommand](rapidata/api_client/docs/LegacySubmitPasswordResetCommand.md)
279
284
  - [Line](rapidata/api_client/docs/Line.md)
280
285
  - [LinePayload](rapidata/api_client/docs/LinePayload.md)
281
286
  - [LinePoint](rapidata/api_client/docs/LinePoint.md)
@@ -334,7 +339,6 @@ Class | Method | HTTP request | Description
334
339
  - [RapidResultModel](rapidata/api_client/docs/RapidResultModel.md)
335
340
  - [RapidResultModelResult](rapidata/api_client/docs/RapidResultModelResult.md)
336
341
  - [RapidSkippedModel](rapidata/api_client/docs/RapidSkippedModel.md)
337
- - [RequestPasswordResetCommand](rapidata/api_client/docs/RequestPasswordResetCommand.md)
338
342
  - [RootFilter](rapidata/api_client/docs/RootFilter.md)
339
343
  - [SendCompletionMailStepModel](rapidata/api_client/docs/SendCompletionMailStepModel.md)
340
344
  - [Shape](rapidata/api_client/docs/Shape.md)
@@ -354,7 +358,6 @@ Class | Method | HTTP request | Description
354
358
  - [StaticSelection](rapidata/api_client/docs/StaticSelection.md)
355
359
  - [SubmitCocoModel](rapidata/api_client/docs/SubmitCocoModel.md)
356
360
  - [SubmitCocoResult](rapidata/api_client/docs/SubmitCocoResult.md)
357
- - [SubmitPasswordResetCommand](rapidata/api_client/docs/SubmitPasswordResetCommand.md)
358
361
  - [TextAsset](rapidata/api_client/docs/TextAsset.md)
359
362
  - [TextAssetModel](rapidata/api_client/docs/TextAssetModel.md)
360
363
  - [TextMetadata](rapidata/api_client/docs/TextMetadata.md)
@@ -403,6 +406,10 @@ Authentication schemes defined for the API:
403
406
  - **API key parameter name**: Authorization
404
407
  - **Location**: HTTP header
405
408
 
409
+ <a id="oauth2"></a>
410
+ ### oauth2
411
+
412
+
406
413
 
407
414
  ## Author
408
415
 
@@ -1,13 +1,24 @@
1
1
  from .rapidata_client import RapidataClient
2
- from .workflow import ClassifyWorkflow, TranscriptionWorkflow, CompareWorkflow, FreeTextWorkflow
2
+ from .workflow import (
3
+ ClassifyWorkflow,
4
+ TranscriptionWorkflow,
5
+ CompareWorkflow,
6
+ FreeTextWorkflow,
7
+ )
3
8
  from .selection import (
4
9
  DemographicSelection,
5
10
  LabelingSelection,
6
11
  ValidationSelection,
7
12
  ConditionalValidationSelection,
13
+ CappedSelection,
8
14
  )
9
15
  from .referee import NaiveReferee, ClassifyEarlyStoppingReferee
10
- from .metadata import PrivateTextMetadata, PublicTextMetadata, PromptMetadata, TranscriptionMetadata
16
+ from .metadata import (
17
+ PrivateTextMetadata,
18
+ PublicTextMetadata,
19
+ PromptMetadata,
20
+ TranscriptionMetadata,
21
+ )
11
22
  from .feature_flags import FeatureFlags
12
23
  from .country_codes import CountryCodes
13
24
  from .assets import MediaAsset, TextAsset, MultiAsset
@@ -23,6 +23,8 @@ class MultiAsset(BaseAsset):
23
23
  Args:
24
24
  assets (List[BaseAsset]): A list of BaseAsset instances to be managed together.
25
25
  """
26
+ if len(assets) != 2:
27
+ raise ValueError("Assets must come in pairs for comparison tasks.")
26
28
  self.assets = assets
27
29
 
28
30
  def __len__(self) -> int:
@@ -9,6 +9,7 @@ from rapidata.api_client.models.upload_text_sources_to_dataset_model import (
9
9
  UploadTextSourcesToDatasetModel,
10
10
  )
11
11
  from rapidata.rapidata_client.metadata.base_metadata import Metadata
12
+ from rapidata.rapidata_client.assets import TextAsset, MediaAsset, MultiAsset
12
13
  from rapidata.service import LocalFileService
13
14
  from rapidata.service.openapi_service import OpenAPIService
14
15
  from concurrent.futures import ThreadPoolExecutor, as_completed
@@ -22,7 +23,8 @@ class RapidataDataset:
22
23
  self.openapi_service = openapi_service
23
24
  self.local_file_service = LocalFileService()
24
25
 
25
- def add_texts(self, texts: list[str]):
26
+ def add_texts(self, text_assets: list[TextAsset]):
27
+ texts = [text.text for text in text_assets]
26
28
  model = UploadTextSourcesToDatasetModel(
27
29
  datasetId=self.dataset_id, textSources=texts
28
30
  )
@@ -32,24 +34,26 @@ class RapidataDataset:
32
34
 
33
35
  def add_media_from_paths(
34
36
  self,
35
- media_paths: list[str | list[str]],
37
+ media_paths: list[MediaAsset | MultiAsset],
36
38
  metadata: list[Metadata] | None = None,
37
39
  max_workers: int = 10,
38
40
  ):
39
41
  if metadata is not None and len(metadata) != len(media_paths):
40
42
  raise ValueError(
41
- "metadata must be None or have the same length as image_paths"
43
+ "metadata must be None or have the same length as media_paths"
42
44
  )
43
45
 
44
- def upload_datapoint(media_paths_rapid: str | list[str], meta: Metadata | None) -> None:
45
- if isinstance(media_paths_rapid, list) and not all(
46
- os.path.exists(media_path) for media_path in media_paths_rapid
47
- ):
48
- raise FileNotFoundError(f"File not found: {media_paths_rapid}")
49
- elif isinstance(media_paths_rapid, str) and not os.path.exists(
50
- media_paths_rapid
51
- ):
52
- raise FileNotFoundError(f"File not found: {media_paths_rapid}")
46
+ def upload_datapoint(media_asset: MediaAsset | MultiAsset, meta: Metadata | None) -> None:
47
+ if isinstance(media_asset, MediaAsset):
48
+ paths = [media_asset.path]
49
+ elif isinstance(media_asset, MultiAsset):
50
+ paths = [asset.path for asset in media_asset.assets if isinstance(asset, MediaAsset)]
51
+ else:
52
+ raise ValueError(f"Unsupported asset type: {type(media_asset)}")
53
+
54
+ assert all(
55
+ os.path.exists(media_path) for media_path in paths
56
+ ), "All media paths must exist on the local filesystem."
53
57
 
54
58
  meta_model = meta.to_model() if meta else None
55
59
  model = DatapointMetadataModel(
@@ -63,14 +67,14 @@ class RapidataDataset:
63
67
 
64
68
  self.openapi_service.dataset_api.dataset_create_datapoint_post(
65
69
  model=model,
66
- files=media_paths_rapid if isinstance(media_paths_rapid, list) else [media_paths_rapid] # type: ignore
70
+ files=paths # type: ignore
67
71
  )
68
72
 
69
73
  total_uploads = len(media_paths)
70
74
  with ThreadPoolExecutor(max_workers=max_workers) as executor:
71
75
  futures = [
72
- executor.submit(upload_datapoint, media_paths, meta)
73
- for media_paths, meta in zip_longest(media_paths, metadata or [])
76
+ executor.submit(upload_datapoint, media_asset, meta)
77
+ for media_asset, meta in zip_longest(media_paths, metadata or [])
74
78
  ]
75
79
 
76
80
  with tqdm(total=total_uploads, desc="Uploading datapoints") as pbar:
@@ -209,7 +209,7 @@ class ValidationSetBuilder:
209
209
  payload=payload,
210
210
  truths=model_truth,
211
211
  metadata=metadata,
212
- randomCorrectProbability=1 / len(transcription),
212
+ randomCorrectProbability = 1 / len(transcription_words),
213
213
  )
214
214
  )
215
215
 
@@ -3,7 +3,9 @@ from rapidata.rapidata_client.dataset.rapidata_dataset import RapidataDataset
3
3
  from rapidata.service.openapi_service import OpenAPIService
4
4
  import json
5
5
  from rapidata.api_client.exceptions import ApiException
6
-
6
+ from typing import cast
7
+ from rapidata.api_client.models.workflow_artifact_model import WorkflowArtifactModel
8
+ from tqdm import tqdm
7
9
 
8
10
  class RapidataOrder:
9
11
  """
@@ -26,6 +28,7 @@ class RapidataOrder:
26
28
  self.openapi_service = openapi_service
27
29
  self.order_id = order_id
28
30
  self._dataset = dataset
31
+ self._workflow_id = None
29
32
 
30
33
  def submit(self):
31
34
  """
@@ -49,27 +52,55 @@ class RapidataOrder:
49
52
  """
50
53
  return self.openapi_service.order_api.order_get_by_id_get(self.order_id)
51
54
 
52
- def wait_for_done(self):
55
+ def display_progress_bar(self, refresh_rate=5):
53
56
  """
54
- Blocking call that waits for the order to be done. Exponential backoff is used to check the status of the order.
57
+ Displays a progress bar for the order processing using tqdm.
58
+
59
+ :param refresh_rate: How often to refresh the progress bar, in seconds.
60
+ :type refresh_rate: float
55
61
  """
56
- wait_time = 1
57
- back_off_factor = 1.1
58
- minimum_poll_interval = 60 # 1 minute
59
-
60
- while True:
61
- time.sleep(wait_time)
62
- result = self.get_status()
63
- if result.state == "ManualReview":
64
- print(
65
- "Order is in manual review. Please contact support for approval. Will continue polling."
66
- )
62
+ total_rapids = self._get_total_rapids()
63
+ with tqdm(total=total_rapids, desc="Processing order", unit="rapids") as pbar:
64
+ completed_rapids = 0
65
+ while True:
66
+ current_completed = self._get_completed_rapids()
67
+ if current_completed > completed_rapids:
68
+ pbar.update(current_completed - completed_rapids)
69
+ completed_rapids = current_completed
70
+
71
+ if completed_rapids >= total_rapids:
72
+ break
73
+
74
+ time.sleep(refresh_rate)
67
75
 
68
- if result.state == "Completed" or result.state == "Failed":
76
+ def _get_workflow_id(self):
77
+ if self._workflow_id:
78
+ return self._workflow_id
79
+
80
+ for _ in range(2):
81
+ try:
82
+ order_result = self.openapi_service.order_api.order_get_by_id_get(self.order_id)
83
+ pipeline = self.openapi_service.pipeline_api.pipeline_id_get(order_result.pipeline_id)
84
+ self._workflow_id = cast(WorkflowArtifactModel, pipeline.artifacts["workflow-artifact"].actual_instance).workflow_id
69
85
  break
70
- wait_time = max(
71
- minimum_poll_interval, wait_time * back_off_factor
72
- ) # poll at least every 10 minutes
86
+ except Exception:
87
+ time.sleep(2)
88
+ if not self._workflow_id:
89
+ raise Exception("Order has not started yet. Please wait for a few seconds and try again.")
90
+ return self._workflow_id
91
+
92
+ def _get_total_rapids(self):
93
+ workflow_id = self._get_workflow_id()
94
+ return self.openapi_service.workflow_api.workflow_get_progress_get(workflow_id).total
95
+
96
+ def _get_completed_rapids(self):
97
+ workflow_id = self._get_workflow_id()
98
+ return self.openapi_service.workflow_api.workflow_get_progress_get(workflow_id).completed
99
+
100
+ def get_progress_percentage(self):
101
+ workflow_id = self._get_workflow_id()
102
+ progress = self.openapi_service.workflow_api.workflow_get_progress_get(workflow_id)
103
+ return progress.completion_percentage
73
104
 
74
105
  def get_results(self):
75
106
  """
@@ -28,6 +28,10 @@ from rapidata.service.openapi_service import OpenAPIService
28
28
 
29
29
  from rapidata.rapidata_client.workflow.compare_workflow import CompareWorkflow
30
30
 
31
+ from rapidata.rapidata_client.assets import MediaAsset, TextAsset, MultiAsset
32
+
33
+ from typing import cast, Sequence
34
+
31
35
 
32
36
  class RapidataOrderBuilder:
33
37
  """Builder object for creating Rapidata orders.
@@ -55,7 +59,6 @@ class RapidataOrderBuilder:
55
59
  self._openapi_service = openapi_service
56
60
  self._workflow: Workflow | None = None
57
61
  self._referee: Referee | None = None
58
- self._media_paths: list[str | list[str]] = []
59
62
  self._metadata: list[Metadata] | None = None
60
63
  self._aggregator: AggregatorType | None = None
61
64
  self._validation_set_id: str | None = None
@@ -65,8 +68,7 @@ class RapidataOrderBuilder:
65
68
  self._selections: list[Selection] = []
66
69
  self._rapids_per_bag: int = 2
67
70
  self._priority: int = 50
68
- self._texts: list[str] | None = None
69
- self._media_paths: list[str | list[str]] = []
71
+ self._assets: list[MediaAsset] | list[TextAsset] | list[MultiAsset] = []
70
72
 
71
73
  def _to_model(self) -> CreateOrderModel:
72
74
  """
@@ -143,8 +145,12 @@ class RapidataOrderBuilder:
143
145
  if isinstance(
144
146
  self._workflow, CompareWorkflow
145
147
  ): # Temporary fix; will be handled by backend in the future
148
+ assert all(isinstance(item, MultiAsset) for item in self._assets), (
149
+ "The media paths must be of type MultiAsset for comparison tasks."
150
+ )
151
+ media_paths = cast(list[MultiAsset], self._assets)
146
152
  assert all(
147
- [len(path) == 2 for path in self._media_paths]
153
+ [len(path) == 2 for path in media_paths]
148
154
  ), "The media paths must come in pairs for comparison tasks."
149
155
 
150
156
  result = self._openapi_service.order_api.order_create_post(
@@ -159,22 +165,18 @@ class RapidataOrderBuilder:
159
165
  openapi_service=self._openapi_service,
160
166
  )
161
167
 
162
- if self._media_paths and self._texts:
163
- raise ValueError(
164
- "You cannot provide both media paths and texts to the same order."
165
- )
166
-
167
- if not self._media_paths and not self._texts:
168
+ if not self._assets:
168
169
  raise ValueError(
169
- "You must provide either media paths or texts to the order."
170
+ "You must provide assets to start the order."
170
171
  )
172
+ if all(isinstance(item, TextAsset) for item in self._assets):
173
+ assets = cast(list[TextAsset], self._assets)
174
+ order.dataset.add_texts(assets)
171
175
 
172
- if self._texts:
173
- order.dataset.add_texts(self._texts)
174
-
175
- if self._media_paths:
176
+ elif all(isinstance(item, (MediaAsset, MultiAsset)) for item in self._assets):
177
+ assets = cast(list[MediaAsset | MultiAsset], self._assets)
176
178
  order.dataset.add_media_from_paths(
177
- self._media_paths, self._metadata, max_workers
179
+ assets, self._metadata, max_workers
178
180
  )
179
181
 
180
182
  if submit:
@@ -210,34 +212,21 @@ class RapidataOrderBuilder:
210
212
 
211
213
  def media(
212
214
  self,
213
- media_paths: list[str | list[str]],
214
- metadata: list[Metadata] | None = None,
215
+ asset: list[MediaAsset] | list[TextAsset] | list[MultiAsset],
216
+ metadata: Sequence[Metadata] | None = None,
215
217
  ) -> "RapidataOrderBuilder":
216
218
  """
217
219
  Set the media assets for the order.
218
220
 
219
221
  Args:
220
- media_paths (list[str | list[str]]): The paths of the media assets to be set.
222
+ media_paths (list[MediaAsset] | list[TextAsset] | list[MultiAsset]): The paths of the media assets to be set.
221
223
  metadata (list[Metadata] | None, optional): Metadata for the media assets. Defaults to None.
222
224
 
223
225
  Returns:
224
226
  RapidataOrderBuilder: The updated RapidataOrderBuilder instance.
225
227
  """
226
- self._media_paths = media_paths
227
- self._metadata = metadata
228
- return self
229
-
230
- def texts(self, texts: list[str]) -> "RapidataOrderBuilder":
231
- """
232
- Set the TextAssets for the order.
233
-
234
- Args:
235
- texts (list[str]): The texts to be set.
236
-
237
- Returns:
238
- RapidataOrderBuilder: The updated RapidataOrderBuilder instance.
239
- """
240
- self._texts = texts
228
+ self._assets = asset
229
+ self._metadata = metadata # type: ignore
241
230
  return self
242
231
 
243
232
  def feature_flags(self, feature_flags: FeatureFlags) -> "RapidataOrderBuilder":
@@ -3,3 +3,4 @@ from .demographic_selection import DemographicSelection
3
3
  from .labeling_selection import LabelingSelection
4
4
  from .validation_selection import ValidationSelection
5
5
  from .conditional_validation_selection import ConditionalValidationSelection
6
+ from .capped_selection import CappedSelection
@@ -0,0 +1,25 @@
1
+ from rapidata.api_client.models.capped_selection import (
2
+ CappedSelection as CappedSelectionModel,
3
+ )
4
+ from rapidata.api_client.models.capped_selection_selections_inner import (
5
+ CappedSelectionSelectionsInner,
6
+ )
7
+ from rapidata.rapidata_client.selection.base_selection import Selection
8
+ from typing import Sequence
9
+
10
+
11
+ class CappedSelection(Selection):
12
+
13
+ def __init__(self, selections: Sequence[Selection], max_rapids: int):
14
+ self.selections = selections
15
+ self.max_rapids = max_rapids
16
+
17
+ def to_model(self):
18
+ return CappedSelectionModel(
19
+ _t="CappedSelection",
20
+ selections=[
21
+ CappedSelectionSelectionsInner(selection.to_model())
22
+ for selection in self.selections
23
+ ],
24
+ max_rapids=self.max_rapids,
25
+ )
File without changes
@@ -7,6 +7,8 @@ from rapidata.rapidata_client.workflow.classify_workflow import ClassifyWorkflow
7
7
  from rapidata.rapidata_client.selection.validation_selection import ValidationSelection
8
8
  from rapidata.rapidata_client.selection.labeling_selection import LabelingSelection
9
9
  from rapidata.service.openapi_service import OpenAPIService
10
+ from rapidata.rapidata_client.assets import MediaAsset
11
+ from typing import Sequence
10
12
 
11
13
  class ClassificationOrderBuilder:
12
14
  def __init__(self, name: str, question: str, options: list[str], media_paths: list[str], openapi_service: OpenAPIService):
@@ -19,7 +21,7 @@ class ClassificationOrderBuilder:
19
21
  self._metadata = None
20
22
  self._validation_set_id = None
21
23
 
22
- def metadata(self, metadata: list[Metadata]):
24
+ def metadata(self, metadata: Sequence[Metadata]):
23
25
  """Set the metadata for the classification order. Has to be the same lenght as the media paths."""
24
26
  self._metadata = metadata
25
27
  return self
@@ -28,27 +30,29 @@ class ClassificationOrderBuilder:
28
30
  """Set the number of responses required for the classification order."""
29
31
  self._responses_required = responses_required
30
32
  return self
31
-
33
+
32
34
  def probability_threshold(self, probability_threshold: float):
33
35
  """Set the probability threshold for early stopping."""
34
36
  self._probability_threshold = probability_threshold
35
37
  return self
36
-
38
+
37
39
  def validation_set_id(self, validation_set_id: str):
38
40
  """Set the validation set ID for the classification order."""
39
41
  self._validation_set_id = validation_set_id
40
42
  return self
41
-
43
+
42
44
  def create(self, submit: bool = True, max_upload_workers: int = 10):
43
45
  if self._probability_threshold and self._responses_required:
44
46
  referee = ClassifyEarlyStoppingReferee(
45
47
  max_vote_count=self._responses_required,
46
48
  threshold=self._probability_threshold
47
49
  )
48
-
50
+
49
51
  else:
50
52
  referee = NaiveReferee(required_guesses=self._responses_required)
51
-
53
+
54
+ assets = [MediaAsset(path=media_path) for media_path in self._media_paths]
55
+
52
56
  selection: list[Selection] = ([ValidationSelection(amount=1, validation_set_id=self._validation_set_id), LabelingSelection(amount=2)]
53
57
  if self._validation_set_id
54
58
  else [LabelingSelection(amount=3)])
@@ -61,14 +65,15 @@ class ClassificationOrderBuilder:
61
65
  )
62
66
  )
63
67
  .referee(referee)
64
- .media(self._media_paths, metadata=self._metadata) # type: ignore
68
+ .media(assets, metadata=self._metadata) # type: ignore
65
69
  .selections(selection)
66
70
  .create(submit=submit, max_workers=max_upload_workers))
67
71
 
68
72
  return order
69
-
73
+
70
74
 
71
75
  class ClassificationMediaBuilder:
76
+ "test"
72
77
  def __init__(self, name: str, question: str, options: list[str], openapi_service: OpenAPIService):
73
78
  self._openapi_service = openapi_service
74
79
  self._name = name
@@ -85,7 +90,7 @@ class ClassificationMediaBuilder:
85
90
  if self._media_paths is None:
86
91
  raise ValueError("Media paths are required")
87
92
  return ClassificationOrderBuilder(self._name, self._question, self._options, self._media_paths, openapi_service=self._openapi_service)
88
-
93
+
89
94
 
90
95
  class ClassificationOptionsBuilder:
91
96
  def __init__(self, name: str, question: str, openapi_service: OpenAPIService):
@@ -1,11 +1,13 @@
1
1
  from rapidata.service.openapi_service import OpenAPIService
2
- from rapidata.rapidata_client.metadata.base_metadata import Metadata
2
+ from rapidata.rapidata_client.metadata import Metadata
3
3
  from rapidata.rapidata_client.order.rapidata_order_builder import RapidataOrderBuilder
4
4
  from rapidata.rapidata_client.workflow.compare_workflow import CompareWorkflow
5
5
  from rapidata.rapidata_client.referee.naive_referee import NaiveReferee
6
6
  from rapidata.rapidata_client.selection.validation_selection import ValidationSelection
7
7
  from rapidata.rapidata_client.selection.labeling_selection import LabelingSelection
8
8
  from rapidata.rapidata_client.selection.base_selection import Selection
9
+ from rapidata.rapidata_client.assets import MultiAsset, MediaAsset
10
+ from typing import Sequence
9
11
 
10
12
  class CompareOrderBuilder:
11
13
  def __init__(self, name:str, criteria: str, media_paths: list[list[str]], openapi_service: OpenAPIService):
@@ -22,7 +24,7 @@ class CompareOrderBuilder:
22
24
  self._responses_required = responses_required
23
25
  return self
24
26
 
25
- def metadata(self, metadata: list[Metadata]) -> 'CompareOrderBuilder':
27
+ def metadata(self, metadata: Sequence[Metadata]) -> 'CompareOrderBuilder':
26
28
  """Set the metadata for the comparison order. Has to be the same shape as the media paths."""
27
29
  self._metadata = metadata
28
30
  return self
@@ -37,6 +39,7 @@ class CompareOrderBuilder:
37
39
  if self._validation_set_id
38
40
  else [LabelingSelection(amount=3)])
39
41
 
42
+ media_paths = [MultiAsset([MediaAsset(path=path) for path in paths]) for paths in self._media_paths]
40
43
  order = (self._order_builder
41
44
  .workflow(
42
45
  CompareWorkflow(
@@ -44,7 +47,7 @@ class CompareOrderBuilder:
44
47
  )
45
48
  )
46
49
  .referee(NaiveReferee(required_guesses=self._responses_required))
47
- .media(self._media_paths, metadata=self._metadata) # type: ignore
50
+ .media(media_paths, metadata=self._metadata) # type: ignore
48
51
  .selections(selection)
49
52
  .create(submit=submit, max_workers=max_upload_workers))
50
53