rapidata 2.36.1__py3-none-any.whl → 2.37.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rapidata might be problematic. Click here for more details.
- rapidata/__init__.py +2 -2
- rapidata/api_client/__init__.py +2 -4
- rapidata/api_client/api/validation_set_api.py +54 -31
- rapidata/api_client/models/__init__.py +2 -4
- rapidata/api_client/models/add_validation_rapid_model.py +17 -2
- rapidata/api_client/models/asset_metadata.py +9 -1
- rapidata/api_client/models/boost_query_result.py +5 -17
- rapidata/api_client/models/campaign_query_result.py +3 -9
- rapidata/api_client/models/classification_metadata.py +12 -1
- rapidata/api_client/models/compare_workflow_config.py +22 -12
- rapidata/api_client/models/compare_workflow_config_model.py +12 -2
- rapidata/api_client/models/compare_workflow_model.py +12 -2
- rapidata/api_client/models/count_metadata.py +12 -1
- rapidata/api_client/models/create_demographic_rapid_model.py +18 -3
- rapidata/api_client/models/create_order_model.py +6 -48
- rapidata/api_client/models/effort_capped_selection.py +2 -11
- rapidata/api_client/models/evaluation_workflow_config.py +13 -3
- rapidata/api_client/models/evaluation_workflow_model.py +13 -3
- rapidata/api_client/models/file_type_metadata.py +11 -6
- rapidata/api_client/models/file_type_metadata_model.py +2 -8
- rapidata/api_client/models/filter.py +5 -23
- rapidata/api_client/models/get_datapoint_by_id_result.py +3 -9
- rapidata/api_client/models/get_rapid_responses_result.py +3 -9
- rapidata/api_client/models/get_recommended_validation_set_result.py +95 -0
- rapidata/api_client/models/get_standing_by_id_result.py +3 -9
- rapidata/api_client/models/get_validation_rapids_result.py +3 -9
- rapidata/api_client/models/get_workflow_progress_result.py +3 -9
- rapidata/api_client/models/get_workflow_results_result.py +3 -9
- rapidata/api_client/models/image_dimension_metadata.py +12 -1
- rapidata/api_client/models/labeling_selection.py +2 -11
- rapidata/api_client/models/location_metadata.py +12 -1
- rapidata/api_client/models/order_model.py +3 -9
- rapidata/api_client/models/original_filename_metadata.py +12 -1
- rapidata/api_client/models/participant_by_benchmark.py +3 -9
- rapidata/api_client/models/prompt_metadata.py +12 -1
- rapidata/api_client/models/rapid_model.py +3 -9
- rapidata/api_client/models/report_model.py +3 -9
- rapidata/api_client/models/response_count_filter.py +2 -8
- rapidata/api_client/models/response_count_user_filter_model.py +2 -8
- rapidata/api_client/models/root_filter.py +3 -12
- rapidata/api_client/models/runs_by_leaderboard_result.py +3 -9
- rapidata/api_client/models/simple_workflow_config.py +13 -3
- rapidata/api_client/models/simple_workflow_config_model.py +11 -3
- rapidata/api_client/models/simple_workflow_model.py +13 -3
- rapidata/api_client/models/sort_criterion.py +3 -9
- rapidata/api_client/models/source_url_metadata.py +12 -1
- rapidata/api_client/models/standing_by_leaderboard.py +3 -9
- rapidata/api_client/models/streams_metadata.py +12 -1
- rapidata/api_client/models/text_metadata.py +12 -1
- rapidata/api_client/models/transcription_metadata.py +9 -1
- rapidata/api_client/models/update_should_alert_model.py +1 -1
- rapidata/api_client/models/validation_set_model.py +12 -24
- rapidata/api_client/models/video_duration_metadata.py +12 -1
- rapidata/api_client/models/workflow_aggregation_step_model.py +3 -12
- rapidata/api_client_README.md +2 -4
- rapidata/rapidata_client/__init__.py +1 -1
- rapidata/rapidata_client/benchmark/participant/_participant.py +5 -5
- rapidata/rapidata_client/benchmark/rapidata_benchmark.py +2 -1
- rapidata/rapidata_client/benchmark/rapidata_benchmark_manager.py +10 -2
- rapidata/rapidata_client/config/__init__.py +1 -1
- rapidata/rapidata_client/config/rapidata_config.py +31 -0
- rapidata/rapidata_client/datapoints/__init__.py +10 -2
- rapidata/rapidata_client/datapoints/{datapoint.py → _datapoint.py} +105 -17
- rapidata/rapidata_client/datapoints/assets/_media_asset.py +80 -68
- rapidata/rapidata_client/datapoints/assets/_sessions.py +3 -3
- rapidata/rapidata_client/datapoints/assets/constants.py +7 -0
- rapidata/rapidata_client/exceptions/failed_upload_exception.py +42 -13
- rapidata/rapidata_client/filter/response_count_filter.py +16 -11
- rapidata/rapidata_client/order/_rapidata_dataset.py +8 -8
- rapidata/rapidata_client/order/_rapidata_order_builder.py +87 -8
- rapidata/rapidata_client/order/rapidata_order_manager.py +28 -4
- rapidata/rapidata_client/rapidata_client.py +6 -0
- rapidata/rapidata_client/selection/__init__.py +1 -1
- rapidata/rapidata_client/selection/effort_selection.py +18 -7
- rapidata/rapidata_client/selection/labeling_selection.py +19 -7
- rapidata/rapidata_client/selection/{retrieval_modes.py → rapidata_retrieval_modes.py} +7 -4
- rapidata/rapidata_client/validation/rapidata_validation_set.py +26 -1
- rapidata/rapidata_client/validation/rapids/rapids.py +46 -19
- rapidata/rapidata_client/validation/validation_set_manager.py +41 -4
- rapidata/rapidata_client/workflow/_base_workflow.py +27 -0
- rapidata/rapidata_client/workflow/_classify_workflow.py +25 -9
- rapidata/rapidata_client/workflow/_compare_workflow.py +11 -0
- rapidata/rapidata_client/workflow/_draw_workflow.py +15 -7
- rapidata/rapidata_client/workflow/_evaluation_workflow.py +8 -1
- rapidata/rapidata_client/workflow/_free_text_workflow.py +11 -0
- rapidata/rapidata_client/workflow/_locate_workflow.py +15 -7
- rapidata/rapidata_client/workflow/_ranking_workflow.py +39 -15
- rapidata/rapidata_client/workflow/_select_words_workflow.py +41 -7
- rapidata/rapidata_client/workflow/_timestamp_workflow.py +17 -8
- rapidata/service/openapi_service.py +1 -1
- {rapidata-2.36.1.dist-info → rapidata-2.37.0.dist-info}/METADATA +1 -1
- {rapidata-2.36.1.dist-info → rapidata-2.37.0.dist-info}/RECORD +94 -92
- rapidata/rapidata_client/config/config.py +0 -33
- {rapidata-2.36.1.dist-info → rapidata-2.37.0.dist-info}/LICENSE +0 -0
- {rapidata-2.36.1.dist-info → rapidata-2.37.0.dist-info}/WHEEL +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from rapidata.rapidata_client.datapoints.assets import MediaAsset, TextAsset, MultiAsset
|
|
2
2
|
from rapidata.rapidata_client.datapoints.metadata import Metadata
|
|
3
|
-
from typing import
|
|
3
|
+
from typing import Any, cast, Sequence
|
|
4
4
|
from rapidata.api_client.models.add_validation_rapid_model import (
|
|
5
5
|
AddValidationRapidModel,
|
|
6
6
|
)
|
|
@@ -10,23 +10,40 @@ from rapidata.api_client.models.add_validation_rapid_model_payload import (
|
|
|
10
10
|
from rapidata.api_client.models.add_validation_rapid_model_truth import (
|
|
11
11
|
AddValidationRapidModelTruth,
|
|
12
12
|
)
|
|
13
|
-
from rapidata.api_client.models.dataset_dataset_id_datapoints_post_request_metadata_inner import
|
|
13
|
+
from rapidata.api_client.models.dataset_dataset_id_datapoints_post_request_metadata_inner import (
|
|
14
|
+
DatasetDatasetIdDatapointsPostRequestMetadataInner,
|
|
15
|
+
)
|
|
14
16
|
from rapidata.service.openapi_service import OpenAPIService
|
|
15
17
|
|
|
16
18
|
from rapidata.rapidata_client.logging import logger
|
|
19
|
+
from rapidata.rapidata_client.settings._rapidata_setting import RapidataSetting
|
|
17
20
|
|
|
18
21
|
|
|
19
|
-
class Rapid
|
|
20
|
-
def __init__(
|
|
22
|
+
class Rapid:
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
asset: MediaAsset | TextAsset | MultiAsset,
|
|
26
|
+
payload: Any,
|
|
27
|
+
metadata: Sequence[Metadata] | None = None,
|
|
28
|
+
truth: Any | None = None,
|
|
29
|
+
randomCorrectProbability: float | None = None,
|
|
30
|
+
explanation: str | None = None,
|
|
31
|
+
settings: Sequence[RapidataSetting] | None = None,
|
|
32
|
+
):
|
|
21
33
|
self.asset = asset
|
|
22
34
|
self.metadata = metadata
|
|
23
35
|
self.payload = payload
|
|
24
36
|
self.truth = truth
|
|
25
37
|
self.randomCorrectProbability = randomCorrectProbability
|
|
26
|
-
self.explanation = explanation
|
|
27
|
-
|
|
38
|
+
self.explanation = explanation
|
|
39
|
+
self.settings = settings
|
|
40
|
+
logger.debug(
|
|
41
|
+
f"Created Rapid with asset: {self.asset}, metadata: {self.metadata}, payload: {self.payload}, truth: {self.truth}, randomCorrectProbability: {self.randomCorrectProbability}, explanation: {self.explanation}"
|
|
42
|
+
)
|
|
28
43
|
|
|
29
|
-
def _add_to_validation_set(
|
|
44
|
+
def _add_to_validation_set(
|
|
45
|
+
self, validationSetId: str, openapi_service: OpenAPIService
|
|
46
|
+
) -> None:
|
|
30
47
|
model = self.__to_model()
|
|
31
48
|
assets = self.__convert_to_assets()
|
|
32
49
|
if isinstance(assets[0], TextAsset):
|
|
@@ -35,7 +52,7 @@ class Rapid():
|
|
|
35
52
|
openapi_service.validation_api.validation_set_validation_set_id_rapid_post(
|
|
36
53
|
validation_set_id=validationSetId,
|
|
37
54
|
model=model,
|
|
38
|
-
texts=[asset.text for asset in texts]
|
|
55
|
+
texts=[asset.text for asset in texts],
|
|
39
56
|
)
|
|
40
57
|
|
|
41
58
|
elif isinstance(assets[0], MediaAsset):
|
|
@@ -45,15 +62,14 @@ class Rapid():
|
|
|
45
62
|
validation_set_id=validationSetId,
|
|
46
63
|
model=model,
|
|
47
64
|
files=[asset.to_file() for asset in files if asset.is_local()],
|
|
48
|
-
urls=[asset.path for asset in files if not asset.is_local()]
|
|
65
|
+
urls=[asset.path for asset in files if not asset.is_local()],
|
|
49
66
|
)
|
|
50
|
-
|
|
67
|
+
|
|
51
68
|
else:
|
|
52
69
|
raise TypeError("The asset must be a MediaAsset, TextAsset, or MultiAsset")
|
|
53
|
-
|
|
54
|
-
|
|
70
|
+
|
|
55
71
|
def __convert_to_assets(self) -> list[MediaAsset | TextAsset]:
|
|
56
|
-
assets: list[MediaAsset | TextAsset] = []
|
|
72
|
+
assets: list[MediaAsset | TextAsset] = []
|
|
57
73
|
if isinstance(self.asset, MultiAsset):
|
|
58
74
|
for asset in self.asset.assets:
|
|
59
75
|
if isinstance(asset, MediaAsset):
|
|
@@ -61,7 +77,9 @@ class Rapid():
|
|
|
61
77
|
elif isinstance(asset, TextAsset):
|
|
62
78
|
assets.append(asset)
|
|
63
79
|
else:
|
|
64
|
-
raise TypeError(
|
|
80
|
+
raise TypeError(
|
|
81
|
+
"The asset is a multiasset, but not all assets are MediaAssets or TextAssets"
|
|
82
|
+
)
|
|
65
83
|
|
|
66
84
|
if isinstance(self.asset, TextAsset):
|
|
67
85
|
assets = [self.asset]
|
|
@@ -75,10 +93,19 @@ class Rapid():
|
|
|
75
93
|
return AddValidationRapidModel(
|
|
76
94
|
payload=AddValidationRapidModelPayload(self.payload),
|
|
77
95
|
truth=AddValidationRapidModelTruth(self.truth),
|
|
78
|
-
metadata=
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
96
|
+
metadata=(
|
|
97
|
+
[
|
|
98
|
+
DatasetDatasetIdDatapointsPostRequestMetadataInner(meta.to_model())
|
|
99
|
+
for meta in self.metadata
|
|
100
|
+
]
|
|
101
|
+
if self.metadata
|
|
102
|
+
else None
|
|
103
|
+
),
|
|
82
104
|
randomCorrectProbability=self.randomCorrectProbability,
|
|
83
|
-
explanation=self.explanation
|
|
105
|
+
explanation=self.explanation,
|
|
106
|
+
featureFlags=(
|
|
107
|
+
[setting._to_feature_flag() for setting in self.settings]
|
|
108
|
+
if self.settings
|
|
109
|
+
else None
|
|
110
|
+
),
|
|
84
111
|
)
|
|
@@ -18,6 +18,8 @@ from rapidata.api_client.models.page_info import PageInfo
|
|
|
18
18
|
from rapidata.api_client.models.root_filter import RootFilter
|
|
19
19
|
from rapidata.api_client.models.filter import Filter
|
|
20
20
|
from rapidata.api_client.models.sort_criterion import SortCriterion
|
|
21
|
+
from rapidata.api_client.models.sort_direction import SortDirection
|
|
22
|
+
from rapidata.api_client.models.filter_operator import FilterOperator
|
|
21
23
|
|
|
22
24
|
from rapidata.rapidata_client.validation.rapids.box import Box
|
|
23
25
|
|
|
@@ -27,6 +29,11 @@ from rapidata.rapidata_client.logging import (
|
|
|
27
29
|
RapidataOutputManager,
|
|
28
30
|
)
|
|
29
31
|
from tqdm import tqdm
|
|
32
|
+
from rapidata.rapidata_client.workflow import Workflow
|
|
33
|
+
from rapidata.rapidata_client.datapoints._datapoint import Datapoint
|
|
34
|
+
from rapidata.rapidata_client.validation.rapids.rapids import Rapid
|
|
35
|
+
from rapidata.rapidata_client.settings._rapidata_setting import RapidataSetting
|
|
36
|
+
from typing import Sequence
|
|
30
37
|
|
|
31
38
|
|
|
32
39
|
class ValidationSetManager:
|
|
@@ -42,6 +49,25 @@ class ValidationSetManager:
|
|
|
42
49
|
self.rapid = RapidsManager()
|
|
43
50
|
logger.debug("ValidationSetManager initialized")
|
|
44
51
|
|
|
52
|
+
def _create_order_validation_set(
|
|
53
|
+
self,
|
|
54
|
+
workflow: Workflow,
|
|
55
|
+
order_name: str,
|
|
56
|
+
datapoints: list[Datapoint],
|
|
57
|
+
settings: Sequence[RapidataSetting] | None = None,
|
|
58
|
+
) -> RapidataValidationSet:
|
|
59
|
+
rapids: list[Rapid] = []
|
|
60
|
+
for datapoint in datapoints:
|
|
61
|
+
rapids.append(
|
|
62
|
+
Rapid(
|
|
63
|
+
asset=datapoint.asset,
|
|
64
|
+
payload=workflow._to_payload(datapoint),
|
|
65
|
+
metadata=datapoint.metadata,
|
|
66
|
+
settings=settings,
|
|
67
|
+
)
|
|
68
|
+
)
|
|
69
|
+
return self._submit(name=order_name, rapids=rapids, dimensions=[])
|
|
70
|
+
|
|
45
71
|
def create_classification_set(
|
|
46
72
|
self,
|
|
47
73
|
name: str,
|
|
@@ -543,7 +569,10 @@ class ValidationSetManager:
|
|
|
543
569
|
return self._submit(name=name, rapids=rapids, dimensions=dimensions)
|
|
544
570
|
|
|
545
571
|
def _submit(
|
|
546
|
-
self,
|
|
572
|
+
self,
|
|
573
|
+
name: str,
|
|
574
|
+
rapids: list[Rapid],
|
|
575
|
+
dimensions: list[str] | None,
|
|
547
576
|
) -> RapidataValidationSet:
|
|
548
577
|
logger.debug("Creating validation set")
|
|
549
578
|
validation_set_id = (
|
|
@@ -590,7 +619,7 @@ class ValidationSetManager:
|
|
|
590
619
|
managed_print()
|
|
591
620
|
managed_print(
|
|
592
621
|
f"Validation set '{name}' created with ID {validation_set_id}\n",
|
|
593
|
-
f"Now viewable under:
|
|
622
|
+
f"Now viewable under: {validation_set.validation_set_details_page}",
|
|
594
623
|
sep="",
|
|
595
624
|
)
|
|
596
625
|
|
|
@@ -637,10 +666,18 @@ class ValidationSetManager:
|
|
|
637
666
|
QueryModel(
|
|
638
667
|
page=PageInfo(index=1, size=amount),
|
|
639
668
|
filter=RootFilter(
|
|
640
|
-
filters=[
|
|
669
|
+
filters=[
|
|
670
|
+
Filter(
|
|
671
|
+
field="Name",
|
|
672
|
+
operator=FilterOperator.CONTAINS,
|
|
673
|
+
value=name,
|
|
674
|
+
)
|
|
675
|
+
]
|
|
641
676
|
),
|
|
642
677
|
sortCriteria=[
|
|
643
|
-
SortCriterion(
|
|
678
|
+
SortCriterion(
|
|
679
|
+
direction=SortDirection.DESC, propertyName="CreatedAt"
|
|
680
|
+
)
|
|
644
681
|
],
|
|
645
682
|
)
|
|
646
683
|
)
|
|
@@ -5,9 +5,21 @@ from rapidata.api_client.models.simple_workflow_model import SimpleWorkflowModel
|
|
|
5
5
|
from rapidata.api_client.models.evaluation_workflow_model import EvaluationWorkflowModel
|
|
6
6
|
from rapidata.api_client.models.compare_workflow_model import CompareWorkflowModel
|
|
7
7
|
from rapidata.rapidata_client.referee._base_referee import Referee
|
|
8
|
+
from rapidata.api_client import (
|
|
9
|
+
ClassifyPayload,
|
|
10
|
+
ComparePayload,
|
|
11
|
+
LocatePayload,
|
|
12
|
+
ScrubPayload,
|
|
13
|
+
TranscriptionPayload,
|
|
14
|
+
LinePayload,
|
|
15
|
+
FreeTextPayload,
|
|
16
|
+
)
|
|
17
|
+
from rapidata.rapidata_client.datapoints._datapoint import Datapoint
|
|
18
|
+
from rapidata.api_client.models.rapid_modality import RapidModality
|
|
8
19
|
|
|
9
20
|
|
|
10
21
|
class Workflow(ABC):
|
|
22
|
+
modality: RapidModality
|
|
11
23
|
|
|
12
24
|
def __init__(self, type: str):
|
|
13
25
|
self._type = type
|
|
@@ -18,6 +30,21 @@ class Workflow(ABC):
|
|
|
18
30
|
"_t": self._type,
|
|
19
31
|
}
|
|
20
32
|
|
|
33
|
+
@abstractmethod
|
|
34
|
+
def _to_payload(
|
|
35
|
+
self,
|
|
36
|
+
datapoint: Datapoint,
|
|
37
|
+
) -> (
|
|
38
|
+
ClassifyPayload
|
|
39
|
+
| ComparePayload
|
|
40
|
+
| LocatePayload
|
|
41
|
+
| ScrubPayload
|
|
42
|
+
| LinePayload
|
|
43
|
+
| FreeTextPayload
|
|
44
|
+
| TranscriptionPayload
|
|
45
|
+
):
|
|
46
|
+
pass
|
|
47
|
+
|
|
21
48
|
@abstractmethod
|
|
22
49
|
def _to_model(
|
|
23
50
|
self,
|
|
@@ -1,8 +1,15 @@
|
|
|
1
1
|
from typing import Any
|
|
2
|
-
from rapidata.api_client.models.attach_category_rapid_blueprint import
|
|
2
|
+
from rapidata.api_client.models.attach_category_rapid_blueprint import (
|
|
3
|
+
AttachCategoryRapidBlueprint,
|
|
4
|
+
)
|
|
3
5
|
from rapidata.api_client.models.simple_workflow_model import SimpleWorkflowModel
|
|
4
|
-
from rapidata.api_client.models.simple_workflow_model_blueprint import
|
|
6
|
+
from rapidata.api_client.models.simple_workflow_model_blueprint import (
|
|
7
|
+
SimpleWorkflowModelBlueprint,
|
|
8
|
+
)
|
|
5
9
|
from rapidata.rapidata_client.workflow import Workflow
|
|
10
|
+
from rapidata.api_client import ClassifyPayload
|
|
11
|
+
from rapidata.rapidata_client.datapoints._datapoint import Datapoint
|
|
12
|
+
from rapidata.api_client.models.rapid_modality import RapidModality
|
|
6
13
|
|
|
7
14
|
|
|
8
15
|
class ClassifyWorkflow(Workflow):
|
|
@@ -21,29 +28,38 @@ class ClassifyWorkflow(Workflow):
|
|
|
21
28
|
_options (list[str]): The list of classification options.
|
|
22
29
|
"""
|
|
23
30
|
|
|
31
|
+
modality = RapidModality.CLASSIFY
|
|
32
|
+
|
|
24
33
|
def __init__(self, instruction: str, answer_options: list[str]):
|
|
25
34
|
super().__init__(type="SimpleWorkflowConfig")
|
|
26
|
-
self.
|
|
27
|
-
self.
|
|
35
|
+
self._instruction = instruction
|
|
36
|
+
self._answer_options = answer_options
|
|
28
37
|
|
|
29
38
|
def _to_dict(self) -> dict[str, Any]:
|
|
30
39
|
return {
|
|
31
40
|
**super()._to_dict(),
|
|
32
41
|
"blueprint": {
|
|
33
42
|
"_t": "ClassifyBlueprint",
|
|
34
|
-
"title": self.
|
|
35
|
-
"possibleCategories": self.
|
|
36
|
-
}
|
|
43
|
+
"title": self._instruction,
|
|
44
|
+
"possibleCategories": self._answer_options,
|
|
45
|
+
},
|
|
37
46
|
}
|
|
38
47
|
|
|
39
48
|
def _to_model(self) -> SimpleWorkflowModel:
|
|
40
49
|
blueprint = AttachCategoryRapidBlueprint(
|
|
41
50
|
_t="ClassifyBlueprint",
|
|
42
|
-
title=self.
|
|
43
|
-
possibleCategories=self.
|
|
51
|
+
title=self._instruction,
|
|
52
|
+
possibleCategories=self._answer_options,
|
|
44
53
|
)
|
|
45
54
|
|
|
46
55
|
return SimpleWorkflowModel(
|
|
47
56
|
_t="SimpleWorkflow",
|
|
48
57
|
blueprint=SimpleWorkflowModelBlueprint(blueprint),
|
|
49
58
|
)
|
|
59
|
+
|
|
60
|
+
def _to_payload(self, datapoint: Datapoint) -> ClassifyPayload:
|
|
61
|
+
return ClassifyPayload(
|
|
62
|
+
_t="ClassifyPayload",
|
|
63
|
+
possibleCategories=self._answer_options,
|
|
64
|
+
title=self._instruction,
|
|
65
|
+
)
|
|
@@ -5,6 +5,9 @@ from rapidata.api_client.models.simple_workflow_model_blueprint import (
|
|
|
5
5
|
from rapidata.rapidata_client.workflow import Workflow
|
|
6
6
|
from rapidata.api_client.models.compare_rapid_blueprint import CompareRapidBlueprint
|
|
7
7
|
from rapidata.api_client.models.simple_workflow_model import SimpleWorkflowModel
|
|
8
|
+
from rapidata.api_client import ComparePayload
|
|
9
|
+
from rapidata.rapidata_client.datapoints._datapoint import Datapoint
|
|
10
|
+
from rapidata.api_client.models.rapid_modality import RapidModality
|
|
8
11
|
|
|
9
12
|
|
|
10
13
|
class CompareWorkflow(Workflow):
|
|
@@ -21,6 +24,8 @@ class CompareWorkflow(Workflow):
|
|
|
21
24
|
instruction (str): The instruction to be used for comparison.
|
|
22
25
|
"""
|
|
23
26
|
|
|
27
|
+
modality = RapidModality.COMPARE
|
|
28
|
+
|
|
24
29
|
def __init__(self, instruction: str, a_b_names: list[str] | None = None):
|
|
25
30
|
super().__init__(type="CompareWorkflowConfig")
|
|
26
31
|
self._instruction = instruction
|
|
@@ -44,3 +49,9 @@ class CompareWorkflow(Workflow):
|
|
|
44
49
|
_t="SimpleWorkflow",
|
|
45
50
|
blueprint=SimpleWorkflowModelBlueprint(blueprint),
|
|
46
51
|
)
|
|
52
|
+
|
|
53
|
+
def _to_payload(self, datapoint: Datapoint) -> ComparePayload:
|
|
54
|
+
return ComparePayload(
|
|
55
|
+
_t="ComparePayload",
|
|
56
|
+
criteria=self._instruction,
|
|
57
|
+
)
|
|
@@ -1,22 +1,30 @@
|
|
|
1
1
|
from rapidata.api_client.models.simple_workflow_model import SimpleWorkflowModel
|
|
2
|
-
from rapidata.api_client.models.simple_workflow_model_blueprint import
|
|
2
|
+
from rapidata.api_client.models.simple_workflow_model_blueprint import (
|
|
3
|
+
SimpleWorkflowModelBlueprint,
|
|
4
|
+
)
|
|
3
5
|
from rapidata.api_client.models.line_rapid_blueprint import LineRapidBlueprint
|
|
4
6
|
from rapidata.rapidata_client.workflow._base_workflow import Workflow
|
|
7
|
+
from rapidata.api_client import LinePayload
|
|
8
|
+
from rapidata.rapidata_client.datapoints._datapoint import Datapoint
|
|
9
|
+
from rapidata.api_client.models.rapid_modality import RapidModality
|
|
5
10
|
|
|
6
11
|
|
|
7
12
|
class DrawWorkflow(Workflow):
|
|
13
|
+
modality = RapidModality.LINE
|
|
8
14
|
|
|
9
15
|
def __init__(self, target: str):
|
|
10
16
|
super().__init__(type="SimpleWorkflowConfig")
|
|
11
17
|
self._target = target
|
|
12
18
|
|
|
13
19
|
def _to_model(self) -> SimpleWorkflowModel:
|
|
14
|
-
blueprint = LineRapidBlueprint(
|
|
15
|
-
_t="LineBlueprint",
|
|
16
|
-
target=self._target
|
|
17
|
-
)
|
|
20
|
+
blueprint = LineRapidBlueprint(_t="LineBlueprint", target=self._target)
|
|
18
21
|
|
|
19
22
|
return SimpleWorkflowModel(
|
|
20
|
-
_t="SimpleWorkflow",
|
|
21
|
-
|
|
23
|
+
_t="SimpleWorkflow", blueprint=SimpleWorkflowModelBlueprint(blueprint)
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
def _to_payload(self, datapoint: Datapoint) -> LinePayload:
|
|
27
|
+
return LinePayload(
|
|
28
|
+
_t="LinePayload",
|
|
29
|
+
target=self._target,
|
|
22
30
|
)
|
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
from rapidata.api_client.models.evaluation_workflow_model import EvaluationWorkflowModel
|
|
2
2
|
from rapidata.rapidata_client.workflow._base_workflow import Workflow
|
|
3
|
+
from rapidata.rapidata_client.datapoints._datapoint import Datapoint
|
|
4
|
+
from rapidata.api_client.models.rapid_modality import RapidModality
|
|
3
5
|
|
|
4
6
|
|
|
5
7
|
class EvaluationWorkflow(Workflow):
|
|
@@ -7,12 +9,14 @@ class EvaluationWorkflow(Workflow):
|
|
|
7
9
|
A workflow to run evaluation orders.
|
|
8
10
|
|
|
9
11
|
This is used internally only and should not be necessary to be used by clients.
|
|
10
|
-
|
|
12
|
+
|
|
11
13
|
Args:
|
|
12
14
|
validation_set_id (str): a source for the tasks that will be sent to the user
|
|
13
15
|
should_accept_incorrect (bool): indicates if the user should get feedback on their answers if they answer wrong. If set to true the user will not notice that he was tested.
|
|
14
16
|
"""
|
|
15
17
|
|
|
18
|
+
modality = RapidModality.NONE
|
|
19
|
+
|
|
16
20
|
def __init__(self, validation_set_id: str, should_accept_incorrect: bool):
|
|
17
21
|
super().__init__("EvaluationWorkflow")
|
|
18
22
|
self.validation_set_id = validation_set_id
|
|
@@ -24,3 +28,6 @@ class EvaluationWorkflow(Workflow):
|
|
|
24
28
|
validationSetId=self.validation_set_id,
|
|
25
29
|
shouldAcceptIncorrect=self.should_accept_incorrect,
|
|
26
30
|
)
|
|
31
|
+
|
|
32
|
+
def _to_payload(self, datapoint: Datapoint):
|
|
33
|
+
raise NotImplementedError("EvaluationWorkflow does not have a payload")
|
|
@@ -5,6 +5,9 @@ from rapidata.api_client.models.simple_workflow_model_blueprint import (
|
|
|
5
5
|
)
|
|
6
6
|
from rapidata.rapidata_client.workflow import Workflow
|
|
7
7
|
from rapidata.api_client.models.free_text_rapid_blueprint import FreeTextRapidBlueprint
|
|
8
|
+
from rapidata.api_client import FreeTextPayload
|
|
9
|
+
from rapidata.rapidata_client.datapoints._datapoint import Datapoint
|
|
10
|
+
from rapidata.api_client.models.rapid_modality import RapidModality
|
|
8
11
|
|
|
9
12
|
|
|
10
13
|
class FreeTextWorkflow(Workflow):
|
|
@@ -23,6 +26,8 @@ class FreeTextWorkflow(Workflow):
|
|
|
23
26
|
Should always specify that the LLM should respond with 'not spam' or 'spam'.
|
|
24
27
|
"""
|
|
25
28
|
|
|
29
|
+
modality = RapidModality.FREETEXT
|
|
30
|
+
|
|
26
31
|
def __init__(self, instruction: str, validation_system_prompt: str | None = None):
|
|
27
32
|
super().__init__(type="SimpleWorkflowConfig")
|
|
28
33
|
self._instruction = instruction
|
|
@@ -49,3 +54,9 @@ class FreeTextWorkflow(Workflow):
|
|
|
49
54
|
_t="SimpleWorkflow",
|
|
50
55
|
blueprint=SimpleWorkflowModelBlueprint(blueprint),
|
|
51
56
|
)
|
|
57
|
+
|
|
58
|
+
def _to_payload(self, datapoint: Datapoint) -> FreeTextPayload:
|
|
59
|
+
return FreeTextPayload(
|
|
60
|
+
_t="FreeTextPayload",
|
|
61
|
+
question=self._instruction,
|
|
62
|
+
)
|
|
@@ -1,22 +1,30 @@
|
|
|
1
1
|
from rapidata.api_client.models.simple_workflow_model import SimpleWorkflowModel
|
|
2
|
-
from rapidata.api_client.models.simple_workflow_model_blueprint import
|
|
2
|
+
from rapidata.api_client.models.simple_workflow_model_blueprint import (
|
|
3
|
+
SimpleWorkflowModelBlueprint,
|
|
4
|
+
)
|
|
3
5
|
from rapidata.api_client.models.locate_rapid_blueprint import LocateRapidBlueprint
|
|
4
6
|
from rapidata.rapidata_client.workflow._base_workflow import Workflow
|
|
7
|
+
from rapidata.api_client import LocatePayload
|
|
8
|
+
from rapidata.rapidata_client.datapoints._datapoint import Datapoint
|
|
9
|
+
from rapidata.api_client.models.rapid_modality import RapidModality
|
|
5
10
|
|
|
6
11
|
|
|
7
12
|
class LocateWorkflow(Workflow):
|
|
13
|
+
modality = RapidModality.LOCATE
|
|
8
14
|
|
|
9
15
|
def __init__(self, target: str):
|
|
10
16
|
super().__init__(type="SimpleWorkflowConfig")
|
|
11
17
|
self._target = target
|
|
12
18
|
|
|
13
19
|
def _to_model(self) -> SimpleWorkflowModel:
|
|
14
|
-
blueprint = LocateRapidBlueprint(
|
|
15
|
-
_t="LocateBlueprint",
|
|
16
|
-
target=self._target
|
|
17
|
-
)
|
|
20
|
+
blueprint = LocateRapidBlueprint(_t="LocateBlueprint", target=self._target)
|
|
18
21
|
|
|
19
22
|
return SimpleWorkflowModel(
|
|
20
|
-
_t="SimpleWorkflow",
|
|
21
|
-
|
|
23
|
+
_t="SimpleWorkflow", blueprint=SimpleWorkflowModelBlueprint(blueprint)
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
def _to_payload(self, datapoint: Datapoint) -> LocatePayload:
|
|
27
|
+
return LocatePayload(
|
|
28
|
+
_t="LocatePayload",
|
|
29
|
+
target=self._target,
|
|
22
30
|
)
|
|
@@ -1,30 +1,48 @@
|
|
|
1
|
-
from rapidata.api_client import
|
|
1
|
+
from rapidata.api_client import (
|
|
2
|
+
CompareWorkflowModelPairMakerConfig,
|
|
3
|
+
OnlinePairMakerConfigModel,
|
|
4
|
+
EloConfigModel,
|
|
5
|
+
)
|
|
2
6
|
from rapidata.api_client.models.compare_workflow_model import CompareWorkflowModel
|
|
3
7
|
from rapidata.rapidata_client.workflow._base_workflow import Workflow
|
|
4
8
|
from rapidata.rapidata_client.datapoints.metadata import PromptMetadata
|
|
5
|
-
from rapidata.api_client.models.dataset_dataset_id_datapoints_post_request_metadata_inner import
|
|
9
|
+
from rapidata.api_client.models.dataset_dataset_id_datapoints_post_request_metadata_inner import (
|
|
10
|
+
DatasetDatasetIdDatapointsPostRequestMetadataInner,
|
|
11
|
+
)
|
|
12
|
+
from rapidata.api_client import ComparePayload
|
|
13
|
+
from rapidata.rapidata_client.datapoints._datapoint import Datapoint
|
|
14
|
+
from rapidata.api_client.models.rapid_modality import RapidModality
|
|
6
15
|
|
|
7
16
|
|
|
8
17
|
class RankingWorkflow(Workflow):
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
+
modality = RapidModality.COMPARE
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
criteria: str,
|
|
23
|
+
total_comparison_budget: int,
|
|
24
|
+
random_comparisons_ratio,
|
|
25
|
+
elo_start: int = 1200,
|
|
26
|
+
elo_k_factor: int = 40,
|
|
27
|
+
elo_scaling_factor: int = 400,
|
|
28
|
+
context: str | None = None,
|
|
29
|
+
):
|
|
18
30
|
super().__init__(type="CompareWorkflowConfig")
|
|
19
31
|
|
|
20
|
-
self.context =
|
|
21
|
-
|
|
22
|
-
|
|
32
|
+
self.context = (
|
|
33
|
+
[
|
|
34
|
+
DatasetDatasetIdDatapointsPostRequestMetadataInner(
|
|
35
|
+
PromptMetadata(context).to_model()
|
|
36
|
+
)
|
|
37
|
+
]
|
|
38
|
+
if context
|
|
39
|
+
else None
|
|
40
|
+
)
|
|
23
41
|
|
|
24
42
|
self.criteria = criteria
|
|
25
43
|
self.pair_maker_config = CompareWorkflowModelPairMakerConfig(
|
|
26
44
|
OnlinePairMakerConfigModel(
|
|
27
|
-
_t=
|
|
45
|
+
_t="OnlinePairMaker",
|
|
28
46
|
totalComparisonBudget=total_comparison_budget,
|
|
29
47
|
randomMatchesRatio=random_comparisons_ratio,
|
|
30
48
|
)
|
|
@@ -45,3 +63,9 @@ class RankingWorkflow(Workflow):
|
|
|
45
63
|
pairMakerConfig=self.pair_maker_config,
|
|
46
64
|
metadata=self.context,
|
|
47
65
|
)
|
|
66
|
+
|
|
67
|
+
def _to_payload(self, datapoint: Datapoint) -> ComparePayload:
|
|
68
|
+
return ComparePayload(
|
|
69
|
+
_t="ComparePayload",
|
|
70
|
+
criteria=self.criteria,
|
|
71
|
+
)
|
|
@@ -1,14 +1,24 @@
|
|
|
1
1
|
from rapidata.api_client.models.simple_workflow_model import SimpleWorkflowModel
|
|
2
|
-
from rapidata.api_client.models.simple_workflow_model_blueprint import
|
|
3
|
-
|
|
2
|
+
from rapidata.api_client.models.simple_workflow_model_blueprint import (
|
|
3
|
+
SimpleWorkflowModelBlueprint,
|
|
4
|
+
)
|
|
5
|
+
from rapidata.api_client.models.transcription_rapid_blueprint import (
|
|
6
|
+
TranscriptionRapidBlueprint,
|
|
7
|
+
)
|
|
4
8
|
from rapidata.rapidata_client.workflow._base_workflow import Workflow
|
|
9
|
+
from rapidata.api_client import TranscriptionPayload, TranscriptionWord
|
|
10
|
+
from rapidata.rapidata_client.datapoints._datapoint import Datapoint
|
|
11
|
+
from rapidata.rapidata_client.datapoints.metadata._select_words_metadata import (
|
|
12
|
+
SelectWordsMetadata,
|
|
13
|
+
)
|
|
14
|
+
from rapidata.api_client.models.rapid_modality import RapidModality
|
|
5
15
|
|
|
6
16
|
|
|
7
17
|
class SelectWordsWorkflow(Workflow):
|
|
8
18
|
"""
|
|
9
19
|
A workflow for select words tasks.
|
|
10
20
|
|
|
11
|
-
This class represents a select words workflow
|
|
21
|
+
This class represents a select words workflow
|
|
12
22
|
where datapoints have a sentence attached to them where words can be selected.
|
|
13
23
|
|
|
14
24
|
Attributes:
|
|
@@ -18,17 +28,41 @@ class SelectWordsWorkflow(Workflow):
|
|
|
18
28
|
instruction (str): The instruction to be provided for the select words task.
|
|
19
29
|
"""
|
|
20
30
|
|
|
31
|
+
modality = RapidModality.TRANSCRIPTION
|
|
32
|
+
|
|
21
33
|
def __init__(self, instruction: str):
|
|
22
34
|
super().__init__(type="SimpleWorkflowConfig")
|
|
23
35
|
self._instruction = instruction
|
|
24
36
|
|
|
25
37
|
def _to_model(self) -> SimpleWorkflowModel:
|
|
26
38
|
blueprint = TranscriptionRapidBlueprint(
|
|
27
|
-
_t="TranscriptionBlueprint",
|
|
28
|
-
title=self._instruction
|
|
39
|
+
_t="TranscriptionBlueprint", title=self._instruction
|
|
29
40
|
)
|
|
30
41
|
|
|
31
42
|
return SimpleWorkflowModel(
|
|
32
|
-
_t="SimpleWorkflow",
|
|
33
|
-
|
|
43
|
+
_t="SimpleWorkflow", blueprint=SimpleWorkflowModelBlueprint(blueprint)
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
def _to_payload(self, datapoint: Datapoint) -> TranscriptionPayload:
|
|
47
|
+
assert (
|
|
48
|
+
datapoint.metadata is not None
|
|
49
|
+
), "SelectWordsWorkflow requires a metadata datapoint"
|
|
50
|
+
|
|
51
|
+
assert any(
|
|
52
|
+
isinstance(metadata, SelectWordsMetadata) for metadata in datapoint.metadata
|
|
53
|
+
), "SelectWordsWorkflow requires a SelectWordsMetadata datapoint"
|
|
54
|
+
|
|
55
|
+
select_words_metadata = next(
|
|
56
|
+
metadata
|
|
57
|
+
for metadata in datapoint.metadata
|
|
58
|
+
if isinstance(metadata, SelectWordsMetadata)
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
return TranscriptionPayload(
|
|
62
|
+
_t="TranscriptionPayload",
|
|
63
|
+
title=self._instruction,
|
|
64
|
+
transcription=[
|
|
65
|
+
TranscriptionWord(word=word, wordIndex=i)
|
|
66
|
+
for i, word in enumerate(select_words_metadata.select_words.split())
|
|
67
|
+
],
|
|
34
68
|
)
|