rapidata 1.0.0__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. rapidata/__init__.py +1 -0
  2. rapidata/api_client/__init__.py +5 -3
  3. rapidata/api_client/api/__init__.py +2 -0
  4. rapidata/api_client/api/campaign_api.py +8 -4
  5. rapidata/api_client/api/coco_api.py +4 -2
  6. rapidata/api_client/api/compare_workflow_api.py +2 -1
  7. rapidata/api_client/api/datapoint_api.py +6 -3
  8. rapidata/api_client/api/dataset_api.py +16 -8
  9. rapidata/api_client/api/identity_api.py +329 -50
  10. rapidata/api_client/api/newsletter_api.py +4 -2
  11. rapidata/api_client/api/order_api.py +40 -20
  12. rapidata/api_client/api/pipeline_api.py +6 -3
  13. rapidata/api_client/api/rapid_api.py +10 -5
  14. rapidata/api_client/api/rapidata_identity_api_api.py +272 -0
  15. rapidata/api_client/api/simple_workflow_api.py +2 -1
  16. rapidata/api_client/api/user_info_api.py +272 -0
  17. rapidata/api_client/api/validation_api.py +14 -7
  18. rapidata/api_client/api/workflow_api.py +18 -9
  19. rapidata/api_client/models/__init__.py +3 -3
  20. rapidata/api_client/models/issue_auth_token_result.py +1 -1
  21. rapidata/api_client/models/legacy_issue_client_auth_token_result.py +87 -0
  22. rapidata/api_client/models/legacy_request_password_reset_command.py +98 -0
  23. rapidata/api_client/models/legacy_submit_password_reset_command.py +102 -0
  24. rapidata/api_client_README.md +10 -3
  25. rapidata/rapidata_client/__init__.py +13 -2
  26. rapidata/rapidata_client/assets/multi_asset.py +2 -0
  27. rapidata/rapidata_client/dataset/rapidata_dataset.py +19 -15
  28. rapidata/rapidata_client/dataset/validation_set_builder.py +13 -15
  29. rapidata/rapidata_client/order/rapidata_order.py +49 -18
  30. rapidata/rapidata_client/order/rapidata_order_builder.py +59 -43
  31. rapidata/rapidata_client/selection/__init__.py +1 -0
  32. rapidata/rapidata_client/selection/capped_selection.py +25 -0
  33. rapidata/rapidata_client/simple_builders/__init__.py +0 -0
  34. rapidata/rapidata_client/simple_builders/simple_classification_builders.py +14 -9
  35. rapidata/rapidata_client/simple_builders/simple_compare_builders.py +6 -3
  36. rapidata/service/openapi_service.py +15 -0
  37. {rapidata-1.0.0.dist-info → rapidata-1.2.0.dist-info}/METADATA +2 -1
  38. {rapidata-1.0.0.dist-info → rapidata-1.2.0.dist-info}/RECORD +40 -33
  39. {rapidata-1.0.0.dist-info → rapidata-1.2.0.dist-info}/WHEEL +1 -1
  40. {rapidata-1.0.0.dist-info → rapidata-1.2.0.dist-info}/LICENSE +0 -0
