rapidata 1.0.0__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rapidata/__init__.py +1 -0
- rapidata/api_client/__init__.py +5 -3
- rapidata/api_client/api/__init__.py +2 -0
- rapidata/api_client/api/campaign_api.py +8 -4
- rapidata/api_client/api/coco_api.py +4 -2
- rapidata/api_client/api/compare_workflow_api.py +2 -1
- rapidata/api_client/api/datapoint_api.py +6 -3
- rapidata/api_client/api/dataset_api.py +16 -8
- rapidata/api_client/api/identity_api.py +329 -50
- rapidata/api_client/api/newsletter_api.py +4 -2
- rapidata/api_client/api/order_api.py +40 -20
- rapidata/api_client/api/pipeline_api.py +6 -3
- rapidata/api_client/api/rapid_api.py +10 -5
- rapidata/api_client/api/rapidata_identity_api_api.py +272 -0
- rapidata/api_client/api/simple_workflow_api.py +2 -1
- rapidata/api_client/api/user_info_api.py +272 -0
- rapidata/api_client/api/validation_api.py +14 -7
- rapidata/api_client/api/workflow_api.py +18 -9
- rapidata/api_client/models/__init__.py +3 -3
- rapidata/api_client/models/issue_auth_token_result.py +1 -1
- rapidata/api_client/models/legacy_issue_client_auth_token_result.py +87 -0
- rapidata/api_client/models/legacy_request_password_reset_command.py +98 -0
- rapidata/api_client/models/legacy_submit_password_reset_command.py +102 -0
- rapidata/api_client_README.md +10 -3
- rapidata/rapidata_client/__init__.py +13 -2
- rapidata/rapidata_client/assets/multi_asset.py +2 -0
- rapidata/rapidata_client/dataset/rapidata_dataset.py +19 -15
- rapidata/rapidata_client/dataset/validation_set_builder.py +13 -15
- rapidata/rapidata_client/order/rapidata_order.py +49 -18
- rapidata/rapidata_client/order/rapidata_order_builder.py +59 -43
- rapidata/rapidata_client/selection/__init__.py +1 -0
- rapidata/rapidata_client/selection/capped_selection.py +25 -0
- rapidata/rapidata_client/simple_builders/__init__.py +0 -0
- rapidata/rapidata_client/simple_builders/simple_classification_builders.py +14 -9
- rapidata/rapidata_client/simple_builders/simple_compare_builders.py +6 -3
- rapidata/service/openapi_service.py +15 -0
- {rapidata-1.0.0.dist-info → rapidata-1.2.0.dist-info}/METADATA +2 -1
- {rapidata-1.0.0.dist-info → rapidata-1.2.0.dist-info}/RECORD +40 -33
- {rapidata-1.0.0.dist-info → rapidata-1.2.0.dist-info}/WHEEL +1 -1
- {rapidata-1.0.0.dist-info → rapidata-1.2.0.dist-info}/LICENSE +0 -0
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
# coding: utf-8
|
|
2
|
+
|
|
3
|
+
"""
|
|
4
|
+
Rapidata.Dataset
|
|
5
|
+
|
|
6
|
+
No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
|
|
7
|
+
|
|
8
|
+
The version of the OpenAPI document: v1
|
|
9
|
+
Generated by OpenAPI Generator (https://openapi-generator.tech)
|
|
10
|
+
|
|
11
|
+
Do not edit the class manually.
|
|
12
|
+
""" # noqa: E501
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
import pprint
|
|
17
|
+
import re # noqa: F401
|
|
18
|
+
import json
|
|
19
|
+
|
|
20
|
+
from pydantic import BaseModel, ConfigDict, Field, StrictStr, field_validator
|
|
21
|
+
from typing import Any, ClassVar, Dict, List
|
|
22
|
+
from typing import Optional, Set
|
|
23
|
+
from typing_extensions import Self
|
|
24
|
+
|
|
25
|
+
class LegacySubmitPasswordResetCommand(BaseModel):
|
|
26
|
+
"""
|
|
27
|
+
LegacySubmitPasswordResetCommand
|
|
28
|
+
""" # noqa: E501
|
|
29
|
+
t: StrictStr = Field(description="Discriminator value for LegacySubmitPasswordResetCommand", alias="_t")
|
|
30
|
+
user_id: StrictStr = Field(alias="userId")
|
|
31
|
+
password: StrictStr
|
|
32
|
+
password_repeated: StrictStr = Field(alias="passwordRepeated")
|
|
33
|
+
reset_token: StrictStr = Field(alias="resetToken")
|
|
34
|
+
__properties: ClassVar[List[str]] = ["_t", "userId", "password", "passwordRepeated", "resetToken"]
|
|
35
|
+
|
|
36
|
+
@field_validator('t')
|
|
37
|
+
def t_validate_enum(cls, value):
|
|
38
|
+
"""Validates the enum"""
|
|
39
|
+
if value not in set(['LegacySubmitPasswordResetCommand']):
|
|
40
|
+
raise ValueError("must be one of enum values ('LegacySubmitPasswordResetCommand')")
|
|
41
|
+
return value
|
|
42
|
+
|
|
43
|
+
model_config = ConfigDict(
|
|
44
|
+
populate_by_name=True,
|
|
45
|
+
validate_assignment=True,
|
|
46
|
+
protected_namespaces=(),
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def to_str(self) -> str:
|
|
51
|
+
"""Returns the string representation of the model using alias"""
|
|
52
|
+
return pprint.pformat(self.model_dump(by_alias=True))
|
|
53
|
+
|
|
54
|
+
def to_json(self) -> str:
|
|
55
|
+
"""Returns the JSON representation of the model using alias"""
|
|
56
|
+
# TODO: pydantic v2: use .model_dump_json(by_alias=True, exclude_unset=True) instead
|
|
57
|
+
return json.dumps(self.to_dict())
|
|
58
|
+
|
|
59
|
+
@classmethod
|
|
60
|
+
def from_json(cls, json_str: str) -> Optional[Self]:
|
|
61
|
+
"""Create an instance of LegacySubmitPasswordResetCommand from a JSON string"""
|
|
62
|
+
return cls.from_dict(json.loads(json_str))
|
|
63
|
+
|
|
64
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
65
|
+
"""Return the dictionary representation of the model using alias.
|
|
66
|
+
|
|
67
|
+
This has the following differences from calling pydantic's
|
|
68
|
+
`self.model_dump(by_alias=True)`:
|
|
69
|
+
|
|
70
|
+
* `None` is only added to the output dict for nullable fields that
|
|
71
|
+
were set at model initialization. Other fields with value `None`
|
|
72
|
+
are ignored.
|
|
73
|
+
"""
|
|
74
|
+
excluded_fields: Set[str] = set([
|
|
75
|
+
])
|
|
76
|
+
|
|
77
|
+
_dict = self.model_dump(
|
|
78
|
+
by_alias=True,
|
|
79
|
+
exclude=excluded_fields,
|
|
80
|
+
exclude_none=True,
|
|
81
|
+
)
|
|
82
|
+
return _dict
|
|
83
|
+
|
|
84
|
+
@classmethod
|
|
85
|
+
def from_dict(cls, obj: Optional[Dict[str, Any]]) -> Optional[Self]:
|
|
86
|
+
"""Create an instance of LegacySubmitPasswordResetCommand from a dict"""
|
|
87
|
+
if obj is None:
|
|
88
|
+
return None
|
|
89
|
+
|
|
90
|
+
if not isinstance(obj, dict):
|
|
91
|
+
return cls.model_validate(obj)
|
|
92
|
+
|
|
93
|
+
_obj = cls.model_validate({
|
|
94
|
+
"_t": obj.get("_t") if obj.get("_t") is not None else 'LegacySubmitPasswordResetCommand',
|
|
95
|
+
"userId": obj.get("userId"),
|
|
96
|
+
"password": obj.get("password"),
|
|
97
|
+
"passwordRepeated": obj.get("passwordRepeated"),
|
|
98
|
+
"resetToken": obj.get("resetToken")
|
|
99
|
+
})
|
|
100
|
+
return _obj
|
|
101
|
+
|
|
102
|
+
|
rapidata/api_client_README.md
CHANGED
|
@@ -97,6 +97,7 @@ Class | Method | HTTP request | Description
|
|
|
97
97
|
*IdentityApi* | [**identity_get_client_auth_token_post**](rapidata/api_client/docs/IdentityApi.md#identity_get_client_auth_token_post) | **POST** /Identity/GetClientAuthToken | Issues a new auth token using the client credentials.
|
|
98
98
|
*IdentityApi* | [**identity_index_post**](rapidata/api_client/docs/IdentityApi.md#identity_index_post) | **POST** /Identity/Index | Logs in a user by username or email and password.
|
|
99
99
|
*IdentityApi* | [**identity_logout_post**](rapidata/api_client/docs/IdentityApi.md#identity_logout_post) | **POST** /Identity/Logout | Logs out the current user by deleting the refresh token cookie.
|
|
100
|
+
*IdentityApi* | [**identity_register_temporary_post**](rapidata/api_client/docs/IdentityApi.md#identity_register_temporary_post) | **POST** /Identity/RegisterTemporary | Registers and logs in a temporary customer.
|
|
100
101
|
*IdentityApi* | [**identity_request_reset_post**](rapidata/api_client/docs/IdentityApi.md#identity_request_reset_post) | **POST** /Identity/RequestReset | Request a password reset for a user.
|
|
101
102
|
*IdentityApi* | [**identity_signup_post**](rapidata/api_client/docs/IdentityApi.md#identity_signup_post) | **POST** /Identity/Signup | Signs up a new user.
|
|
102
103
|
*IdentityApi* | [**identity_submit_reset_post**](rapidata/api_client/docs/IdentityApi.md#identity_submit_reset_post) | **POST** /Identity/SubmitReset | Updates the password of a user after a password reset request.
|
|
@@ -131,7 +132,9 @@ Class | Method | HTTP request | Description
|
|
|
131
132
|
*RapidApi* | [**rapid_query_validation_rapids_get**](rapidata/api_client/docs/RapidApi.md#rapid_query_validation_rapids_get) | **GET** /Rapid/QueryValidationRapids | Queries the validation rapids for a specific validation set.
|
|
132
133
|
*RapidApi* | [**rapid_skip_user_guess_post**](rapidata/api_client/docs/RapidApi.md#rapid_skip_user_guess_post) | **POST** /Rapid/SkipUserGuess | Skips a Rapid for the user.
|
|
133
134
|
*RapidApi* | [**rapid_validate_current_rapid_bag_get**](rapidata/api_client/docs/RapidApi.md#rapid_validate_current_rapid_bag_get) | **GET** /Rapid/ValidateCurrentRapidBag | Validates that the rapids associated with the current user are active.
|
|
135
|
+
*RapidataIdentityAPIApi* | [**root_get**](rapidata/api_client/docs/RapidataIdentityAPIApi.md#root_get) | **GET** / |
|
|
134
136
|
*SimpleWorkflowApi* | [**simple_workflow_get_result_overview_get**](rapidata/api_client/docs/SimpleWorkflowApi.md#simple_workflow_get_result_overview_get) | **GET** /SimpleWorkflow/GetResultOverview | Get the result overview for a simple workflow.
|
|
137
|
+
*UserInfoApi* | [**connect_userinfo_get**](rapidata/api_client/docs/UserInfoApi.md#connect_userinfo_get) | **GET** /connect/userinfo | Retrieves information about the authenticated user.
|
|
135
138
|
*ValidationApi* | [**validation_add_validation_rapid_post**](rapidata/api_client/docs/ValidationApi.md#validation_add_validation_rapid_post) | **POST** /Validation/AddValidationRapid | Adds a new validation rapid to the specified validation set.
|
|
136
139
|
*ValidationApi* | [**validation_add_validation_text_rapid_post**](rapidata/api_client/docs/ValidationApi.md#validation_add_validation_text_rapid_post) | **POST** /Validation/AddValidationTextRapid | Adds a new validation rapid to the specified validation set.
|
|
137
140
|
*ValidationApi* | [**validation_create_validation_set_post**](rapidata/api_client/docs/ValidationApi.md#validation_create_validation_set_post) | **POST** /Validation/CreateValidationSet | Creates a new empty validation set.
|
|
@@ -273,9 +276,11 @@ Class | Method | HTTP request | Description
|
|
|
273
276
|
- [ImportValidationSetFromFileResult](rapidata/api_client/docs/ImportValidationSetFromFileResult.md)
|
|
274
277
|
- [InProgressRapidModel](rapidata/api_client/docs/InProgressRapidModel.md)
|
|
275
278
|
- [IssueAuthTokenResult](rapidata/api_client/docs/IssueAuthTokenResult.md)
|
|
276
|
-
- [IssueClientAuthTokenResult](rapidata/api_client/docs/IssueClientAuthTokenResult.md)
|
|
277
279
|
- [LabelingSelection](rapidata/api_client/docs/LabelingSelection.md)
|
|
278
280
|
- [LanguageUserFilterModel](rapidata/api_client/docs/LanguageUserFilterModel.md)
|
|
281
|
+
- [LegacyIssueClientAuthTokenResult](rapidata/api_client/docs/LegacyIssueClientAuthTokenResult.md)
|
|
282
|
+
- [LegacyRequestPasswordResetCommand](rapidata/api_client/docs/LegacyRequestPasswordResetCommand.md)
|
|
283
|
+
- [LegacySubmitPasswordResetCommand](rapidata/api_client/docs/LegacySubmitPasswordResetCommand.md)
|
|
279
284
|
- [Line](rapidata/api_client/docs/Line.md)
|
|
280
285
|
- [LinePayload](rapidata/api_client/docs/LinePayload.md)
|
|
281
286
|
- [LinePoint](rapidata/api_client/docs/LinePoint.md)
|
|
@@ -334,7 +339,6 @@ Class | Method | HTTP request | Description
|
|
|
334
339
|
- [RapidResultModel](rapidata/api_client/docs/RapidResultModel.md)
|
|
335
340
|
- [RapidResultModelResult](rapidata/api_client/docs/RapidResultModelResult.md)
|
|
336
341
|
- [RapidSkippedModel](rapidata/api_client/docs/RapidSkippedModel.md)
|
|
337
|
-
- [RequestPasswordResetCommand](rapidata/api_client/docs/RequestPasswordResetCommand.md)
|
|
338
342
|
- [RootFilter](rapidata/api_client/docs/RootFilter.md)
|
|
339
343
|
- [SendCompletionMailStepModel](rapidata/api_client/docs/SendCompletionMailStepModel.md)
|
|
340
344
|
- [Shape](rapidata/api_client/docs/Shape.md)
|
|
@@ -354,7 +358,6 @@ Class | Method | HTTP request | Description
|
|
|
354
358
|
- [StaticSelection](rapidata/api_client/docs/StaticSelection.md)
|
|
355
359
|
- [SubmitCocoModel](rapidata/api_client/docs/SubmitCocoModel.md)
|
|
356
360
|
- [SubmitCocoResult](rapidata/api_client/docs/SubmitCocoResult.md)
|
|
357
|
-
- [SubmitPasswordResetCommand](rapidata/api_client/docs/SubmitPasswordResetCommand.md)
|
|
358
361
|
- [TextAsset](rapidata/api_client/docs/TextAsset.md)
|
|
359
362
|
- [TextAssetModel](rapidata/api_client/docs/TextAssetModel.md)
|
|
360
363
|
- [TextMetadata](rapidata/api_client/docs/TextMetadata.md)
|
|
@@ -403,6 +406,10 @@ Authentication schemes defined for the API:
|
|
|
403
406
|
- **API key parameter name**: Authorization
|
|
404
407
|
- **Location**: HTTP header
|
|
405
408
|
|
|
409
|
+
<a id="oauth2"></a>
|
|
410
|
+
### oauth2
|
|
411
|
+
|
|
412
|
+
|
|
406
413
|
|
|
407
414
|
## Author
|
|
408
415
|
|
|
@@ -1,13 +1,24 @@
|
|
|
1
1
|
from .rapidata_client import RapidataClient
|
|
2
|
-
from .workflow import
|
|
2
|
+
from .workflow import (
|
|
3
|
+
ClassifyWorkflow,
|
|
4
|
+
TranscriptionWorkflow,
|
|
5
|
+
CompareWorkflow,
|
|
6
|
+
FreeTextWorkflow,
|
|
7
|
+
)
|
|
3
8
|
from .selection import (
|
|
4
9
|
DemographicSelection,
|
|
5
10
|
LabelingSelection,
|
|
6
11
|
ValidationSelection,
|
|
7
12
|
ConditionalValidationSelection,
|
|
13
|
+
CappedSelection,
|
|
8
14
|
)
|
|
9
15
|
from .referee import NaiveReferee, ClassifyEarlyStoppingReferee
|
|
10
|
-
from .metadata import
|
|
16
|
+
from .metadata import (
|
|
17
|
+
PrivateTextMetadata,
|
|
18
|
+
PublicTextMetadata,
|
|
19
|
+
PromptMetadata,
|
|
20
|
+
TranscriptionMetadata,
|
|
21
|
+
)
|
|
11
22
|
from .feature_flags import FeatureFlags
|
|
12
23
|
from .country_codes import CountryCodes
|
|
13
24
|
from .assets import MediaAsset, TextAsset, MultiAsset
|
|
@@ -23,6 +23,8 @@ class MultiAsset(BaseAsset):
|
|
|
23
23
|
Args:
|
|
24
24
|
assets (List[BaseAsset]): A list of BaseAsset instances to be managed together.
|
|
25
25
|
"""
|
|
26
|
+
if len(assets) != 2:
|
|
27
|
+
raise ValueError("Assets must come in pairs for comparison tasks.")
|
|
26
28
|
self.assets = assets
|
|
27
29
|
|
|
28
30
|
def __len__(self) -> int:
|
|
@@ -9,6 +9,7 @@ from rapidata.api_client.models.upload_text_sources_to_dataset_model import (
|
|
|
9
9
|
UploadTextSourcesToDatasetModel,
|
|
10
10
|
)
|
|
11
11
|
from rapidata.rapidata_client.metadata.base_metadata import Metadata
|
|
12
|
+
from rapidata.rapidata_client.assets import TextAsset, MediaAsset, MultiAsset
|
|
12
13
|
from rapidata.service import LocalFileService
|
|
13
14
|
from rapidata.service.openapi_service import OpenAPIService
|
|
14
15
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
@@ -22,7 +23,8 @@ class RapidataDataset:
|
|
|
22
23
|
self.openapi_service = openapi_service
|
|
23
24
|
self.local_file_service = LocalFileService()
|
|
24
25
|
|
|
25
|
-
def add_texts(self,
|
|
26
|
+
def add_texts(self, text_assets: list[TextAsset]):
|
|
27
|
+
texts = [text.text for text in text_assets]
|
|
26
28
|
model = UploadTextSourcesToDatasetModel(
|
|
27
29
|
datasetId=self.dataset_id, textSources=texts
|
|
28
30
|
)
|
|
@@ -32,24 +34,26 @@ class RapidataDataset:
|
|
|
32
34
|
|
|
33
35
|
def add_media_from_paths(
|
|
34
36
|
self,
|
|
35
|
-
media_paths: list[
|
|
37
|
+
media_paths: list[MediaAsset | MultiAsset],
|
|
36
38
|
metadata: list[Metadata] | None = None,
|
|
37
39
|
max_workers: int = 10,
|
|
38
40
|
):
|
|
39
41
|
if metadata is not None and len(metadata) != len(media_paths):
|
|
40
42
|
raise ValueError(
|
|
41
|
-
"metadata must be None or have the same length as
|
|
43
|
+
"metadata must be None or have the same length as media_paths"
|
|
42
44
|
)
|
|
43
45
|
|
|
44
|
-
def upload_datapoint(
|
|
45
|
-
if isinstance(
|
|
46
|
-
|
|
47
|
-
):
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
46
|
+
def upload_datapoint(media_asset: MediaAsset | MultiAsset, meta: Metadata | None) -> None:
|
|
47
|
+
if isinstance(media_asset, MediaAsset):
|
|
48
|
+
paths = [media_asset.path]
|
|
49
|
+
elif isinstance(media_asset, MultiAsset):
|
|
50
|
+
paths = [asset.path for asset in media_asset.assets if isinstance(asset, MediaAsset)]
|
|
51
|
+
else:
|
|
52
|
+
raise ValueError(f"Unsupported asset type: {type(media_asset)}")
|
|
53
|
+
|
|
54
|
+
assert all(
|
|
55
|
+
os.path.exists(media_path) for media_path in paths
|
|
56
|
+
), "All media paths must exist on the local filesystem."
|
|
53
57
|
|
|
54
58
|
meta_model = meta.to_model() if meta else None
|
|
55
59
|
model = DatapointMetadataModel(
|
|
@@ -63,14 +67,14 @@ class RapidataDataset:
|
|
|
63
67
|
|
|
64
68
|
self.openapi_service.dataset_api.dataset_create_datapoint_post(
|
|
65
69
|
model=model,
|
|
66
|
-
files=
|
|
70
|
+
files=paths # type: ignore
|
|
67
71
|
)
|
|
68
72
|
|
|
69
73
|
total_uploads = len(media_paths)
|
|
70
74
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
71
75
|
futures = [
|
|
72
|
-
executor.submit(upload_datapoint,
|
|
73
|
-
for
|
|
76
|
+
executor.submit(upload_datapoint, media_asset, meta)
|
|
77
|
+
for media_asset, meta in zip_longest(media_paths, metadata or [])
|
|
74
78
|
]
|
|
75
79
|
|
|
76
80
|
with tqdm(total=total_uploads, desc="Uploading datapoints") as pbar:
|
|
@@ -160,8 +160,8 @@ class ValidationSetBuilder:
|
|
|
160
160
|
self,
|
|
161
161
|
asset: MediaAsset | TextAsset,
|
|
162
162
|
question: str,
|
|
163
|
-
transcription:
|
|
164
|
-
|
|
163
|
+
transcription: str,
|
|
164
|
+
truths: list[int],
|
|
165
165
|
strict_grading: bool | None = None,
|
|
166
166
|
metadata: list[Metadata] = [],
|
|
167
167
|
):
|
|
@@ -171,9 +171,9 @@ class ValidationSetBuilder:
|
|
|
171
171
|
asset (MediaAsset | TextAsset): The asset for the rapid.
|
|
172
172
|
question (str): The question for the rapid.
|
|
173
173
|
transcription (list[str]): The transcription for the rapid.
|
|
174
|
-
|
|
175
|
-
strict_grading (bool | None, optional): The strict grading
|
|
176
|
-
metadata (list[Metadata], optional): The metadata for the rapid.
|
|
174
|
+
truths (list[int]): The list of indices of the true word selections.
|
|
175
|
+
strict_grading (bool | None, optional): The strict grading for the rapid. Defaults to None.
|
|
176
|
+
metadata (list[Metadata], optional): The metadata for the rapid.
|
|
177
177
|
|
|
178
178
|
Returns:
|
|
179
179
|
ValidationSetBuilder: The ValidationSetBuilder instance.
|
|
@@ -183,16 +183,14 @@ class ValidationSetBuilder:
|
|
|
183
183
|
"""
|
|
184
184
|
transcription_words = [
|
|
185
185
|
TranscriptionWord(word=word, wordIndex=i)
|
|
186
|
-
for i, word in enumerate(transcription)
|
|
186
|
+
for i, word in enumerate(transcription.split())
|
|
187
187
|
]
|
|
188
188
|
|
|
189
|
-
|
|
190
|
-
for
|
|
191
|
-
if
|
|
192
|
-
raise ValueError(f"
|
|
193
|
-
|
|
194
|
-
TranscriptionWord(word=word, wordIndex=transcription.index(word))
|
|
195
|
-
)
|
|
189
|
+
true_words = []
|
|
190
|
+
for idx in truths:
|
|
191
|
+
if idx > len(transcription_words) - 1:
|
|
192
|
+
raise ValueError(f"Index {idx} is out of bounds")
|
|
193
|
+
true_words.append(transcription_words[idx])
|
|
196
194
|
|
|
197
195
|
payload = TranscriptionPayload(
|
|
198
196
|
_t="TranscriptionPayload", title=question, transcription=transcription_words
|
|
@@ -200,7 +198,7 @@ class ValidationSetBuilder:
|
|
|
200
198
|
|
|
201
199
|
model_truth = TranscriptionTruth(
|
|
202
200
|
_t="TranscriptionTruth",
|
|
203
|
-
correctWords=
|
|
201
|
+
correctWords=true_words,
|
|
204
202
|
strictGrading=strict_grading,
|
|
205
203
|
)
|
|
206
204
|
|
|
@@ -211,7 +209,7 @@ class ValidationSetBuilder:
|
|
|
211
209
|
payload=payload,
|
|
212
210
|
truths=model_truth,
|
|
213
211
|
metadata=metadata,
|
|
214
|
-
randomCorrectProbability=1 / len(
|
|
212
|
+
randomCorrectProbability = 1 / len(transcription_words),
|
|
215
213
|
)
|
|
216
214
|
)
|
|
217
215
|
|
|
@@ -3,7 +3,9 @@ from rapidata.rapidata_client.dataset.rapidata_dataset import RapidataDataset
|
|
|
3
3
|
from rapidata.service.openapi_service import OpenAPIService
|
|
4
4
|
import json
|
|
5
5
|
from rapidata.api_client.exceptions import ApiException
|
|
6
|
-
|
|
6
|
+
from typing import cast
|
|
7
|
+
from rapidata.api_client.models.workflow_artifact_model import WorkflowArtifactModel
|
|
8
|
+
from tqdm import tqdm
|
|
7
9
|
|
|
8
10
|
class RapidataOrder:
|
|
9
11
|
"""
|
|
@@ -26,6 +28,7 @@ class RapidataOrder:
|
|
|
26
28
|
self.openapi_service = openapi_service
|
|
27
29
|
self.order_id = order_id
|
|
28
30
|
self._dataset = dataset
|
|
31
|
+
self._workflow_id = None
|
|
29
32
|
|
|
30
33
|
def submit(self):
|
|
31
34
|
"""
|
|
@@ -49,27 +52,55 @@ class RapidataOrder:
|
|
|
49
52
|
"""
|
|
50
53
|
return self.openapi_service.order_api.order_get_by_id_get(self.order_id)
|
|
51
54
|
|
|
52
|
-
def
|
|
55
|
+
def display_progress_bar(self, refresh_rate=5):
|
|
53
56
|
"""
|
|
54
|
-
|
|
57
|
+
Displays a progress bar for the order processing using tqdm.
|
|
58
|
+
|
|
59
|
+
:param refresh_rate: How often to refresh the progress bar, in seconds.
|
|
60
|
+
:type refresh_rate: float
|
|
55
61
|
"""
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
62
|
+
total_rapids = self._get_total_rapids()
|
|
63
|
+
with tqdm(total=total_rapids, desc="Processing order", unit="rapids") as pbar:
|
|
64
|
+
completed_rapids = 0
|
|
65
|
+
while True:
|
|
66
|
+
current_completed = self._get_completed_rapids()
|
|
67
|
+
if current_completed > completed_rapids:
|
|
68
|
+
pbar.update(current_completed - completed_rapids)
|
|
69
|
+
completed_rapids = current_completed
|
|
70
|
+
|
|
71
|
+
if completed_rapids >= total_rapids:
|
|
72
|
+
break
|
|
73
|
+
|
|
74
|
+
time.sleep(refresh_rate)
|
|
67
75
|
|
|
68
|
-
|
|
76
|
+
def _get_workflow_id(self):
|
|
77
|
+
if self._workflow_id:
|
|
78
|
+
return self._workflow_id
|
|
79
|
+
|
|
80
|
+
for _ in range(2):
|
|
81
|
+
try:
|
|
82
|
+
order_result = self.openapi_service.order_api.order_get_by_id_get(self.order_id)
|
|
83
|
+
pipeline = self.openapi_service.pipeline_api.pipeline_id_get(order_result.pipeline_id)
|
|
84
|
+
self._workflow_id = cast(WorkflowArtifactModel, pipeline.artifacts["workflow-artifact"].actual_instance).workflow_id
|
|
69
85
|
break
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
86
|
+
except Exception:
|
|
87
|
+
time.sleep(2)
|
|
88
|
+
if not self._workflow_id:
|
|
89
|
+
raise Exception("Order has not started yet. Please wait for a few seconds and try again.")
|
|
90
|
+
return self._workflow_id
|
|
91
|
+
|
|
92
|
+
def _get_total_rapids(self):
|
|
93
|
+
workflow_id = self._get_workflow_id()
|
|
94
|
+
return self.openapi_service.workflow_api.workflow_get_progress_get(workflow_id).total
|
|
95
|
+
|
|
96
|
+
def _get_completed_rapids(self):
|
|
97
|
+
workflow_id = self._get_workflow_id()
|
|
98
|
+
return self.openapi_service.workflow_api.workflow_get_progress_get(workflow_id).completed
|
|
99
|
+
|
|
100
|
+
def get_progress_percentage(self):
|
|
101
|
+
workflow_id = self._get_workflow_id()
|
|
102
|
+
progress = self.openapi_service.workflow_api.workflow_get_progress_get(workflow_id)
|
|
103
|
+
return progress.completion_percentage
|
|
73
104
|
|
|
74
105
|
def get_results(self):
|
|
75
106
|
"""
|
|
@@ -13,6 +13,9 @@ from rapidata.api_client.models.create_order_model_workflow import (
|
|
|
13
13
|
CreateOrderModelWorkflow,
|
|
14
14
|
)
|
|
15
15
|
from rapidata.api_client.models.country_user_filter_model import CountryUserFilterModel
|
|
16
|
+
from rapidata.api_client.models.language_user_filter_model import (
|
|
17
|
+
LanguageUserFilterModel,
|
|
18
|
+
)
|
|
16
19
|
from rapidata.rapidata_client.feature_flags import FeatureFlags
|
|
17
20
|
from rapidata.rapidata_client.metadata.base_metadata import Metadata
|
|
18
21
|
from rapidata.rapidata_client.dataset.rapidata_dataset import RapidataDataset
|
|
@@ -25,6 +28,10 @@ from rapidata.service.openapi_service import OpenAPIService
|
|
|
25
28
|
|
|
26
29
|
from rapidata.rapidata_client.workflow.compare_workflow import CompareWorkflow
|
|
27
30
|
|
|
31
|
+
from rapidata.rapidata_client.assets import MediaAsset, TextAsset, MultiAsset
|
|
32
|
+
|
|
33
|
+
from typing import cast, Sequence
|
|
34
|
+
|
|
28
35
|
|
|
29
36
|
class RapidataOrderBuilder:
|
|
30
37
|
"""Builder object for creating Rapidata orders.
|
|
@@ -52,17 +59,16 @@ class RapidataOrderBuilder:
|
|
|
52
59
|
self._openapi_service = openapi_service
|
|
53
60
|
self._workflow: Workflow | None = None
|
|
54
61
|
self._referee: Referee | None = None
|
|
55
|
-
self._media_paths: list[str | list[str]] = []
|
|
56
62
|
self._metadata: list[Metadata] | None = None
|
|
57
63
|
self._aggregator: AggregatorType | None = None
|
|
58
64
|
self._validation_set_id: str | None = None
|
|
59
65
|
self._feature_flags: FeatureFlags | None = None
|
|
60
66
|
self._country_codes: list[str] | None = None
|
|
67
|
+
self._language_codes: list[str] | None = None
|
|
61
68
|
self._selections: list[Selection] = []
|
|
62
69
|
self._rapids_per_bag: int = 2
|
|
63
70
|
self._priority: int = 50
|
|
64
|
-
self.
|
|
65
|
-
self._media_paths: list[str | list[str]] = []
|
|
71
|
+
self._assets: list[MediaAsset] | list[TextAsset] | list[MultiAsset] = []
|
|
66
72
|
|
|
67
73
|
def _to_model(self) -> CreateOrderModel:
|
|
68
74
|
"""
|
|
@@ -80,22 +86,32 @@ class RapidataOrderBuilder:
|
|
|
80
86
|
if self._referee is None:
|
|
81
87
|
print("No referee provided, using default NaiveReferee.")
|
|
82
88
|
self._referee = NaiveReferee()
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
89
|
+
|
|
90
|
+
user_filters = []
|
|
91
|
+
|
|
92
|
+
if self._country_codes is not None:
|
|
93
|
+
user_filters.append(
|
|
94
|
+
CreateOrderModelUserFiltersInner(
|
|
95
|
+
CountryUserFilterModel(
|
|
96
|
+
_t="CountryFilter", countries=self._country_codes
|
|
97
|
+
)
|
|
98
|
+
)
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
if self._language_codes is not None:
|
|
102
|
+
user_filters.append(
|
|
103
|
+
CreateOrderModelUserFiltersInner(
|
|
104
|
+
LanguageUserFilterModel(
|
|
105
|
+
_t="LanguageFilter", languages=self._language_codes
|
|
106
|
+
)
|
|
107
|
+
)
|
|
88
108
|
)
|
|
89
109
|
|
|
90
110
|
return CreateOrderModel(
|
|
91
111
|
_t="CreateOrderModel",
|
|
92
112
|
orderName=self._name,
|
|
93
113
|
workflow=CreateOrderModelWorkflow(self._workflow.to_model()),
|
|
94
|
-
userFilters=
|
|
95
|
-
[CreateOrderModelUserFiltersInner(country_filter)]
|
|
96
|
-
if country_filter
|
|
97
|
-
else []
|
|
98
|
-
),
|
|
114
|
+
userFilters=user_filters,
|
|
99
115
|
referee=CreateOrderModelReferee(self._referee.to_model()),
|
|
100
116
|
validationSetId=self._validation_set_id,
|
|
101
117
|
featureFlags=(
|
|
@@ -129,8 +145,12 @@ class RapidataOrderBuilder:
|
|
|
129
145
|
if isinstance(
|
|
130
146
|
self._workflow, CompareWorkflow
|
|
131
147
|
): # Temporary fix; will be handled by backend in the future
|
|
148
|
+
assert all(isinstance(item, MultiAsset) for item in self._assets), (
|
|
149
|
+
"The media paths must be of type MultiAsset for comparison tasks."
|
|
150
|
+
)
|
|
151
|
+
media_paths = cast(list[MultiAsset], self._assets)
|
|
132
152
|
assert all(
|
|
133
|
-
[len(path) == 2 for path in
|
|
153
|
+
[len(path) == 2 for path in media_paths]
|
|
134
154
|
), "The media paths must come in pairs for comparison tasks."
|
|
135
155
|
|
|
136
156
|
result = self._openapi_service.order_api.order_create_post(
|
|
@@ -145,22 +165,18 @@ class RapidataOrderBuilder:
|
|
|
145
165
|
openapi_service=self._openapi_service,
|
|
146
166
|
)
|
|
147
167
|
|
|
148
|
-
if
|
|
168
|
+
if not self._assets:
|
|
149
169
|
raise ValueError(
|
|
150
|
-
"You
|
|
170
|
+
"You must provide assets to start the order."
|
|
151
171
|
)
|
|
172
|
+
if all(isinstance(item, TextAsset) for item in self._assets):
|
|
173
|
+
assets = cast(list[TextAsset], self._assets)
|
|
174
|
+
order.dataset.add_texts(assets)
|
|
152
175
|
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
"You must provide either media paths or texts to the order."
|
|
156
|
-
)
|
|
157
|
-
|
|
158
|
-
if self._texts:
|
|
159
|
-
order.dataset.add_texts(self._texts)
|
|
160
|
-
|
|
161
|
-
if self._media_paths:
|
|
176
|
+
elif all(isinstance(item, (MediaAsset, MultiAsset)) for item in self._assets):
|
|
177
|
+
assets = cast(list[MediaAsset | MultiAsset], self._assets)
|
|
162
178
|
order.dataset.add_media_from_paths(
|
|
163
|
-
|
|
179
|
+
assets, self._metadata, max_workers
|
|
164
180
|
)
|
|
165
181
|
|
|
166
182
|
if submit:
|
|
@@ -196,60 +212,60 @@ class RapidataOrderBuilder:
|
|
|
196
212
|
|
|
197
213
|
def media(
|
|
198
214
|
self,
|
|
199
|
-
|
|
200
|
-
metadata:
|
|
215
|
+
asset: list[MediaAsset] | list[TextAsset] | list[MultiAsset],
|
|
216
|
+
metadata: Sequence[Metadata] | None = None,
|
|
201
217
|
) -> "RapidataOrderBuilder":
|
|
202
218
|
"""
|
|
203
219
|
Set the media assets for the order.
|
|
204
220
|
|
|
205
221
|
Args:
|
|
206
|
-
media_paths (list[
|
|
222
|
+
media_paths (list[MediaAsset] | list[TextAsset] | list[MultiAsset]): The paths of the media assets to be set.
|
|
207
223
|
metadata (list[Metadata] | None, optional): Metadata for the media assets. Defaults to None.
|
|
208
224
|
|
|
209
225
|
Returns:
|
|
210
226
|
RapidataOrderBuilder: The updated RapidataOrderBuilder instance.
|
|
211
227
|
"""
|
|
212
|
-
self.
|
|
213
|
-
self._metadata = metadata
|
|
228
|
+
self._assets = asset
|
|
229
|
+
self._metadata = metadata # type: ignore
|
|
214
230
|
return self
|
|
215
231
|
|
|
216
|
-
def
|
|
232
|
+
def feature_flags(self, feature_flags: FeatureFlags) -> "RapidataOrderBuilder":
|
|
217
233
|
"""
|
|
218
|
-
Set the
|
|
234
|
+
Set the feature flags for the order.
|
|
219
235
|
|
|
220
236
|
Args:
|
|
221
|
-
|
|
237
|
+
feature_flags (FeatureFlags): The feature flags to be set.
|
|
222
238
|
|
|
223
239
|
Returns:
|
|
224
240
|
RapidataOrderBuilder: The updated RapidataOrderBuilder instance.
|
|
225
241
|
"""
|
|
226
|
-
self.
|
|
242
|
+
self._feature_flags = feature_flags
|
|
227
243
|
return self
|
|
228
244
|
|
|
229
|
-
def
|
|
245
|
+
def country_filter(self, country_codes: list[str]) -> "RapidataOrderBuilder":
|
|
230
246
|
"""
|
|
231
|
-
Set the
|
|
247
|
+
Set the target country codes for the order. E.g. `country_codes=["DE", "CH", "AT"]` for Germany, Switzerland, and Austria.
|
|
232
248
|
|
|
233
249
|
Args:
|
|
234
|
-
|
|
250
|
+
country_codes (list[str]): The country codes to be set.
|
|
235
251
|
|
|
236
252
|
Returns:
|
|
237
253
|
RapidataOrderBuilder: The updated RapidataOrderBuilder instance.
|
|
238
254
|
"""
|
|
239
|
-
self.
|
|
255
|
+
self._country_codes = country_codes
|
|
240
256
|
return self
|
|
241
257
|
|
|
242
|
-
def
|
|
258
|
+
def language_filter(self, language_codes: list[str]) -> "RapidataOrderBuilder":
|
|
243
259
|
"""
|
|
244
|
-
Set the target
|
|
260
|
+
Set the target language codes for the order. E.g. `language_codes=["de", "fr", "it"]` for German, French, and Italian.
|
|
245
261
|
|
|
246
262
|
Args:
|
|
247
|
-
|
|
263
|
+
language_codes (list[str]): The language codes to be set.
|
|
248
264
|
|
|
249
265
|
Returns:
|
|
250
266
|
RapidataOrderBuilder: The updated RapidataOrderBuilder instance.
|
|
251
267
|
"""
|
|
252
|
-
self.
|
|
268
|
+
self._language_codes = language_codes
|
|
253
269
|
return self
|
|
254
270
|
|
|
255
271
|
def aggregator(self, aggregator: AggregatorType) -> "RapidataOrderBuilder":
|
|
@@ -3,3 +3,4 @@ from .demographic_selection import DemographicSelection
|
|
|
3
3
|
from .labeling_selection import LabelingSelection
|
|
4
4
|
from .validation_selection import ValidationSelection
|
|
5
5
|
from .conditional_validation_selection import ConditionalValidationSelection
|
|
6
|
+
from .capped_selection import CappedSelection
|