rapidata 1.1.0__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rapidata/__init__.py +1 -0
- rapidata/api_client/__init__.py +5 -3
- rapidata/api_client/api/__init__.py +2 -0
- rapidata/api_client/api/campaign_api.py +8 -4
- rapidata/api_client/api/coco_api.py +4 -2
- rapidata/api_client/api/compare_workflow_api.py +2 -1
- rapidata/api_client/api/datapoint_api.py +6 -3
- rapidata/api_client/api/dataset_api.py +16 -8
- rapidata/api_client/api/identity_api.py +329 -50
- rapidata/api_client/api/newsletter_api.py +4 -2
- rapidata/api_client/api/order_api.py +40 -20
- rapidata/api_client/api/pipeline_api.py +6 -3
- rapidata/api_client/api/rapid_api.py +10 -5
- rapidata/api_client/api/rapidata_identity_api_api.py +272 -0
- rapidata/api_client/api/simple_workflow_api.py +2 -1
- rapidata/api_client/api/user_info_api.py +272 -0
- rapidata/api_client/api/validation_api.py +14 -7
- rapidata/api_client/api/workflow_api.py +18 -9
- rapidata/api_client/models/__init__.py +3 -3
- rapidata/api_client/models/issue_auth_token_result.py +1 -1
- rapidata/api_client/models/legacy_issue_client_auth_token_result.py +87 -0
- rapidata/api_client/models/legacy_request_password_reset_command.py +98 -0
- rapidata/api_client/models/legacy_submit_password_reset_command.py +102 -0
- rapidata/api_client_README.md +10 -3
- rapidata/rapidata_client/__init__.py +13 -2
- rapidata/rapidata_client/assets/multi_asset.py +2 -0
- rapidata/rapidata_client/dataset/rapidata_dataset.py +19 -15
- rapidata/rapidata_client/dataset/validation_set_builder.py +1 -1
- rapidata/rapidata_client/order/rapidata_order.py +49 -18
- rapidata/rapidata_client/order/rapidata_order_builder.py +23 -34
- rapidata/rapidata_client/selection/__init__.py +1 -0
- rapidata/rapidata_client/selection/capped_selection.py +25 -0
- rapidata/rapidata_client/simple_builders/__init__.py +0 -0
- rapidata/rapidata_client/simple_builders/simple_classification_builders.py +14 -9
- rapidata/rapidata_client/simple_builders/simple_compare_builders.py +6 -3
- rapidata/service/openapi_service.py +15 -0
- {rapidata-1.1.0.dist-info → rapidata-1.2.0.dist-info}/METADATA +1 -1
- {rapidata-1.1.0.dist-info → rapidata-1.2.0.dist-info}/RECORD +40 -33
- {rapidata-1.1.0.dist-info → rapidata-1.2.0.dist-info}/LICENSE +0 -0
- {rapidata-1.1.0.dist-info → rapidata-1.2.0.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
# coding: utf-8
|
|
2
|
+
|
|
3
|
+
"""
|
|
4
|
+
Rapidata.Dataset
|
|
5
|
+
|
|
6
|
+
No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
|
|
7
|
+
|
|
8
|
+
The version of the OpenAPI document: v1
|
|
9
|
+
Generated by OpenAPI Generator (https://openapi-generator.tech)
|
|
10
|
+
|
|
11
|
+
Do not edit the class manually.
|
|
12
|
+
""" # noqa: E501
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
import pprint
|
|
17
|
+
import re # noqa: F401
|
|
18
|
+
import json
|
|
19
|
+
|
|
20
|
+
from pydantic import BaseModel, ConfigDict, Field, StrictStr, field_validator
|
|
21
|
+
from typing import Any, ClassVar, Dict, List
|
|
22
|
+
from typing import Optional, Set
|
|
23
|
+
from typing_extensions import Self
|
|
24
|
+
|
|
25
|
+
class LegacySubmitPasswordResetCommand(BaseModel):
|
|
26
|
+
"""
|
|
27
|
+
LegacySubmitPasswordResetCommand
|
|
28
|
+
""" # noqa: E501
|
|
29
|
+
t: StrictStr = Field(description="Discriminator value for LegacySubmitPasswordResetCommand", alias="_t")
|
|
30
|
+
user_id: StrictStr = Field(alias="userId")
|
|
31
|
+
password: StrictStr
|
|
32
|
+
password_repeated: StrictStr = Field(alias="passwordRepeated")
|
|
33
|
+
reset_token: StrictStr = Field(alias="resetToken")
|
|
34
|
+
__properties: ClassVar[List[str]] = ["_t", "userId", "password", "passwordRepeated", "resetToken"]
|
|
35
|
+
|
|
36
|
+
@field_validator('t')
|
|
37
|
+
def t_validate_enum(cls, value):
|
|
38
|
+
"""Validates the enum"""
|
|
39
|
+
if value not in set(['LegacySubmitPasswordResetCommand']):
|
|
40
|
+
raise ValueError("must be one of enum values ('LegacySubmitPasswordResetCommand')")
|
|
41
|
+
return value
|
|
42
|
+
|
|
43
|
+
model_config = ConfigDict(
|
|
44
|
+
populate_by_name=True,
|
|
45
|
+
validate_assignment=True,
|
|
46
|
+
protected_namespaces=(),
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def to_str(self) -> str:
|
|
51
|
+
"""Returns the string representation of the model using alias"""
|
|
52
|
+
return pprint.pformat(self.model_dump(by_alias=True))
|
|
53
|
+
|
|
54
|
+
def to_json(self) -> str:
|
|
55
|
+
"""Returns the JSON representation of the model using alias"""
|
|
56
|
+
# TODO: pydantic v2: use .model_dump_json(by_alias=True, exclude_unset=True) instead
|
|
57
|
+
return json.dumps(self.to_dict())
|
|
58
|
+
|
|
59
|
+
@classmethod
|
|
60
|
+
def from_json(cls, json_str: str) -> Optional[Self]:
|
|
61
|
+
"""Create an instance of LegacySubmitPasswordResetCommand from a JSON string"""
|
|
62
|
+
return cls.from_dict(json.loads(json_str))
|
|
63
|
+
|
|
64
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
65
|
+
"""Return the dictionary representation of the model using alias.
|
|
66
|
+
|
|
67
|
+
This has the following differences from calling pydantic's
|
|
68
|
+
`self.model_dump(by_alias=True)`:
|
|
69
|
+
|
|
70
|
+
* `None` is only added to the output dict for nullable fields that
|
|
71
|
+
were set at model initialization. Other fields with value `None`
|
|
72
|
+
are ignored.
|
|
73
|
+
"""
|
|
74
|
+
excluded_fields: Set[str] = set([
|
|
75
|
+
])
|
|
76
|
+
|
|
77
|
+
_dict = self.model_dump(
|
|
78
|
+
by_alias=True,
|
|
79
|
+
exclude=excluded_fields,
|
|
80
|
+
exclude_none=True,
|
|
81
|
+
)
|
|
82
|
+
return _dict
|
|
83
|
+
|
|
84
|
+
@classmethod
|
|
85
|
+
def from_dict(cls, obj: Optional[Dict[str, Any]]) -> Optional[Self]:
|
|
86
|
+
"""Create an instance of LegacySubmitPasswordResetCommand from a dict"""
|
|
87
|
+
if obj is None:
|
|
88
|
+
return None
|
|
89
|
+
|
|
90
|
+
if not isinstance(obj, dict):
|
|
91
|
+
return cls.model_validate(obj)
|
|
92
|
+
|
|
93
|
+
_obj = cls.model_validate({
|
|
94
|
+
"_t": obj.get("_t") if obj.get("_t") is not None else 'LegacySubmitPasswordResetCommand',
|
|
95
|
+
"userId": obj.get("userId"),
|
|
96
|
+
"password": obj.get("password"),
|
|
97
|
+
"passwordRepeated": obj.get("passwordRepeated"),
|
|
98
|
+
"resetToken": obj.get("resetToken")
|
|
99
|
+
})
|
|
100
|
+
return _obj
|
|
101
|
+
|
|
102
|
+
|
rapidata/api_client_README.md
CHANGED
|
@@ -97,6 +97,7 @@ Class | Method | HTTP request | Description
|
|
|
97
97
|
*IdentityApi* | [**identity_get_client_auth_token_post**](rapidata/api_client/docs/IdentityApi.md#identity_get_client_auth_token_post) | **POST** /Identity/GetClientAuthToken | Issues a new auth token using the client credentials.
|
|
98
98
|
*IdentityApi* | [**identity_index_post**](rapidata/api_client/docs/IdentityApi.md#identity_index_post) | **POST** /Identity/Index | Logs in a user by username or email and password.
|
|
99
99
|
*IdentityApi* | [**identity_logout_post**](rapidata/api_client/docs/IdentityApi.md#identity_logout_post) | **POST** /Identity/Logout | Logs out the current user by deleting the refresh token cookie.
|
|
100
|
+
*IdentityApi* | [**identity_register_temporary_post**](rapidata/api_client/docs/IdentityApi.md#identity_register_temporary_post) | **POST** /Identity/RegisterTemporary | Registers and logs in a temporary customer.
|
|
100
101
|
*IdentityApi* | [**identity_request_reset_post**](rapidata/api_client/docs/IdentityApi.md#identity_request_reset_post) | **POST** /Identity/RequestReset | Request a password reset for a user.
|
|
101
102
|
*IdentityApi* | [**identity_signup_post**](rapidata/api_client/docs/IdentityApi.md#identity_signup_post) | **POST** /Identity/Signup | Signs up a new user.
|
|
102
103
|
*IdentityApi* | [**identity_submit_reset_post**](rapidata/api_client/docs/IdentityApi.md#identity_submit_reset_post) | **POST** /Identity/SubmitReset | Updates the password of a user after a password reset request.
|
|
@@ -131,7 +132,9 @@ Class | Method | HTTP request | Description
|
|
|
131
132
|
*RapidApi* | [**rapid_query_validation_rapids_get**](rapidata/api_client/docs/RapidApi.md#rapid_query_validation_rapids_get) | **GET** /Rapid/QueryValidationRapids | Queries the validation rapids for a specific validation set.
|
|
132
133
|
*RapidApi* | [**rapid_skip_user_guess_post**](rapidata/api_client/docs/RapidApi.md#rapid_skip_user_guess_post) | **POST** /Rapid/SkipUserGuess | Skips a Rapid for the user.
|
|
133
134
|
*RapidApi* | [**rapid_validate_current_rapid_bag_get**](rapidata/api_client/docs/RapidApi.md#rapid_validate_current_rapid_bag_get) | **GET** /Rapid/ValidateCurrentRapidBag | Validates that the rapids associated with the current user are active.
|
|
135
|
+
*RapidataIdentityAPIApi* | [**root_get**](rapidata/api_client/docs/RapidataIdentityAPIApi.md#root_get) | **GET** / |
|
|
134
136
|
*SimpleWorkflowApi* | [**simple_workflow_get_result_overview_get**](rapidata/api_client/docs/SimpleWorkflowApi.md#simple_workflow_get_result_overview_get) | **GET** /SimpleWorkflow/GetResultOverview | Get the result overview for a simple workflow.
|
|
137
|
+
*UserInfoApi* | [**connect_userinfo_get**](rapidata/api_client/docs/UserInfoApi.md#connect_userinfo_get) | **GET** /connect/userinfo | Retrieves information about the authenticated user.
|
|
135
138
|
*ValidationApi* | [**validation_add_validation_rapid_post**](rapidata/api_client/docs/ValidationApi.md#validation_add_validation_rapid_post) | **POST** /Validation/AddValidationRapid | Adds a new validation rapid to the specified validation set.
|
|
136
139
|
*ValidationApi* | [**validation_add_validation_text_rapid_post**](rapidata/api_client/docs/ValidationApi.md#validation_add_validation_text_rapid_post) | **POST** /Validation/AddValidationTextRapid | Adds a new validation rapid to the specified validation set.
|
|
137
140
|
*ValidationApi* | [**validation_create_validation_set_post**](rapidata/api_client/docs/ValidationApi.md#validation_create_validation_set_post) | **POST** /Validation/CreateValidationSet | Creates a new empty validation set.
|
|
@@ -273,9 +276,11 @@ Class | Method | HTTP request | Description
|
|
|
273
276
|
- [ImportValidationSetFromFileResult](rapidata/api_client/docs/ImportValidationSetFromFileResult.md)
|
|
274
277
|
- [InProgressRapidModel](rapidata/api_client/docs/InProgressRapidModel.md)
|
|
275
278
|
- [IssueAuthTokenResult](rapidata/api_client/docs/IssueAuthTokenResult.md)
|
|
276
|
-
- [IssueClientAuthTokenResult](rapidata/api_client/docs/IssueClientAuthTokenResult.md)
|
|
277
279
|
- [LabelingSelection](rapidata/api_client/docs/LabelingSelection.md)
|
|
278
280
|
- [LanguageUserFilterModel](rapidata/api_client/docs/LanguageUserFilterModel.md)
|
|
281
|
+
- [LegacyIssueClientAuthTokenResult](rapidata/api_client/docs/LegacyIssueClientAuthTokenResult.md)
|
|
282
|
+
- [LegacyRequestPasswordResetCommand](rapidata/api_client/docs/LegacyRequestPasswordResetCommand.md)
|
|
283
|
+
- [LegacySubmitPasswordResetCommand](rapidata/api_client/docs/LegacySubmitPasswordResetCommand.md)
|
|
279
284
|
- [Line](rapidata/api_client/docs/Line.md)
|
|
280
285
|
- [LinePayload](rapidata/api_client/docs/LinePayload.md)
|
|
281
286
|
- [LinePoint](rapidata/api_client/docs/LinePoint.md)
|
|
@@ -334,7 +339,6 @@ Class | Method | HTTP request | Description
|
|
|
334
339
|
- [RapidResultModel](rapidata/api_client/docs/RapidResultModel.md)
|
|
335
340
|
- [RapidResultModelResult](rapidata/api_client/docs/RapidResultModelResult.md)
|
|
336
341
|
- [RapidSkippedModel](rapidata/api_client/docs/RapidSkippedModel.md)
|
|
337
|
-
- [RequestPasswordResetCommand](rapidata/api_client/docs/RequestPasswordResetCommand.md)
|
|
338
342
|
- [RootFilter](rapidata/api_client/docs/RootFilter.md)
|
|
339
343
|
- [SendCompletionMailStepModel](rapidata/api_client/docs/SendCompletionMailStepModel.md)
|
|
340
344
|
- [Shape](rapidata/api_client/docs/Shape.md)
|
|
@@ -354,7 +358,6 @@ Class | Method | HTTP request | Description
|
|
|
354
358
|
- [StaticSelection](rapidata/api_client/docs/StaticSelection.md)
|
|
355
359
|
- [SubmitCocoModel](rapidata/api_client/docs/SubmitCocoModel.md)
|
|
356
360
|
- [SubmitCocoResult](rapidata/api_client/docs/SubmitCocoResult.md)
|
|
357
|
-
- [SubmitPasswordResetCommand](rapidata/api_client/docs/SubmitPasswordResetCommand.md)
|
|
358
361
|
- [TextAsset](rapidata/api_client/docs/TextAsset.md)
|
|
359
362
|
- [TextAssetModel](rapidata/api_client/docs/TextAssetModel.md)
|
|
360
363
|
- [TextMetadata](rapidata/api_client/docs/TextMetadata.md)
|
|
@@ -403,6 +406,10 @@ Authentication schemes defined for the API:
|
|
|
403
406
|
- **API key parameter name**: Authorization
|
|
404
407
|
- **Location**: HTTP header
|
|
405
408
|
|
|
409
|
+
<a id="oauth2"></a>
|
|
410
|
+
### oauth2
|
|
411
|
+
|
|
412
|
+
|
|
406
413
|
|
|
407
414
|
## Author
|
|
408
415
|
|
|
@@ -1,13 +1,24 @@
|
|
|
1
1
|
from .rapidata_client import RapidataClient
|
|
2
|
-
from .workflow import
|
|
2
|
+
from .workflow import (
|
|
3
|
+
ClassifyWorkflow,
|
|
4
|
+
TranscriptionWorkflow,
|
|
5
|
+
CompareWorkflow,
|
|
6
|
+
FreeTextWorkflow,
|
|
7
|
+
)
|
|
3
8
|
from .selection import (
|
|
4
9
|
DemographicSelection,
|
|
5
10
|
LabelingSelection,
|
|
6
11
|
ValidationSelection,
|
|
7
12
|
ConditionalValidationSelection,
|
|
13
|
+
CappedSelection,
|
|
8
14
|
)
|
|
9
15
|
from .referee import NaiveReferee, ClassifyEarlyStoppingReferee
|
|
10
|
-
from .metadata import
|
|
16
|
+
from .metadata import (
|
|
17
|
+
PrivateTextMetadata,
|
|
18
|
+
PublicTextMetadata,
|
|
19
|
+
PromptMetadata,
|
|
20
|
+
TranscriptionMetadata,
|
|
21
|
+
)
|
|
11
22
|
from .feature_flags import FeatureFlags
|
|
12
23
|
from .country_codes import CountryCodes
|
|
13
24
|
from .assets import MediaAsset, TextAsset, MultiAsset
|
|
@@ -23,6 +23,8 @@ class MultiAsset(BaseAsset):
|
|
|
23
23
|
Args:
|
|
24
24
|
assets (List[BaseAsset]): A list of BaseAsset instances to be managed together.
|
|
25
25
|
"""
|
|
26
|
+
if len(assets) != 2:
|
|
27
|
+
raise ValueError("Assets must come in pairs for comparison tasks.")
|
|
26
28
|
self.assets = assets
|
|
27
29
|
|
|
28
30
|
def __len__(self) -> int:
|
|
@@ -9,6 +9,7 @@ from rapidata.api_client.models.upload_text_sources_to_dataset_model import (
|
|
|
9
9
|
UploadTextSourcesToDatasetModel,
|
|
10
10
|
)
|
|
11
11
|
from rapidata.rapidata_client.metadata.base_metadata import Metadata
|
|
12
|
+
from rapidata.rapidata_client.assets import TextAsset, MediaAsset, MultiAsset
|
|
12
13
|
from rapidata.service import LocalFileService
|
|
13
14
|
from rapidata.service.openapi_service import OpenAPIService
|
|
14
15
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
@@ -22,7 +23,8 @@ class RapidataDataset:
|
|
|
22
23
|
self.openapi_service = openapi_service
|
|
23
24
|
self.local_file_service = LocalFileService()
|
|
24
25
|
|
|
25
|
-
def add_texts(self,
|
|
26
|
+
def add_texts(self, text_assets: list[TextAsset]):
|
|
27
|
+
texts = [text.text for text in text_assets]
|
|
26
28
|
model = UploadTextSourcesToDatasetModel(
|
|
27
29
|
datasetId=self.dataset_id, textSources=texts
|
|
28
30
|
)
|
|
@@ -32,24 +34,26 @@ class RapidataDataset:
|
|
|
32
34
|
|
|
33
35
|
def add_media_from_paths(
|
|
34
36
|
self,
|
|
35
|
-
media_paths: list[
|
|
37
|
+
media_paths: list[MediaAsset | MultiAsset],
|
|
36
38
|
metadata: list[Metadata] | None = None,
|
|
37
39
|
max_workers: int = 10,
|
|
38
40
|
):
|
|
39
41
|
if metadata is not None and len(metadata) != len(media_paths):
|
|
40
42
|
raise ValueError(
|
|
41
|
-
"metadata must be None or have the same length as
|
|
43
|
+
"metadata must be None or have the same length as media_paths"
|
|
42
44
|
)
|
|
43
45
|
|
|
44
|
-
def upload_datapoint(
|
|
45
|
-
if isinstance(
|
|
46
|
-
|
|
47
|
-
):
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
46
|
+
def upload_datapoint(media_asset: MediaAsset | MultiAsset, meta: Metadata | None) -> None:
|
|
47
|
+
if isinstance(media_asset, MediaAsset):
|
|
48
|
+
paths = [media_asset.path]
|
|
49
|
+
elif isinstance(media_asset, MultiAsset):
|
|
50
|
+
paths = [asset.path for asset in media_asset.assets if isinstance(asset, MediaAsset)]
|
|
51
|
+
else:
|
|
52
|
+
raise ValueError(f"Unsupported asset type: {type(media_asset)}")
|
|
53
|
+
|
|
54
|
+
assert all(
|
|
55
|
+
os.path.exists(media_path) for media_path in paths
|
|
56
|
+
), "All media paths must exist on the local filesystem."
|
|
53
57
|
|
|
54
58
|
meta_model = meta.to_model() if meta else None
|
|
55
59
|
model = DatapointMetadataModel(
|
|
@@ -63,14 +67,14 @@ class RapidataDataset:
|
|
|
63
67
|
|
|
64
68
|
self.openapi_service.dataset_api.dataset_create_datapoint_post(
|
|
65
69
|
model=model,
|
|
66
|
-
files=
|
|
70
|
+
files=paths # type: ignore
|
|
67
71
|
)
|
|
68
72
|
|
|
69
73
|
total_uploads = len(media_paths)
|
|
70
74
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
71
75
|
futures = [
|
|
72
|
-
executor.submit(upload_datapoint,
|
|
73
|
-
for
|
|
76
|
+
executor.submit(upload_datapoint, media_asset, meta)
|
|
77
|
+
for media_asset, meta in zip_longest(media_paths, metadata or [])
|
|
74
78
|
]
|
|
75
79
|
|
|
76
80
|
with tqdm(total=total_uploads, desc="Uploading datapoints") as pbar:
|
|
@@ -3,7 +3,9 @@ from rapidata.rapidata_client.dataset.rapidata_dataset import RapidataDataset
|
|
|
3
3
|
from rapidata.service.openapi_service import OpenAPIService
|
|
4
4
|
import json
|
|
5
5
|
from rapidata.api_client.exceptions import ApiException
|
|
6
|
-
|
|
6
|
+
from typing import cast
|
|
7
|
+
from rapidata.api_client.models.workflow_artifact_model import WorkflowArtifactModel
|
|
8
|
+
from tqdm import tqdm
|
|
7
9
|
|
|
8
10
|
class RapidataOrder:
|
|
9
11
|
"""
|
|
@@ -26,6 +28,7 @@ class RapidataOrder:
|
|
|
26
28
|
self.openapi_service = openapi_service
|
|
27
29
|
self.order_id = order_id
|
|
28
30
|
self._dataset = dataset
|
|
31
|
+
self._workflow_id = None
|
|
29
32
|
|
|
30
33
|
def submit(self):
|
|
31
34
|
"""
|
|
@@ -49,27 +52,55 @@ class RapidataOrder:
|
|
|
49
52
|
"""
|
|
50
53
|
return self.openapi_service.order_api.order_get_by_id_get(self.order_id)
|
|
51
54
|
|
|
52
|
-
def
|
|
55
|
+
def display_progress_bar(self, refresh_rate=5):
|
|
53
56
|
"""
|
|
54
|
-
|
|
57
|
+
Displays a progress bar for the order processing using tqdm.
|
|
58
|
+
|
|
59
|
+
:param refresh_rate: How often to refresh the progress bar, in seconds.
|
|
60
|
+
:type refresh_rate: float
|
|
55
61
|
"""
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
62
|
+
total_rapids = self._get_total_rapids()
|
|
63
|
+
with tqdm(total=total_rapids, desc="Processing order", unit="rapids") as pbar:
|
|
64
|
+
completed_rapids = 0
|
|
65
|
+
while True:
|
|
66
|
+
current_completed = self._get_completed_rapids()
|
|
67
|
+
if current_completed > completed_rapids:
|
|
68
|
+
pbar.update(current_completed - completed_rapids)
|
|
69
|
+
completed_rapids = current_completed
|
|
70
|
+
|
|
71
|
+
if completed_rapids >= total_rapids:
|
|
72
|
+
break
|
|
73
|
+
|
|
74
|
+
time.sleep(refresh_rate)
|
|
67
75
|
|
|
68
|
-
|
|
76
|
+
def _get_workflow_id(self):
|
|
77
|
+
if self._workflow_id:
|
|
78
|
+
return self._workflow_id
|
|
79
|
+
|
|
80
|
+
for _ in range(2):
|
|
81
|
+
try:
|
|
82
|
+
order_result = self.openapi_service.order_api.order_get_by_id_get(self.order_id)
|
|
83
|
+
pipeline = self.openapi_service.pipeline_api.pipeline_id_get(order_result.pipeline_id)
|
|
84
|
+
self._workflow_id = cast(WorkflowArtifactModel, pipeline.artifacts["workflow-artifact"].actual_instance).workflow_id
|
|
69
85
|
break
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
86
|
+
except Exception:
|
|
87
|
+
time.sleep(2)
|
|
88
|
+
if not self._workflow_id:
|
|
89
|
+
raise Exception("Order has not started yet. Please wait for a few seconds and try again.")
|
|
90
|
+
return self._workflow_id
|
|
91
|
+
|
|
92
|
+
def _get_total_rapids(self):
|
|
93
|
+
workflow_id = self._get_workflow_id()
|
|
94
|
+
return self.openapi_service.workflow_api.workflow_get_progress_get(workflow_id).total
|
|
95
|
+
|
|
96
|
+
def _get_completed_rapids(self):
|
|
97
|
+
workflow_id = self._get_workflow_id()
|
|
98
|
+
return self.openapi_service.workflow_api.workflow_get_progress_get(workflow_id).completed
|
|
99
|
+
|
|
100
|
+
def get_progress_percentage(self):
|
|
101
|
+
workflow_id = self._get_workflow_id()
|
|
102
|
+
progress = self.openapi_service.workflow_api.workflow_get_progress_get(workflow_id)
|
|
103
|
+
return progress.completion_percentage
|
|
73
104
|
|
|
74
105
|
def get_results(self):
|
|
75
106
|
"""
|
|
@@ -28,6 +28,10 @@ from rapidata.service.openapi_service import OpenAPIService
|
|
|
28
28
|
|
|
29
29
|
from rapidata.rapidata_client.workflow.compare_workflow import CompareWorkflow
|
|
30
30
|
|
|
31
|
+
from rapidata.rapidata_client.assets import MediaAsset, TextAsset, MultiAsset
|
|
32
|
+
|
|
33
|
+
from typing import cast, Sequence
|
|
34
|
+
|
|
31
35
|
|
|
32
36
|
class RapidataOrderBuilder:
|
|
33
37
|
"""Builder object for creating Rapidata orders.
|
|
@@ -55,7 +59,6 @@ class RapidataOrderBuilder:
|
|
|
55
59
|
self._openapi_service = openapi_service
|
|
56
60
|
self._workflow: Workflow | None = None
|
|
57
61
|
self._referee: Referee | None = None
|
|
58
|
-
self._media_paths: list[str | list[str]] = []
|
|
59
62
|
self._metadata: list[Metadata] | None = None
|
|
60
63
|
self._aggregator: AggregatorType | None = None
|
|
61
64
|
self._validation_set_id: str | None = None
|
|
@@ -65,8 +68,7 @@ class RapidataOrderBuilder:
|
|
|
65
68
|
self._selections: list[Selection] = []
|
|
66
69
|
self._rapids_per_bag: int = 2
|
|
67
70
|
self._priority: int = 50
|
|
68
|
-
self.
|
|
69
|
-
self._media_paths: list[str | list[str]] = []
|
|
71
|
+
self._assets: list[MediaAsset] | list[TextAsset] | list[MultiAsset] = []
|
|
70
72
|
|
|
71
73
|
def _to_model(self) -> CreateOrderModel:
|
|
72
74
|
"""
|
|
@@ -143,8 +145,12 @@ class RapidataOrderBuilder:
|
|
|
143
145
|
if isinstance(
|
|
144
146
|
self._workflow, CompareWorkflow
|
|
145
147
|
): # Temporary fix; will be handled by backend in the future
|
|
148
|
+
assert all(isinstance(item, MultiAsset) for item in self._assets), (
|
|
149
|
+
"The media paths must be of type MultiAsset for comparison tasks."
|
|
150
|
+
)
|
|
151
|
+
media_paths = cast(list[MultiAsset], self._assets)
|
|
146
152
|
assert all(
|
|
147
|
-
[len(path) == 2 for path in
|
|
153
|
+
[len(path) == 2 for path in media_paths]
|
|
148
154
|
), "The media paths must come in pairs for comparison tasks."
|
|
149
155
|
|
|
150
156
|
result = self._openapi_service.order_api.order_create_post(
|
|
@@ -159,22 +165,18 @@ class RapidataOrderBuilder:
|
|
|
159
165
|
openapi_service=self._openapi_service,
|
|
160
166
|
)
|
|
161
167
|
|
|
162
|
-
if
|
|
163
|
-
raise ValueError(
|
|
164
|
-
"You cannot provide both media paths and texts to the same order."
|
|
165
|
-
)
|
|
166
|
-
|
|
167
|
-
if not self._media_paths and not self._texts:
|
|
168
|
+
if not self._assets:
|
|
168
169
|
raise ValueError(
|
|
169
|
-
"You must provide
|
|
170
|
+
"You must provide assets to start the order."
|
|
170
171
|
)
|
|
172
|
+
if all(isinstance(item, TextAsset) for item in self._assets):
|
|
173
|
+
assets = cast(list[TextAsset], self._assets)
|
|
174
|
+
order.dataset.add_texts(assets)
|
|
171
175
|
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
if self._media_paths:
|
|
176
|
+
elif all(isinstance(item, (MediaAsset, MultiAsset)) for item in self._assets):
|
|
177
|
+
assets = cast(list[MediaAsset | MultiAsset], self._assets)
|
|
176
178
|
order.dataset.add_media_from_paths(
|
|
177
|
-
|
|
179
|
+
assets, self._metadata, max_workers
|
|
178
180
|
)
|
|
179
181
|
|
|
180
182
|
if submit:
|
|
@@ -210,34 +212,21 @@ class RapidataOrderBuilder:
|
|
|
210
212
|
|
|
211
213
|
def media(
|
|
212
214
|
self,
|
|
213
|
-
|
|
214
|
-
metadata:
|
|
215
|
+
asset: list[MediaAsset] | list[TextAsset] | list[MultiAsset],
|
|
216
|
+
metadata: Sequence[Metadata] | None = None,
|
|
215
217
|
) -> "RapidataOrderBuilder":
|
|
216
218
|
"""
|
|
217
219
|
Set the media assets for the order.
|
|
218
220
|
|
|
219
221
|
Args:
|
|
220
|
-
media_paths (list[
|
|
222
|
+
media_paths (list[MediaAsset] | list[TextAsset] | list[MultiAsset]): The paths of the media assets to be set.
|
|
221
223
|
metadata (list[Metadata] | None, optional): Metadata for the media assets. Defaults to None.
|
|
222
224
|
|
|
223
225
|
Returns:
|
|
224
226
|
RapidataOrderBuilder: The updated RapidataOrderBuilder instance.
|
|
225
227
|
"""
|
|
226
|
-
self.
|
|
227
|
-
self._metadata = metadata
|
|
228
|
-
return self
|
|
229
|
-
|
|
230
|
-
def texts(self, texts: list[str]) -> "RapidataOrderBuilder":
|
|
231
|
-
"""
|
|
232
|
-
Set the TextAssets for the order.
|
|
233
|
-
|
|
234
|
-
Args:
|
|
235
|
-
texts (list[str]): The texts to be set.
|
|
236
|
-
|
|
237
|
-
Returns:
|
|
238
|
-
RapidataOrderBuilder: The updated RapidataOrderBuilder instance.
|
|
239
|
-
"""
|
|
240
|
-
self._texts = texts
|
|
228
|
+
self._assets = asset
|
|
229
|
+
self._metadata = metadata # type: ignore
|
|
241
230
|
return self
|
|
242
231
|
|
|
243
232
|
def feature_flags(self, feature_flags: FeatureFlags) -> "RapidataOrderBuilder":
|
|
@@ -3,3 +3,4 @@ from .demographic_selection import DemographicSelection
|
|
|
3
3
|
from .labeling_selection import LabelingSelection
|
|
4
4
|
from .validation_selection import ValidationSelection
|
|
5
5
|
from .conditional_validation_selection import ConditionalValidationSelection
|
|
6
|
+
from .capped_selection import CappedSelection
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from rapidata.api_client.models.capped_selection import (
|
|
2
|
+
CappedSelection as CappedSelectionModel,
|
|
3
|
+
)
|
|
4
|
+
from rapidata.api_client.models.capped_selection_selections_inner import (
|
|
5
|
+
CappedSelectionSelectionsInner,
|
|
6
|
+
)
|
|
7
|
+
from rapidata.rapidata_client.selection.base_selection import Selection
|
|
8
|
+
from typing import Sequence
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class CappedSelection(Selection):
|
|
12
|
+
|
|
13
|
+
def __init__(self, selections: Sequence[Selection], max_rapids: int):
|
|
14
|
+
self.selections = selections
|
|
15
|
+
self.max_rapids = max_rapids
|
|
16
|
+
|
|
17
|
+
def to_model(self):
|
|
18
|
+
return CappedSelectionModel(
|
|
19
|
+
_t="CappedSelection",
|
|
20
|
+
selections=[
|
|
21
|
+
CappedSelectionSelectionsInner(selection.to_model())
|
|
22
|
+
for selection in self.selections
|
|
23
|
+
],
|
|
24
|
+
max_rapids=self.max_rapids,
|
|
25
|
+
)
|
|
File without changes
|
|
@@ -7,6 +7,8 @@ from rapidata.rapidata_client.workflow.classify_workflow import ClassifyWorkflow
|
|
|
7
7
|
from rapidata.rapidata_client.selection.validation_selection import ValidationSelection
|
|
8
8
|
from rapidata.rapidata_client.selection.labeling_selection import LabelingSelection
|
|
9
9
|
from rapidata.service.openapi_service import OpenAPIService
|
|
10
|
+
from rapidata.rapidata_client.assets import MediaAsset
|
|
11
|
+
from typing import Sequence
|
|
10
12
|
|
|
11
13
|
class ClassificationOrderBuilder:
|
|
12
14
|
def __init__(self, name: str, question: str, options: list[str], media_paths: list[str], openapi_service: OpenAPIService):
|
|
@@ -19,7 +21,7 @@ class ClassificationOrderBuilder:
|
|
|
19
21
|
self._metadata = None
|
|
20
22
|
self._validation_set_id = None
|
|
21
23
|
|
|
22
|
-
def metadata(self, metadata:
|
|
24
|
+
def metadata(self, metadata: Sequence[Metadata]):
|
|
23
25
|
"""Set the metadata for the classification order. Has to be the same lenght as the media paths."""
|
|
24
26
|
self._metadata = metadata
|
|
25
27
|
return self
|
|
@@ -28,27 +30,29 @@ class ClassificationOrderBuilder:
|
|
|
28
30
|
"""Set the number of responses required for the classification order."""
|
|
29
31
|
self._responses_required = responses_required
|
|
30
32
|
return self
|
|
31
|
-
|
|
33
|
+
|
|
32
34
|
def probability_threshold(self, probability_threshold: float):
|
|
33
35
|
"""Set the probability threshold for early stopping."""
|
|
34
36
|
self._probability_threshold = probability_threshold
|
|
35
37
|
return self
|
|
36
|
-
|
|
38
|
+
|
|
37
39
|
def validation_set_id(self, validation_set_id: str):
|
|
38
40
|
"""Set the validation set ID for the classification order."""
|
|
39
41
|
self._validation_set_id = validation_set_id
|
|
40
42
|
return self
|
|
41
|
-
|
|
43
|
+
|
|
42
44
|
def create(self, submit: bool = True, max_upload_workers: int = 10):
|
|
43
45
|
if self._probability_threshold and self._responses_required:
|
|
44
46
|
referee = ClassifyEarlyStoppingReferee(
|
|
45
47
|
max_vote_count=self._responses_required,
|
|
46
48
|
threshold=self._probability_threshold
|
|
47
49
|
)
|
|
48
|
-
|
|
50
|
+
|
|
49
51
|
else:
|
|
50
52
|
referee = NaiveReferee(required_guesses=self._responses_required)
|
|
51
|
-
|
|
53
|
+
|
|
54
|
+
assets = [MediaAsset(path=media_path) for media_path in self._media_paths]
|
|
55
|
+
|
|
52
56
|
selection: list[Selection] = ([ValidationSelection(amount=1, validation_set_id=self._validation_set_id), LabelingSelection(amount=2)]
|
|
53
57
|
if self._validation_set_id
|
|
54
58
|
else [LabelingSelection(amount=3)])
|
|
@@ -61,14 +65,15 @@ class ClassificationOrderBuilder:
|
|
|
61
65
|
)
|
|
62
66
|
)
|
|
63
67
|
.referee(referee)
|
|
64
|
-
.media(
|
|
68
|
+
.media(assets, metadata=self._metadata) # type: ignore
|
|
65
69
|
.selections(selection)
|
|
66
70
|
.create(submit=submit, max_workers=max_upload_workers))
|
|
67
71
|
|
|
68
72
|
return order
|
|
69
|
-
|
|
73
|
+
|
|
70
74
|
|
|
71
75
|
class ClassificationMediaBuilder:
|
|
76
|
+
"test"
|
|
72
77
|
def __init__(self, name: str, question: str, options: list[str], openapi_service: OpenAPIService):
|
|
73
78
|
self._openapi_service = openapi_service
|
|
74
79
|
self._name = name
|
|
@@ -85,7 +90,7 @@ class ClassificationMediaBuilder:
|
|
|
85
90
|
if self._media_paths is None:
|
|
86
91
|
raise ValueError("Media paths are required")
|
|
87
92
|
return ClassificationOrderBuilder(self._name, self._question, self._options, self._media_paths, openapi_service=self._openapi_service)
|
|
88
|
-
|
|
93
|
+
|
|
89
94
|
|
|
90
95
|
class ClassificationOptionsBuilder:
|
|
91
96
|
def __init__(self, name: str, question: str, openapi_service: OpenAPIService):
|
|
@@ -1,11 +1,13 @@
|
|
|
1
1
|
from rapidata.service.openapi_service import OpenAPIService
|
|
2
|
-
from rapidata.rapidata_client.metadata
|
|
2
|
+
from rapidata.rapidata_client.metadata import Metadata
|
|
3
3
|
from rapidata.rapidata_client.order.rapidata_order_builder import RapidataOrderBuilder
|
|
4
4
|
from rapidata.rapidata_client.workflow.compare_workflow import CompareWorkflow
|
|
5
5
|
from rapidata.rapidata_client.referee.naive_referee import NaiveReferee
|
|
6
6
|
from rapidata.rapidata_client.selection.validation_selection import ValidationSelection
|
|
7
7
|
from rapidata.rapidata_client.selection.labeling_selection import LabelingSelection
|
|
8
8
|
from rapidata.rapidata_client.selection.base_selection import Selection
|
|
9
|
+
from rapidata.rapidata_client.assets import MultiAsset, MediaAsset
|
|
10
|
+
from typing import Sequence
|
|
9
11
|
|
|
10
12
|
class CompareOrderBuilder:
|
|
11
13
|
def __init__(self, name:str, criteria: str, media_paths: list[list[str]], openapi_service: OpenAPIService):
|
|
@@ -22,7 +24,7 @@ class CompareOrderBuilder:
|
|
|
22
24
|
self._responses_required = responses_required
|
|
23
25
|
return self
|
|
24
26
|
|
|
25
|
-
def metadata(self, metadata:
|
|
27
|
+
def metadata(self, metadata: Sequence[Metadata]) -> 'CompareOrderBuilder':
|
|
26
28
|
"""Set the metadata for the comparison order. Has to be the same shape as the media paths."""
|
|
27
29
|
self._metadata = metadata
|
|
28
30
|
return self
|
|
@@ -37,6 +39,7 @@ class CompareOrderBuilder:
|
|
|
37
39
|
if self._validation_set_id
|
|
38
40
|
else [LabelingSelection(amount=3)])
|
|
39
41
|
|
|
42
|
+
media_paths = [MultiAsset([MediaAsset(path=path) for path in paths]) for paths in self._media_paths]
|
|
40
43
|
order = (self._order_builder
|
|
41
44
|
.workflow(
|
|
42
45
|
CompareWorkflow(
|
|
@@ -44,7 +47,7 @@ class CompareOrderBuilder:
|
|
|
44
47
|
)
|
|
45
48
|
)
|
|
46
49
|
.referee(NaiveReferee(required_guesses=self._responses_required))
|
|
47
|
-
.media(
|
|
50
|
+
.media(media_paths, metadata=self._metadata) # type: ignore
|
|
48
51
|
.selections(selection)
|
|
49
52
|
.create(submit=submit, max_workers=max_upload_workers))
|
|
50
53
|
|