@@ -0,0 +1,102 @@
1
+ # coding: utf-8
2
+
3
+ """
4
+ Rapidata.Dataset
5
+
6
+ No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
7
+
8
+ The version of the OpenAPI document: v1
9
+ Generated by OpenAPI Generator (https://openapi-generator.tech)
10
+
11
+ Do not edit the class manually.
12
+ """ # noqa: E501
13
+
14
+
15
+ from __future__ import annotations
16
+ import pprint
17
+ import re # noqa: F401
18
+ import json
19
+
20
+ from pydantic import BaseModel, ConfigDict, Field, StrictStr, field_validator
21
+ from typing import Any, ClassVar, Dict, List
22
+ from typing import Optional, Set
23
+ from typing_extensions import Self
24
+
25
+ class LegacySubmitPasswordResetCommand(BaseModel):
26
+ """
27
+ LegacySubmitPasswordResetCommand
28
+ """ # noqa: E501
29
+ t: StrictStr = Field(description="Discriminator value for LegacySubmitPasswordResetCommand", alias="_t")
30
+ user_id: StrictStr = Field(alias="userId")
31
+ password: StrictStr
32
+ password_repeated: StrictStr = Field(alias="passwordRepeated")
33
+ reset_token: StrictStr = Field(alias="resetToken")
34
+ __properties: ClassVar[List[str]] = ["_t", "userId", "password", "passwordRepeated", "resetToken"]
35
+
36
+ @field_validator('t')
37
+ def t_validate_enum(cls, value):
38
+ """Validates the enum"""
39
+ if value not in set(['LegacySubmitPasswordResetCommand']):
40
+ raise ValueError("must be one of enum values ('LegacySubmitPasswordResetCommand')")
41
+ return value
42
+
43
+ model_config = ConfigDict(
44
+ populate_by_name=True,
45
+ validate_assignment=True,
46
+ protected_namespaces=(),
47
+ )
48
+
49
+
50
+ def to_str(self) -> str:
51
+ """Returns the string representation of the model using alias"""
52
+ return pprint.pformat(self.model_dump(by_alias=True))
53
+
54
+ def to_json(self) -> str:
55
+ """Returns the JSON representation of the model using alias"""
56
+ # TODO: pydantic v2: use .model_dump_json(by_alias=True, exclude_unset=True) instead
57
+ return json.dumps(self.to_dict())
58
+
59
+ @classmethod
60
+ def from_json(cls, json_str: str) -> Optional[Self]:
61
+ """Create an instance of LegacySubmitPasswordResetCommand from a JSON string"""
62
+ return cls.from_dict(json.loads(json_str))
63
+
64
+ def to_dict(self) -> Dict[str, Any]:
65
+ """Return the dictionary representation of the model using alias.
66
+
67
+ This has the following differences from calling pydantic's
68
+ `self.model_dump(by_alias=True)`:
69
+
70
+ * `None` is only added to the output dict for nullable fields that
71
+ were set at model initialization. Other fields with value `None`
72
+ are ignored.
73
+ """
74
+ excluded_fields: Set[str] = set([
75
+ ])
76
+
77
+ _dict = self.model_dump(
78
+ by_alias=True,
79
+ exclude=excluded_fields,
80
+ exclude_none=True,
81
+ )
82
+ return _dict
83
+
84
+ @classmethod
85
+ def from_dict(cls, obj: Optional[Dict[str, Any]]) -> Optional[Self]:
86
+ """Create an instance of LegacySubmitPasswordResetCommand from a dict"""
87
+ if obj is None:
88
+ return None
89
+
90
+ if not isinstance(obj, dict):
91
+ return cls.model_validate(obj)
92
+
93
+ _obj = cls.model_validate({
94
+ "_t": obj.get("_t") if obj.get("_t") is not None else 'LegacySubmitPasswordResetCommand',
95
+ "userId": obj.get("userId"),
96
+ "password": obj.get("password"),
97
+ "passwordRepeated": obj.get("passwordRepeated"),
98
+ "resetToken": obj.get("resetToken")
99
+ })
100
+ return _obj
101
+
102
+
@@ -97,6 +97,7 @@ Class | Method | HTTP request | Description
97
97
  *IdentityApi* | [**identity_get_client_auth_token_post**](rapidata/api_client/docs/IdentityApi.md#identity_get_client_auth_token_post) | **POST** /Identity/GetClientAuthToken | Issues a new auth token using the client credentials.
98
98
  *IdentityApi* | [**identity_index_post**](rapidata/api_client/docs/IdentityApi.md#identity_index_post) | **POST** /Identity/Index | Logs in a user by username or email and password.
99
99
  *IdentityApi* | [**identity_logout_post**](rapidata/api_client/docs/IdentityApi.md#identity_logout_post) | **POST** /Identity/Logout | Logs out the current user by deleting the refresh token cookie.
100
+ *IdentityApi* | [**identity_register_temporary_post**](rapidata/api_client/docs/IdentityApi.md#identity_register_temporary_post) | **POST** /Identity/RegisterTemporary | Registers and logs in a temporary customer.
100
101
  *IdentityApi* | [**identity_request_reset_post**](rapidata/api_client/docs/IdentityApi.md#identity_request_reset_post) | **POST** /Identity/RequestReset | Request a password reset for a user.
101
102
  *IdentityApi* | [**identity_signup_post**](rapidata/api_client/docs/IdentityApi.md#identity_signup_post) | **POST** /Identity/Signup | Signs up a new user.
102
103
  *IdentityApi* | [**identity_submit_reset_post**](rapidata/api_client/docs/IdentityApi.md#identity_submit_reset_post) | **POST** /Identity/SubmitReset | Updates the password of a user after a password reset request.
@@ -131,7 +132,9 @@ Class | Method | HTTP request | Description
131
132
  *RapidApi* | [**rapid_query_validation_rapids_get**](rapidata/api_client/docs/RapidApi.md#rapid_query_validation_rapids_get) | **GET** /Rapid/QueryValidationRapids | Queries the validation rapids for a specific validation set.
132
133
  *RapidApi* | [**rapid_skip_user_guess_post**](rapidata/api_client/docs/RapidApi.md#rapid_skip_user_guess_post) | **POST** /Rapid/SkipUserGuess | Skips a Rapid for the user.
133
134
  *RapidApi* | [**rapid_validate_current_rapid_bag_get**](rapidata/api_client/docs/RapidApi.md#rapid_validate_current_rapid_bag_get) | **GET** /Rapid/ValidateCurrentRapidBag | Validates that the rapids associated with the current user are active.
135
+ *RapidataIdentityAPIApi* | [**root_get**](rapidata/api_client/docs/RapidataIdentityAPIApi.md#root_get) | **GET** / |
134
136
  *SimpleWorkflowApi* | [**simple_workflow_get_result_overview_get**](rapidata/api_client/docs/SimpleWorkflowApi.md#simple_workflow_get_result_overview_get) | **GET** /SimpleWorkflow/GetResultOverview | Get the result overview for a simple workflow.
137
+ *UserInfoApi* | [**connect_userinfo_get**](rapidata/api_client/docs/UserInfoApi.md#connect_userinfo_get) | **GET** /connect/userinfo | Retrieves information about the authenticated user.
135
138
  *ValidationApi* | [**validation_add_validation_rapid_post**](rapidata/api_client/docs/ValidationApi.md#validation_add_validation_rapid_post) | **POST** /Validation/AddValidationRapid | Adds a new validation rapid to the specified validation set.
136
139
  *ValidationApi* | [**validation_add_validation_text_rapid_post**](rapidata/api_client/docs/ValidationApi.md#validation_add_validation_text_rapid_post) | **POST** /Validation/AddValidationTextRapid | Adds a new validation rapid to the specified validation set.
137
140
  *ValidationApi* | [**validation_create_validation_set_post**](rapidata/api_client/docs/ValidationApi.md#validation_create_validation_set_post) | **POST** /Validation/CreateValidationSet | Creates a new empty validation set.
@@ -273,9 +276,11 @@ Class | Method | HTTP request | Description
273
276
  - [ImportValidationSetFromFileResult](rapidata/api_client/docs/ImportValidationSetFromFileResult.md)
274
277
  - [InProgressRapidModel](rapidata/api_client/docs/InProgressRapidModel.md)
275
278
  - [IssueAuthTokenResult](rapidata/api_client/docs/IssueAuthTokenResult.md)
276
- - [IssueClientAuthTokenResult](rapidata/api_client/docs/IssueClientAuthTokenResult.md)
277
279
  - [LabelingSelection](rapidata/api_client/docs/LabelingSelection.md)
278
280
  - [LanguageUserFilterModel](rapidata/api_client/docs/LanguageUserFilterModel.md)
281
+ - [LegacyIssueClientAuthTokenResult](rapidata/api_client/docs/LegacyIssueClientAuthTokenResult.md)
282
+ - [LegacyRequestPasswordResetCommand](rapidata/api_client/docs/LegacyRequestPasswordResetCommand.md)
283
+ - [LegacySubmitPasswordResetCommand](rapidata/api_client/docs/LegacySubmitPasswordResetCommand.md)
279
284
  - [Line](rapidata/api_client/docs/Line.md)
280
285
  - [LinePayload](rapidata/api_client/docs/LinePayload.md)
281
286
  - [LinePoint](rapidata/api_client/docs/LinePoint.md)
@@ -334,7 +339,6 @@ Class | Method | HTTP request | Description
334
339
  - [RapidResultModel](rapidata/api_client/docs/RapidResultModel.md)
335
340
  - [RapidResultModelResult](rapidata/api_client/docs/RapidResultModelResult.md)
336
341
  - [RapidSkippedModel](rapidata/api_client/docs/RapidSkippedModel.md)
337
- - [RequestPasswordResetCommand](rapidata/api_client/docs/RequestPasswordResetCommand.md)
338
342
  - [RootFilter](rapidata/api_client/docs/RootFilter.md)
339
343
  - [SendCompletionMailStepModel](rapidata/api_client/docs/SendCompletionMailStepModel.md)
340
344
  - [Shape](rapidata/api_client/docs/Shape.md)
@@ -354,7 +358,6 @@ Class | Method | HTTP request | Description
354
358
  - [StaticSelection](rapidata/api_client/docs/StaticSelection.md)
355
359
  - [SubmitCocoModel](rapidata/api_client/docs/SubmitCocoModel.md)
356
360
  - [SubmitCocoResult](rapidata/api_client/docs/SubmitCocoResult.md)
357
- - [SubmitPasswordResetCommand](rapidata/api_client/docs/SubmitPasswordResetCommand.md)
358
361
  - [TextAsset](rapidata/api_client/docs/TextAsset.md)
359
362
  - [TextAssetModel](rapidata/api_client/docs/TextAssetModel.md)
360
363
  - [TextMetadata](rapidata/api_client/docs/TextMetadata.md)
@@ -403,6 +406,10 @@ Authentication schemes defined for the API:
403
406
  - **API key parameter name**: Authorization
404
407
  - **Location**: HTTP header
405
408
 
409
+ <a id="oauth2"></a>
410
+ ### oauth2
411
+
412
+
406
413
 
407
414
  ## Author
408
415
 
@@ -1,13 +1,24 @@
1
1
  from .rapidata_client import RapidataClient
2
- from .workflow import ClassifyWorkflow, TranscriptionWorkflow, CompareWorkflow, FreeTextWorkflow
2
+ from .workflow import (
3
+ ClassifyWorkflow,
4
+ TranscriptionWorkflow,
5
+ CompareWorkflow,
6
+ FreeTextWorkflow,
7
+ )
3
8
  from .selection import (
4
9
  DemographicSelection,
5
10
  LabelingSelection,
6
11
  ValidationSelection,
7
12
  ConditionalValidationSelection,
13
+ CappedSelection,
8
14
  )
9
15
  from .referee import NaiveReferee, ClassifyEarlyStoppingReferee
10
- from .metadata import PrivateTextMetadata, PublicTextMetadata, PromptMetadata, TranscriptionMetadata
16
+ from .metadata import (
17
+ PrivateTextMetadata,
18
+ PublicTextMetadata,
19
+ PromptMetadata,
20
+ TranscriptionMetadata,
21
+ )
11
22
  from .feature_flags import FeatureFlags
12
23
  from .country_codes import CountryCodes
13
24
  from .assets import MediaAsset, TextAsset, MultiAsset
@@ -23,6 +23,8 @@ class MultiAsset(BaseAsset):
23
23
  Args:
24
24
  assets (List[BaseAsset]): A list of BaseAsset instances to be managed together.
25
25
  """
26
+ if len(assets) != 2:
27
+ raise ValueError("Assets must come in pairs for comparison tasks.")
26
28
  self.assets = assets
27
29
 
28
30
  def __len__(self) -> int:
@@ -9,6 +9,7 @@ from rapidata.api_client.models.upload_text_sources_to_dataset_model import (
9
9
  UploadTextSourcesToDatasetModel,
10
10
  )
11
11
  from rapidata.rapidata_client.metadata.base_metadata import Metadata
12
+ from rapidata.rapidata_client.assets import TextAsset, MediaAsset, MultiAsset
12
13
  from rapidata.service import LocalFileService
13
14
  from rapidata.service.openapi_service import OpenAPIService
14
15
  from concurrent.futures import ThreadPoolExecutor, as_completed
@@ -22,7 +23,8 @@ class RapidataDataset:
22
23
  self.openapi_service = openapi_service
23
24
  self.local_file_service = LocalFileService()
24
25
 
25
- def add_texts(self, texts: list[str]):
26
+ def add_texts(self, text_assets: list[TextAsset]):
27
+ texts = [text.text for text in text_assets]
26
28
  model = UploadTextSourcesToDatasetModel(
27
29
  datasetId=self.dataset_id, textSources=texts
28
30
  )
@@ -32,24 +34,26 @@ class RapidataDataset:
32
34
 
33
35
  def add_media_from_paths(
34
36
  self,
35
- media_paths: list[str | list[str]],
37
+ media_paths: list[MediaAsset | MultiAsset],
36
38
  metadata: list[Metadata] | None = None,
37
39
  max_workers: int = 10,
38
40
  ):
39
41
  if metadata is not None and len(metadata) != len(media_paths):
40
42
  raise ValueError(
41
- "metadata must be None or have the same length as image_paths"
43
+ "metadata must be None or have the same length as media_paths"
42
44
  )
43
45
 
44
- def upload_datapoint(media_paths_rapid: str | list[str], meta: Metadata | None) -> None:
45
- if isinstance(media_paths_rapid, list) and not all(
46
- os.path.exists(media_path) for media_path in media_paths_rapid
47
- ):
48
- raise FileNotFoundError(f"File not found: {media_paths_rapid}")
49
- elif isinstance(media_paths_rapid, str) and not os.path.exists(
50
- media_paths_rapid
51
- ):
52
- raise FileNotFoundError(f"File not found: {media_paths_rapid}")
46
+ def upload_datapoint(media_asset: MediaAsset | MultiAsset, meta: Metadata | None) -> None:
47
+ if isinstance(media_asset, MediaAsset):
48
+ paths = [media_asset.path]
49
+ elif isinstance(media_asset, MultiAsset):
50
+ paths = [asset.path for asset in media_asset.assets if isinstance(asset, MediaAsset)]
51
+ else:
52
+ raise ValueError(f"Unsupported asset type: {type(media_asset)}")
53
+
54
+ assert all(
55
+ os.path.exists(media_path) for media_path in paths
56
+ ), "All media paths must exist on the local filesystem."
53
57
 
54
58
  meta_model = meta.to_model() if meta else None
55
59
  model = DatapointMetadataModel(
@@ -63,14 +67,14 @@ class RapidataDataset:
63
67
 
64
68
  self.openapi_service.dataset_api.dataset_create_datapoint_post(
65
69
  model=model,
66
- files=media_paths_rapid if isinstance(media_paths_rapid, list) else [media_paths_rapid] # type: ignore
70
+ files=paths # type: ignore
67
71
  )
68
72
 
69
73
  total_uploads = len(media_paths)
70
74
  with ThreadPoolExecutor(max_workers=max_workers) as executor:
71
75
  futures = [
72
- executor.submit(upload_datapoint, media_paths, meta)
73
- for media_paths, meta in zip_longest(media_paths, metadata or [])
76
+ executor.submit(upload_datapoint, media_asset, meta)
77
+ for media_asset, meta in zip_longest(media_paths, metadata or [])
74
78
  ]
75
79
 
76
80
  with tqdm(total=total_uploads, desc="Uploading datapoints") as pbar:
@@ -160,8 +160,8 @@ class ValidationSetBuilder:
160
160
  self,
161
161
  asset: MediaAsset | TextAsset,
162
162
  question: str,
163
- transcription: list[str],
164
- correct_words: list[str],
163
+ transcription: str,
164
+ truths: list[int],
165
165
  strict_grading: bool | None = None,
166
166
  metadata: list[Metadata] = [],
167
167
  ):
@@ -171,9 +171,9 @@ class ValidationSetBuilder:
171
171
  asset (MediaAsset | TextAsset): The asset for the rapid.
172
172
  question (str): The question for the rapid.
173
173
  transcription (list[str]): The transcription for the rapid.
174
- correct_words (list[str]): The list of correct words for the rapid.
175
- strict_grading (bool | None, optional): The strict grading flag for the rapid. Defaults to None.
176
- metadata (list[Metadata], optional): The metadata for the rapid. Defaults to an empty list.
174
+ truths (list[int]): The list of indices of the true word selections.
175
+ strict_grading (bool | None, optional): The strict grading for the rapid. Defaults to None.
176
+ metadata (list[Metadata], optional): The metadata for the rapid.
177
177
 
178
178
  Returns:
179
179
  ValidationSetBuilder: The ValidationSetBuilder instance.
@@ -183,16 +183,14 @@ class ValidationSetBuilder:
183
183
  """
184
184
  transcription_words = [
185
185
  TranscriptionWord(word=word, wordIndex=i)
186
- for i, word in enumerate(transcription)
186
+ for i, word in enumerate(transcription.split())
187
187
  ]
188
188
 
189
- correct_transcription_words = []
190
- for word in correct_words:
191
- if word not in transcription:
192
- raise ValueError(f"Correct word '{word}' not found in transcription")
193
- correct_transcription_words.append(
194
- TranscriptionWord(word=word, wordIndex=transcription.index(word))
195
- )
189
+ true_words = []
190
+ for idx in truths:
191
+ if idx > len(transcription_words) - 1:
192
+ raise ValueError(f"Index {idx} is out of bounds")
193
+ true_words.append(transcription_words[idx])
196
194
 
197
195
  payload = TranscriptionPayload(
198
196
  _t="TranscriptionPayload", title=question, transcription=transcription_words
@@ -200,7 +198,7 @@ class ValidationSetBuilder:
200
198
 
201
199
  model_truth = TranscriptionTruth(
202
200
  _t="TranscriptionTruth",
203
- correctWords=correct_transcription_words,
201
+ correctWords=true_words,
204
202
  strictGrading=strict_grading,
205
203
  )
206
204
 
@@ -211,7 +209,7 @@ class ValidationSetBuilder:
211
209
  payload=payload,
212
210
  truths=model_truth,
213
211
  metadata=metadata,
214
- randomCorrectProbability=1 / len(transcription),
212
+ randomCorrectProbability = 1 / len(transcription_words),
215
213
  )
216
214
  )
217
215
 
@@ -3,7 +3,9 @@ from rapidata.rapidata_client.dataset.rapidata_dataset import RapidataDataset
3
3
  from rapidata.service.openapi_service import OpenAPIService
4
4
  import json
5
5
  from rapidata.api_client.exceptions import ApiException
6
-
6
+ from typing import cast
7
+ from rapidata.api_client.models.workflow_artifact_model import WorkflowArtifactModel
8
+ from tqdm import tqdm
7
9
 
8
10
  class RapidataOrder:
9
11
  """
@@ -26,6 +28,7 @@ class RapidataOrder:
26
28
  self.openapi_service = openapi_service
27
29
  self.order_id = order_id
28
30
  self._dataset = dataset
31
+ self._workflow_id = None
29
32
 
30
33
  def submit(self):
31
34
  """
@@ -49,27 +52,55 @@ class RapidataOrder:
49
52
  """
50
53
  return self.openapi_service.order_api.order_get_by_id_get(self.order_id)
51
54
 
52
- def wait_for_done(self):
55
+ def display_progress_bar(self, refresh_rate=5):
53
56
  """
54
- Blocking call that waits for the order to be done. Exponential backoff is used to check the status of the order.
57
+ Displays a progress bar for the order processing using tqdm.
58
+
59
+ :param refresh_rate: How often to refresh the progress bar, in seconds.
60
+ :type refresh_rate: float
55
61
  """
56
- wait_time = 1
57
- back_off_factor = 1.1
58
- minimum_poll_interval = 60 # 1 minute
59
-
60
- while True:
61
- time.sleep(wait_time)
62
- result = self.get_status()
63
- if result.state == "ManualReview":
64
- print(
65
- "Order is in manual review. Please contact support for approval. Will continue polling."
66
- )
62
+ total_rapids = self._get_total_rapids()
63
+ with tqdm(total=total_rapids, desc="Processing order", unit="rapids") as pbar:
64
+ completed_rapids = 0
65
+ while True:
66
+ current_completed = self._get_completed_rapids()
67
+ if current_completed > completed_rapids:
68
+ pbar.update(current_completed - completed_rapids)
69
+ completed_rapids = current_completed
70
+
71
+ if completed_rapids >= total_rapids:
72
+ break
73
+
74
+ time.sleep(refresh_rate)
67
75
 
68
- if result.state == "Completed" or result.state == "Failed":
76
+ def _get_workflow_id(self):
77
+ if self._workflow_id:
78
+ return self._workflow_id
79
+
80
+ for _ in range(2):
81
+ try:
82
+ order_result = self.openapi_service.order_api.order_get_by_id_get(self.order_id)
83
+ pipeline = self.openapi_service.pipeline_api.pipeline_id_get(order_result.pipeline_id)
84
+ self._workflow_id = cast(WorkflowArtifactModel, pipeline.artifacts["workflow-artifact"].actual_instance).workflow_id
69
85
  break
70
- wait_time = max(
71
- minimum_poll_interval, wait_time * back_off_factor
72
- ) # poll at least every 10 minutes
86
+ except Exception:
87
+ time.sleep(2)
88
+ if not self._workflow_id:
89
+ raise Exception("Order has not started yet. Please wait for a few seconds and try again.")
90
+ return self._workflow_id
91
+
92
+ def _get_total_rapids(self):
93
+ workflow_id = self._get_workflow_id()
94
+ return self.openapi_service.workflow_api.workflow_get_progress_get(workflow_id).total
95
+
96
+ def _get_completed_rapids(self):
97
+ workflow_id = self._get_workflow_id()
98
+ return self.openapi_service.workflow_api.workflow_get_progress_get(workflow_id).completed
99
+
100
+ def get_progress_percentage(self):
101
+ workflow_id = self._get_workflow_id()
102
+ progress = self.openapi_service.workflow_api.workflow_get_progress_get(workflow_id)
103
+ return progress.completion_percentage
73
104
 
74
105
  def get_results(self):
75
106
  """
@@ -13,6 +13,9 @@ from rapidata.api_client.models.create_order_model_workflow import (
13
13
  CreateOrderModelWorkflow,
14
14
  )
15
15
  from rapidata.api_client.models.country_user_filter_model import CountryUserFilterModel
16
+ from rapidata.api_client.models.language_user_filter_model import (
17
+ LanguageUserFilterModel,
18
+ )
16
19
  from rapidata.rapidata_client.feature_flags import FeatureFlags
17
20
  from rapidata.rapidata_client.metadata.base_metadata import Metadata
18
21
  from rapidata.rapidata_client.dataset.rapidata_dataset import RapidataDataset
@@ -25,6 +28,10 @@ from rapidata.service.openapi_service import OpenAPIService
25
28
 
26
29
  from rapidata.rapidata_client.workflow.compare_workflow import CompareWorkflow
27
30
 
31
+ from rapidata.rapidata_client.assets import MediaAsset, TextAsset, MultiAsset
32
+
33
+ from typing import cast, Sequence
34
+
28
35
 
29
36
  class RapidataOrderBuilder:
30
37
  """Builder object for creating Rapidata orders.
@@ -52,17 +59,16 @@ class RapidataOrderBuilder:
52
59
  self._openapi_service = openapi_service
53
60
  self._workflow: Workflow | None = None
54
61
  self._referee: Referee | None = None
55
- self._media_paths: list[str | list[str]] = []
56
62
  self._metadata: list[Metadata] | None = None
57
63
  self._aggregator: AggregatorType | None = None
58
64
  self._validation_set_id: str | None = None
59
65
  self._feature_flags: FeatureFlags | None = None
60
66
  self._country_codes: list[str] | None = None
67
+ self._language_codes: list[str] | None = None
61
68
  self._selections: list[Selection] = []
62
69
  self._rapids_per_bag: int = 2
63
70
  self._priority: int = 50
64
- self._texts: list[str] | None = None
65
- self._media_paths: list[str | list[str]] = []
71
+ self._assets: list[MediaAsset] | list[TextAsset] | list[MultiAsset] = []
66
72
 
67
73
  def _to_model(self) -> CreateOrderModel:
68
74
  """
@@ -80,22 +86,32 @@ class RapidataOrderBuilder:
80
86
  if self._referee is None:
81
87
  print("No referee provided, using default NaiveReferee.")
82
88
  self._referee = NaiveReferee()
83
- if self._country_codes is None:
84
- country_filter = None
85
- else:
86
- country_filter = CountryUserFilterModel(
87
- _t="CountryFilter", countries=self._country_codes
89
+
90
+ user_filters = []
91
+
92
+ if self._country_codes is not None:
93
+ user_filters.append(
94
+ CreateOrderModelUserFiltersInner(
95
+ CountryUserFilterModel(
96
+ _t="CountryFilter", countries=self._country_codes
97
+ )
98
+ )
99
+ )
100
+
101
+ if self._language_codes is not None:
102
+ user_filters.append(
103
+ CreateOrderModelUserFiltersInner(
104
+ LanguageUserFilterModel(
105
+ _t="LanguageFilter", languages=self._language_codes
106
+ )
107
+ )
88
108
  )
89
109
 
90
110
  return CreateOrderModel(
91
111
  _t="CreateOrderModel",
92
112
  orderName=self._name,
93
113
  workflow=CreateOrderModelWorkflow(self._workflow.to_model()),
94
- userFilters=(
95
- [CreateOrderModelUserFiltersInner(country_filter)]
96
- if country_filter
97
- else []
98
- ),
114
+ userFilters=user_filters,
99
115
  referee=CreateOrderModelReferee(self._referee.to_model()),
100
116
  validationSetId=self._validation_set_id,
101
117
  featureFlags=(
@@ -129,8 +145,12 @@ class RapidataOrderBuilder:
129
145
  if isinstance(
130
146
  self._workflow, CompareWorkflow
131
147
  ): # Temporary fix; will be handled by backend in the future
148
+ assert all(isinstance(item, MultiAsset) for item in self._assets), (
149
+ "The media paths must be of type MultiAsset for comparison tasks."
150
+ )
151
+ media_paths = cast(list[MultiAsset], self._assets)
132
152
  assert all(
133
- [len(path) == 2 for path in self._media_paths]
153
+ [len(path) == 2 for path in media_paths]
134
154
  ), "The media paths must come in pairs for comparison tasks."
135
155
 
136
156
  result = self._openapi_service.order_api.order_create_post(
@@ -145,22 +165,18 @@ class RapidataOrderBuilder:
145
165
  openapi_service=self._openapi_service,
146
166
  )
147
167
 
148
- if self._media_paths and self._texts:
168
+ if not self._assets:
149
169
  raise ValueError(
150
- "You cannot provide both media paths and texts to the same order."
170
+ "You must provide assets to start the order."
151
171
  )
172
+ if all(isinstance(item, TextAsset) for item in self._assets):
173
+ assets = cast(list[TextAsset], self._assets)
174
+ order.dataset.add_texts(assets)
152
175
 
153
- if not self._media_paths and not self._texts:
154
- raise ValueError(
155
- "You must provide either media paths or texts to the order."
156
- )
157
-
158
- if self._texts:
159
- order.dataset.add_texts(self._texts)
160
-
161
- if self._media_paths:
176
+ elif all(isinstance(item, (MediaAsset, MultiAsset)) for item in self._assets):
177
+ assets = cast(list[MediaAsset | MultiAsset], self._assets)
162
178
  order.dataset.add_media_from_paths(
163
- self._media_paths, self._metadata, max_workers
179
+ assets, self._metadata, max_workers
164
180
  )
165
181
 
166
182
  if submit:
@@ -196,60 +212,60 @@ class RapidataOrderBuilder:
196
212
 
197
213
  def media(
198
214
  self,
199
- media_paths: list[str | list[str]],
200
- metadata: list[Metadata] | None = None,
215
+ asset: list[MediaAsset] | list[TextAsset] | list[MultiAsset],
216
+ metadata: Sequence[Metadata] | None = None,
201
217
  ) -> "RapidataOrderBuilder":
202
218
  """
203
219
  Set the media assets for the order.
204
220
 
205
221
  Args:
206
- media_paths (list[str | list[str]]): The paths of the media assets to be set.
222
+ media_paths (list[MediaAsset] | list[TextAsset] | list[MultiAsset]): The paths of the media assets to be set.
207
223
  metadata (list[Metadata] | None, optional): Metadata for the media assets. Defaults to None.
208
224
 
209
225
  Returns:
210
226
  RapidataOrderBuilder: The updated RapidataOrderBuilder instance.
211
227
  """
212
- self._media_paths = media_paths
213
- self._metadata = metadata
228
+ self._assets = asset
229
+ self._metadata = metadata # type: ignore
214
230
  return self
215
231
 
216
- def texts(self, texts: list[str]) -> "RapidataOrderBuilder":
232
+ def feature_flags(self, feature_flags: FeatureFlags) -> "RapidataOrderBuilder":
217
233
  """
218
- Set the TextAssets for the order.
234
+ Set the feature flags for the order.
219
235
 
220
236
  Args:
221
- texts (list[str]): The texts to be set.
237
+ feature_flags (FeatureFlags): The feature flags to be set.
222
238
 
223
239
  Returns:
224
240
  RapidataOrderBuilder: The updated RapidataOrderBuilder instance.
225
241
  """
226
- self._texts = texts
242
+ self._feature_flags = feature_flags
227
243
  return self
228
244
 
229
- def feature_flags(self, feature_flags: FeatureFlags) -> "RapidataOrderBuilder":
245
+ def country_filter(self, country_codes: list[str]) -> "RapidataOrderBuilder":
230
246
  """
231
- Set the feature flags for the order.
247
+ Set the target country codes for the order. E.g. `country_codes=["DE", "CH", "AT"]` for Germany, Switzerland, and Austria.
232
248
 
233
249
  Args:
234
- feature_flags (FeatureFlags): The feature flags to be set.
250
+ country_codes (list[str]): The country codes to be set.
235
251
 
236
252
  Returns:
237
253
  RapidataOrderBuilder: The updated RapidataOrderBuilder instance.
238
254
  """
239
- self._feature_flags = feature_flags
255
+ self._country_codes = country_codes
240
256
  return self
241
257
 
242
- def country_filter(self, country_codes: list[str]) -> "RapidataOrderBuilder":
258
+ def language_filter(self, language_codes: list[str]) -> "RapidataOrderBuilder":
243
259
  """
244
- Set the target country codes for the order.
260
+ Set the target language codes for the order. E.g. `language_codes=["de", "fr", "it"]` for German, French, and Italian.
245
261
 
246
262
  Args:
247
- country_codes (list[str]): The country codes to be set.
263
+ language_codes (list[str]): The language codes to be set.
248
264
 
249
265
  Returns:
250
266
  RapidataOrderBuilder: The updated RapidataOrderBuilder instance.
251
267
  """
252
- self._country_codes = country_codes
268
+ self._language_codes = language_codes
253
269
  return self
254
270
 
255
271
  def aggregator(self, aggregator: AggregatorType) -> "RapidataOrderBuilder":
@@ -3,3 +3,4 @@ from .demographic_selection import DemographicSelection
3
3
  from .labeling_selection import LabelingSelection
4
4
  from .validation_selection import ValidationSelection
5
5
  from .conditional_validation_selection import ConditionalValidationSelection
6
+ from .capped_selection import CappedSelection