rapidata 2.35.1__py3-none-any.whl → 2.35.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rapidata might be problematic. Click here for more details.
- rapidata/__init__.py +2 -1
- rapidata/api_client/api/leaderboard_api.py +3 -3
- rapidata/api_client_README.md +1 -1
- rapidata/rapidata_client/__init__.py +5 -13
- rapidata/rapidata_client/api/rapidata_exception.py +61 -32
- rapidata/rapidata_client/benchmark/participant/_participant.py +45 -26
- rapidata/rapidata_client/benchmark/rapidata_benchmark_manager.py +73 -30
- rapidata/rapidata_client/config/__init__.py +1 -0
- rapidata/rapidata_client/config/config.py +33 -0
- rapidata/rapidata_client/datapoints/assets/_multi_asset.py +7 -7
- rapidata/rapidata_client/datapoints/assets/_sessions.py +13 -8
- rapidata/rapidata_client/order/_rapidata_dataset.py +166 -115
- rapidata/rapidata_client/order/_rapidata_order_builder.py +54 -22
- rapidata/rapidata_client/order/rapidata_order.py +109 -48
- rapidata/rapidata_client/rapidata_client.py +19 -14
- rapidata/rapidata_client/validation/rapidata_validation_set.py +13 -7
- rapidata/rapidata_client/validation/validation_set_manager.py +167 -98
- rapidata/service/credential_manager.py +13 -13
- rapidata/service/openapi_service.py +22 -13
- {rapidata-2.35.1.dist-info → rapidata-2.35.3.dist-info}/METADATA +1 -1
- {rapidata-2.35.1.dist-info → rapidata-2.35.3.dist-info}/RECORD +23 -21
- {rapidata-2.35.1.dist-info → rapidata-2.35.3.dist-info}/LICENSE +0 -0
- {rapidata-2.35.1.dist-info → rapidata-2.35.3.dist-info}/WHEEL +0 -0
rapidata/__init__.py
CHANGED
|
@@ -3117,7 +3117,7 @@ class LeaderboardApi:
|
|
|
3117
3117
|
_headers: Optional[Dict[StrictStr, Any]] = None,
|
|
3118
3118
|
_host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0,
|
|
3119
3119
|
) -> None:
|
|
3120
|
-
"""Updates the
|
|
3120
|
+
"""Updates the response config of a leaderboard.
|
|
3121
3121
|
|
|
3122
3122
|
|
|
3123
3123
|
:param leaderboard_id: (required)
|
|
@@ -3187,7 +3187,7 @@ class LeaderboardApi:
|
|
|
3187
3187
|
_headers: Optional[Dict[StrictStr, Any]] = None,
|
|
3188
3188
|
_host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0,
|
|
3189
3189
|
) -> ApiResponse[None]:
|
|
3190
|
-
"""Updates the
|
|
3190
|
+
"""Updates the response config of a leaderboard.
|
|
3191
3191
|
|
|
3192
3192
|
|
|
3193
3193
|
:param leaderboard_id: (required)
|
|
@@ -3257,7 +3257,7 @@ class LeaderboardApi:
|
|
|
3257
3257
|
_headers: Optional[Dict[StrictStr, Any]] = None,
|
|
3258
3258
|
_host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0,
|
|
3259
3259
|
) -> RESTResponseType:
|
|
3260
|
-
"""Updates the
|
|
3260
|
+
"""Updates the response config of a leaderboard.
|
|
3261
3261
|
|
|
3262
3262
|
|
|
3263
3263
|
:param leaderboard_id: (required)
|
rapidata/api_client_README.md
CHANGED
|
@@ -135,7 +135,7 @@ Class | Method | HTTP request | Description
|
|
|
135
135
|
*LeaderboardApi* | [**leaderboard_leaderboard_id_participants_post**](rapidata/api_client/docs/LeaderboardApi.md#leaderboard_leaderboard_id_participants_post) | **POST** /leaderboard/{leaderboardId}/participants | Creates a participant in a leaderboard.
|
|
136
136
|
*LeaderboardApi* | [**leaderboard_leaderboard_id_prompts_get**](rapidata/api_client/docs/LeaderboardApi.md#leaderboard_leaderboard_id_prompts_get) | **GET** /leaderboard/{leaderboardId}/prompts | returns the paged prompts of a leaderboard by its ID.
|
|
137
137
|
*LeaderboardApi* | [**leaderboard_leaderboard_id_prompts_post**](rapidata/api_client/docs/LeaderboardApi.md#leaderboard_leaderboard_id_prompts_post) | **POST** /leaderboard/{leaderboardId}/prompts | adds a new prompt to a leaderboard.
|
|
138
|
-
*LeaderboardApi* | [**leaderboard_leaderboard_id_response_config_put**](rapidata/api_client/docs/LeaderboardApi.md#leaderboard_leaderboard_id_response_config_put) | **PUT** /leaderboard/{leaderboardId}/response-config | Updates the
|
|
138
|
+
*LeaderboardApi* | [**leaderboard_leaderboard_id_response_config_put**](rapidata/api_client/docs/LeaderboardApi.md#leaderboard_leaderboard_id_response_config_put) | **PUT** /leaderboard/{leaderboardId}/response-config | Updates the response config of a leaderboard.
|
|
139
139
|
*LeaderboardApi* | [**leaderboard_leaderboard_id_runs_get**](rapidata/api_client/docs/LeaderboardApi.md#leaderboard_leaderboard_id_runs_get) | **GET** /leaderboard/{leaderboardId}/runs | Gets the runs related to a leaderboard
|
|
140
140
|
*LeaderboardApi* | [**leaderboard_leaderboard_id_standings_get**](rapidata/api_client/docs/LeaderboardApi.md#leaderboard_leaderboard_id_standings_get) | **GET** /leaderboard/{leaderboardId}/standings | queries all the participants connected to leaderboard by its ID.
|
|
141
141
|
*LeaderboardApi* | [**leaderboard_post**](rapidata/api_client/docs/LeaderboardApi.md#leaderboard_post) | **POST** /leaderboard | Creates a new leaderboard with the specified name and criteria.
|
|
@@ -16,14 +16,9 @@ from .datapoints.metadata import (
|
|
|
16
16
|
PromptMetadata,
|
|
17
17
|
SelectWordsMetadata,
|
|
18
18
|
)
|
|
19
|
-
from .datapoints.assets import
|
|
20
|
-
MediaAsset,
|
|
21
|
-
TextAsset,
|
|
22
|
-
MultiAsset,
|
|
23
|
-
RapidataDataTypes
|
|
24
|
-
)
|
|
19
|
+
from .datapoints.assets import MediaAsset, TextAsset, MultiAsset, RapidataDataTypes
|
|
25
20
|
from .settings import (
|
|
26
|
-
RapidataSettings,
|
|
21
|
+
RapidataSettings,
|
|
27
22
|
TranslationBehaviourOptions,
|
|
28
23
|
AlertOnFastResponse,
|
|
29
24
|
TranslationBehaviour,
|
|
@@ -32,7 +27,7 @@ from .settings import (
|
|
|
32
27
|
PlayVideoUntilTheEnd,
|
|
33
28
|
CustomSetting,
|
|
34
29
|
AllowNeitherBoth,
|
|
35
|
-
|
|
30
|
+
)
|
|
36
31
|
from .country_codes import CountryCodes
|
|
37
32
|
from .filter import (
|
|
38
33
|
CountryFilter,
|
|
@@ -49,11 +44,8 @@ from .filter import (
|
|
|
49
44
|
ResponseCountFilter,
|
|
50
45
|
)
|
|
51
46
|
|
|
52
|
-
from .logging import
|
|
53
|
-
configure_logger,
|
|
54
|
-
logger,
|
|
55
|
-
RapidataOutputManager
|
|
56
|
-
)
|
|
47
|
+
from .logging import configure_logger, logger, RapidataOutputManager
|
|
57
48
|
|
|
58
49
|
from .validation import Box
|
|
59
50
|
from .exceptions import FailedUploadException
|
|
51
|
+
from .config import rapidata_config
|
|
@@ -2,106 +2,135 @@ from typing import Optional, Any
|
|
|
2
2
|
from rapidata.api_client.api_client import ApiClient, rest, ApiResponse, ApiResponseT
|
|
3
3
|
from rapidata.api_client.exceptions import ApiException
|
|
4
4
|
import json
|
|
5
|
+
import threading
|
|
6
|
+
from contextlib import contextmanager
|
|
5
7
|
from rapidata.rapidata_client.logging import logger
|
|
6
8
|
|
|
9
|
+
# Thread-local storage for controlling error logging
|
|
10
|
+
_thread_local = threading.local()
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@contextmanager
|
|
14
|
+
def suppress_rapidata_error_logging():
|
|
15
|
+
"""Context manager to suppress error logging for RapidataApiClient calls."""
|
|
16
|
+
old_value = getattr(_thread_local, "suppress_error_logging", False)
|
|
17
|
+
_thread_local.suppress_error_logging = True
|
|
18
|
+
try:
|
|
19
|
+
yield
|
|
20
|
+
finally:
|
|
21
|
+
_thread_local.suppress_error_logging = old_value
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _should_suppress_error_logging() -> bool:
|
|
25
|
+
"""Check if error logging should be suppressed for the current thread."""
|
|
26
|
+
return getattr(_thread_local, "suppress_error_logging", False)
|
|
27
|
+
|
|
28
|
+
|
|
7
29
|
class RapidataError(Exception):
|
|
8
30
|
"""Custom error class for Rapidata API errors."""
|
|
9
|
-
|
|
31
|
+
|
|
10
32
|
def __init__(
|
|
11
|
-
self,
|
|
12
|
-
status_code: Optional[int] = None,
|
|
13
|
-
message: str | None = None,
|
|
33
|
+
self,
|
|
34
|
+
status_code: Optional[int] = None,
|
|
35
|
+
message: str | None = None,
|
|
14
36
|
original_exception: Exception | None = None,
|
|
15
|
-
details: Any = None
|
|
37
|
+
details: Any = None,
|
|
16
38
|
):
|
|
17
39
|
self.status_code = status_code
|
|
18
40
|
self.message = message
|
|
19
41
|
self.original_exception = original_exception
|
|
20
42
|
self.details = details
|
|
21
|
-
|
|
43
|
+
|
|
22
44
|
# Create a nice error message
|
|
23
45
|
error_msg = "Rapidata API Error"
|
|
24
46
|
if status_code:
|
|
25
47
|
error_msg += f" ({status_code})"
|
|
26
48
|
if message:
|
|
27
49
|
error_msg += f": {message}"
|
|
28
|
-
|
|
50
|
+
|
|
29
51
|
super().__init__(error_msg)
|
|
30
52
|
|
|
31
53
|
def __str__(self):
|
|
32
|
-
"""Return a string representation of the error."""
|
|
54
|
+
"""Return a string representation of the error."""
|
|
33
55
|
# Extract information from message if available
|
|
34
56
|
title = None
|
|
35
57
|
errors = None
|
|
36
58
|
trace_id = None
|
|
37
|
-
|
|
59
|
+
|
|
38
60
|
# Try to extract from details if available and is a dict
|
|
39
61
|
if self.details and isinstance(self.details, dict):
|
|
40
|
-
title = self.details.get(
|
|
41
|
-
errors = self.details.get(
|
|
42
|
-
trace_id = self.details.get(
|
|
43
|
-
|
|
62
|
+
title = self.details.get("title")
|
|
63
|
+
errors = self.details.get("errors")
|
|
64
|
+
trace_id = self.details.get("traceId")
|
|
65
|
+
|
|
44
66
|
# Build the error string
|
|
45
67
|
error_parts = []
|
|
46
|
-
|
|
68
|
+
|
|
47
69
|
# Main error line
|
|
48
70
|
if title:
|
|
49
71
|
error_parts.append(f"{title}")
|
|
50
72
|
else:
|
|
51
73
|
error_parts.append(f"{self.message or 'Unknown error'}")
|
|
52
|
-
|
|
74
|
+
|
|
53
75
|
# Reasons
|
|
54
76
|
if errors:
|
|
55
77
|
if isinstance(errors, dict):
|
|
56
78
|
error_parts.append(f"Reasons: {json.dumps({'errors': errors})}")
|
|
57
79
|
else:
|
|
58
80
|
error_parts.append(f"Reasons: {errors}")
|
|
59
|
-
|
|
81
|
+
|
|
60
82
|
# Trace ID
|
|
61
83
|
if trace_id:
|
|
62
84
|
error_parts.append(f"Trace Id: {trace_id}")
|
|
63
85
|
else:
|
|
64
86
|
error_parts.append("Trace Id: N/A")
|
|
65
|
-
|
|
87
|
+
|
|
66
88
|
return "\n".join(error_parts)
|
|
67
89
|
|
|
90
|
+
|
|
68
91
|
class RapidataApiClient(ApiClient):
|
|
69
92
|
"""Custom API client that wraps errors in RapidataError."""
|
|
70
93
|
|
|
71
94
|
def response_deserialize(
|
|
72
95
|
self,
|
|
73
96
|
response_data: rest.RESTResponse,
|
|
74
|
-
response_types_map: Optional[dict[str, ApiResponseT]] = None
|
|
97
|
+
response_types_map: Optional[dict[str, ApiResponseT]] = None,
|
|
75
98
|
) -> ApiResponse[ApiResponseT]:
|
|
76
99
|
"""Override the response_deserialize method to catch and convert exceptions."""
|
|
77
100
|
try:
|
|
78
101
|
return super().response_deserialize(response_data, response_types_map)
|
|
79
102
|
except ApiException as e:
|
|
80
|
-
status_code = getattr(e,
|
|
103
|
+
status_code = getattr(e, "status", None)
|
|
81
104
|
message = str(e)
|
|
82
105
|
details = None
|
|
83
|
-
|
|
106
|
+
|
|
84
107
|
# Extract more detailed error message from response body if available
|
|
85
|
-
if hasattr(e,
|
|
108
|
+
if hasattr(e, "body") and e.body:
|
|
86
109
|
try:
|
|
87
110
|
body_json = json.loads(e.body)
|
|
88
111
|
if isinstance(body_json, dict):
|
|
89
|
-
if
|
|
90
|
-
message = body_json[
|
|
91
|
-
elif
|
|
92
|
-
message = body_json[
|
|
93
|
-
|
|
112
|
+
if "message" in body_json:
|
|
113
|
+
message = body_json["message"]
|
|
114
|
+
elif "error" in body_json:
|
|
115
|
+
message = body_json["error"]
|
|
116
|
+
|
|
94
117
|
# Store the full error details for debugging
|
|
95
118
|
details = body_json
|
|
96
119
|
except (json.JSONDecodeError, AttributeError):
|
|
97
120
|
# If we can't parse the body as JSON, use the original message
|
|
98
121
|
pass
|
|
99
|
-
|
|
100
|
-
error_formatted =
|
|
101
|
-
status_code=status_code,
|
|
102
|
-
message=message,
|
|
122
|
+
|
|
123
|
+
error_formatted = RapidataError(
|
|
124
|
+
status_code=status_code,
|
|
125
|
+
message=message,
|
|
103
126
|
original_exception=e,
|
|
104
|
-
details=details
|
|
127
|
+
details=details,
|
|
105
128
|
)
|
|
106
|
-
|
|
129
|
+
|
|
130
|
+
# Only log error if not suppressed
|
|
131
|
+
if not _should_suppress_error_logging():
|
|
132
|
+
logger.error("Error: %s", error_formatted)
|
|
133
|
+
else:
|
|
134
|
+
logger.debug("Suppressed Error: %s", error_formatted)
|
|
135
|
+
|
|
107
136
|
raise error_formatted from None
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
2
|
+
import time
|
|
2
3
|
from tqdm import tqdm
|
|
3
4
|
|
|
4
5
|
from rapidata.rapidata_client.datapoints.assets import MediaAsset
|
|
@@ -6,6 +7,10 @@ from rapidata.rapidata_client.logging import logger
|
|
|
6
7
|
from rapidata.rapidata_client.logging.output_manager import RapidataOutputManager
|
|
7
8
|
from rapidata.api_client.models.create_sample_model import CreateSampleModel
|
|
8
9
|
from rapidata.service.openapi_service import OpenAPIService
|
|
10
|
+
from rapidata.rapidata_client.config.config import rapidata_config
|
|
11
|
+
from rapidata.rapidata_client.api.rapidata_exception import (
|
|
12
|
+
suppress_rapidata_error_logging,
|
|
13
|
+
)
|
|
9
14
|
|
|
10
15
|
|
|
11
16
|
class BenchmarkParticipant:
|
|
@@ -21,11 +26,11 @@ class BenchmarkParticipant:
|
|
|
21
26
|
) -> tuple[MediaAsset | None, MediaAsset | None]:
|
|
22
27
|
"""
|
|
23
28
|
Process single sample upload with retry logic and error tracking.
|
|
24
|
-
|
|
29
|
+
|
|
25
30
|
Args:
|
|
26
31
|
asset: MediaAsset to upload
|
|
27
32
|
identifier: Identifier for the sample
|
|
28
|
-
|
|
33
|
+
|
|
29
34
|
Returns:
|
|
30
35
|
tuple[MediaAsset | None, MediaAsset | None]: (successful_asset, failed_asset)
|
|
31
36
|
"""
|
|
@@ -37,20 +42,30 @@ class BenchmarkParticipant:
|
|
|
37
42
|
urls = [asset.path]
|
|
38
43
|
|
|
39
44
|
last_exception = None
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
45
|
+
for attempt in range(rapidata_config.upload_max_retries):
|
|
46
|
+
try:
|
|
47
|
+
with suppress_rapidata_error_logging():
|
|
48
|
+
self.__openapi_service.participant_api.participant_participant_id_sample_post(
|
|
49
|
+
participant_id=self.id,
|
|
50
|
+
model=CreateSampleModel(identifier=identifier),
|
|
51
|
+
files=files,
|
|
52
|
+
urls=urls,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
return asset, None
|
|
56
|
+
|
|
57
|
+
except Exception as e:
|
|
58
|
+
last_exception = e
|
|
59
|
+
if attempt < rapidata_config.upload_max_retries - 1:
|
|
60
|
+
# Exponential backoff: wait 1s, then 2s, then 4s
|
|
61
|
+
retry_delay = 2**attempt
|
|
62
|
+
time.sleep(retry_delay)
|
|
63
|
+
logger.debug("Error: %s", str(last_exception))
|
|
64
|
+
logger.debug(
|
|
65
|
+
"Retrying %s of %s...",
|
|
66
|
+
attempt + 1,
|
|
67
|
+
rapidata_config.upload_max_retries,
|
|
68
|
+
)
|
|
54
69
|
|
|
55
70
|
logger.error(f"Upload failed for {identifier}. Error: {str(last_exception)}")
|
|
56
71
|
return None, asset
|
|
@@ -59,24 +74,24 @@ class BenchmarkParticipant:
|
|
|
59
74
|
self,
|
|
60
75
|
assets: list[MediaAsset],
|
|
61
76
|
identifiers: list[str],
|
|
62
|
-
max_workers: int = 10,
|
|
63
77
|
) -> tuple[list[MediaAsset], list[MediaAsset]]:
|
|
64
78
|
"""
|
|
65
79
|
Upload samples concurrently with proper error handling and progress tracking.
|
|
66
|
-
|
|
80
|
+
|
|
67
81
|
Args:
|
|
68
82
|
assets: List of MediaAsset objects to upload
|
|
69
83
|
identifiers: List of identifiers matching the assets
|
|
70
|
-
|
|
71
|
-
|
|
84
|
+
|
|
72
85
|
Returns:
|
|
73
86
|
tuple[list[str], list[str]]: Lists of successful and failed identifiers
|
|
74
87
|
"""
|
|
75
88
|
successful_uploads: list[MediaAsset] = []
|
|
76
89
|
failed_uploads: list[MediaAsset] = []
|
|
77
90
|
total_uploads = len(assets)
|
|
78
|
-
|
|
79
|
-
with ThreadPoolExecutor(
|
|
91
|
+
|
|
92
|
+
with ThreadPoolExecutor(
|
|
93
|
+
max_workers=rapidata_config.max_upload_workers
|
|
94
|
+
) as executor:
|
|
80
95
|
futures = [
|
|
81
96
|
executor.submit(
|
|
82
97
|
self._process_single_sample_upload,
|
|
@@ -85,8 +100,12 @@ class BenchmarkParticipant:
|
|
|
85
100
|
)
|
|
86
101
|
for asset, identifier in zip(assets, identifiers)
|
|
87
102
|
]
|
|
88
|
-
|
|
89
|
-
with tqdm(
|
|
103
|
+
|
|
104
|
+
with tqdm(
|
|
105
|
+
total=total_uploads,
|
|
106
|
+
desc="Uploading media",
|
|
107
|
+
disable=RapidataOutputManager.silent_mode,
|
|
108
|
+
) as pbar:
|
|
90
109
|
for future in as_completed(futures):
|
|
91
110
|
try:
|
|
92
111
|
successful_id, failed_id = future.result()
|
|
@@ -96,7 +115,7 @@ class BenchmarkParticipant:
|
|
|
96
115
|
failed_uploads.append(failed_id)
|
|
97
116
|
except Exception as e:
|
|
98
117
|
logger.error(f"Future execution failed: {str(e)}")
|
|
99
|
-
|
|
118
|
+
|
|
100
119
|
pbar.update(1)
|
|
101
|
-
|
|
120
|
+
|
|
102
121
|
return successful_uploads, failed_uploads
|
|
@@ -8,6 +8,7 @@ from rapidata.api_client.models.root_filter import RootFilter
|
|
|
8
8
|
from rapidata.api_client.models.filter import Filter
|
|
9
9
|
from rapidata.api_client.models.sort_criterion import SortCriterion
|
|
10
10
|
|
|
11
|
+
|
|
11
12
|
class RapidataBenchmarkManager:
|
|
12
13
|
"""
|
|
13
14
|
A manager for benchmarks.
|
|
@@ -19,16 +20,18 @@ class RapidataBenchmarkManager:
|
|
|
19
20
|
Args:
|
|
20
21
|
openapi_service: The OpenAPIService instance for API interaction.
|
|
21
22
|
"""
|
|
23
|
+
|
|
22
24
|
def __init__(self, openapi_service: OpenAPIService):
|
|
23
25
|
self.__openapi_service = openapi_service
|
|
24
26
|
|
|
25
|
-
def create_new_benchmark(
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
27
|
+
def create_new_benchmark(
|
|
28
|
+
self,
|
|
29
|
+
name: str,
|
|
30
|
+
identifiers: list[str],
|
|
31
|
+
prompts: Optional[list[str | None]] = None,
|
|
32
|
+
prompt_assets: Optional[list[str | None]] = None,
|
|
33
|
+
tags: Optional[list[list[str] | None]] = None,
|
|
34
|
+
) -> RapidataBenchmark:
|
|
32
35
|
"""
|
|
33
36
|
Creates a new benchmark with the given name, identifiers, prompts, and media assets.
|
|
34
37
|
Everything is matched up by the indexes of the lists.
|
|
@@ -41,31 +44,54 @@ class RapidataBenchmarkManager:
|
|
|
41
44
|
prompts: The prompts that will be registered for the benchmark.
|
|
42
45
|
prompt_assets: The prompt assets that will be registered for the benchmark.
|
|
43
46
|
tags: The tags that will be associated with the prompts to use for filtering the leaderboard results. They will NOT be shown to the users.
|
|
47
|
+
|
|
48
|
+
Example:
|
|
49
|
+
```python
|
|
50
|
+
name = "Example Benchmark"
|
|
51
|
+
identifiers = ["id1", "id2", "id3"]
|
|
52
|
+
prompts = ["prompt 1", "prompt 2", "prompt 3"]
|
|
53
|
+
prompt_assets = ["https://assets.rapidata.ai/prompt_1.jpg", "https://assets.rapidata.ai/prompt_2.jpg", "https://assets.rapidata.ai/prompt_3.jpg"]
|
|
54
|
+
tags = [["tag1", "tag2"], ["tag2"], ["tag2", "tag3"]]
|
|
55
|
+
|
|
56
|
+
benchmark = create_new_benchmark(name=name, identifiers=identifiers, prompts=prompts, prompt_assets=prompt_assets, tags=tags)
|
|
57
|
+
```
|
|
44
58
|
"""
|
|
45
59
|
if not isinstance(name, str):
|
|
46
60
|
raise ValueError("Name must be a string.")
|
|
47
|
-
|
|
48
|
-
if prompts and (
|
|
61
|
+
|
|
62
|
+
if prompts and (
|
|
63
|
+
not isinstance(prompts, list)
|
|
64
|
+
or not all(isinstance(prompt, str) or prompt is None for prompt in prompts)
|
|
65
|
+
):
|
|
49
66
|
raise ValueError("Prompts must be a list of strings or None.")
|
|
50
|
-
|
|
51
|
-
if prompt_assets and (
|
|
67
|
+
|
|
68
|
+
if prompt_assets and (
|
|
69
|
+
not isinstance(prompt_assets, list)
|
|
70
|
+
or not all(
|
|
71
|
+
isinstance(asset, str) or asset is None for asset in prompt_assets
|
|
72
|
+
)
|
|
73
|
+
):
|
|
52
74
|
raise ValueError("Media assets must be a list of strings or None.")
|
|
53
|
-
|
|
54
|
-
if not isinstance(identifiers, list) or not all(
|
|
75
|
+
|
|
76
|
+
if not isinstance(identifiers, list) or not all(
|
|
77
|
+
isinstance(identifier, str) for identifier in identifiers
|
|
78
|
+
):
|
|
55
79
|
raise ValueError("Identifiers must be a list of strings.")
|
|
56
|
-
|
|
80
|
+
|
|
57
81
|
if prompts and len(identifiers) != len(prompts):
|
|
58
82
|
raise ValueError("Identifiers and prompts must have the same length.")
|
|
59
|
-
|
|
83
|
+
|
|
60
84
|
if prompt_assets and len(identifiers) != len(prompt_assets):
|
|
61
85
|
raise ValueError("Identifiers and media assets must have the same length.")
|
|
62
|
-
|
|
86
|
+
|
|
63
87
|
if not prompts and not prompt_assets:
|
|
64
|
-
raise ValueError(
|
|
65
|
-
|
|
88
|
+
raise ValueError(
|
|
89
|
+
"At least one of prompts or media assets must be provided."
|
|
90
|
+
)
|
|
91
|
+
|
|
66
92
|
if len(set(identifiers)) != len(identifiers):
|
|
67
93
|
raise ValueError("Identifiers must be unique.")
|
|
68
|
-
|
|
94
|
+
|
|
69
95
|
if tags and len(identifiers) != len(tags):
|
|
70
96
|
raise ValueError("Identifiers and tags must have the same length.")
|
|
71
97
|
|
|
@@ -78,32 +104,49 @@ class RapidataBenchmarkManager:
|
|
|
78
104
|
benchmark = RapidataBenchmark(name, benchmark_result.id, self.__openapi_service)
|
|
79
105
|
|
|
80
106
|
prompts_list = prompts if prompts is not None else [None] * len(identifiers)
|
|
81
|
-
media_assets_list =
|
|
107
|
+
media_assets_list = (
|
|
108
|
+
prompt_assets if prompt_assets is not None else [None] * len(identifiers)
|
|
109
|
+
)
|
|
82
110
|
tags_list = tags if tags is not None else [None] * len(identifiers)
|
|
83
111
|
|
|
84
|
-
for identifier, prompt, asset, tag in zip(
|
|
112
|
+
for identifier, prompt, asset, tag in zip(
|
|
113
|
+
identifiers, prompts_list, media_assets_list, tags_list
|
|
114
|
+
):
|
|
85
115
|
benchmark.add_prompt(identifier, prompt, asset, tag)
|
|
86
116
|
|
|
87
117
|
return benchmark
|
|
88
|
-
|
|
118
|
+
|
|
89
119
|
def get_benchmark_by_id(self, id: str) -> RapidataBenchmark:
|
|
90
120
|
"""
|
|
91
121
|
Returns a benchmark by its ID.
|
|
92
122
|
"""
|
|
93
|
-
benchmark_result =
|
|
94
|
-
|
|
123
|
+
benchmark_result = (
|
|
124
|
+
self.__openapi_service.benchmark_api.benchmark_benchmark_id_get(
|
|
125
|
+
benchmark_id=id
|
|
126
|
+
)
|
|
95
127
|
)
|
|
96
|
-
return RapidataBenchmark(
|
|
97
|
-
|
|
98
|
-
|
|
128
|
+
return RapidataBenchmark(
|
|
129
|
+
benchmark_result.name, benchmark_result.id, self.__openapi_service
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
def find_benchmarks(
|
|
133
|
+
self, name: str = "", amount: int = 10
|
|
134
|
+
) -> list[RapidataBenchmark]:
|
|
99
135
|
"""
|
|
100
136
|
Returns a list of benchmarks by their name.
|
|
101
137
|
"""
|
|
102
138
|
benchmark_result = self.__openapi_service.benchmark_api.benchmarks_get(
|
|
103
139
|
QueryModel(
|
|
104
140
|
page=PageInfo(index=1, size=amount),
|
|
105
|
-
filter=RootFilter(
|
|
106
|
-
|
|
141
|
+
filter=RootFilter(
|
|
142
|
+
filters=[Filter(field="Name", operator="Contains", value=name)]
|
|
143
|
+
),
|
|
144
|
+
sortCriteria=[
|
|
145
|
+
SortCriterion(direction="Desc", propertyName="CreatedAt")
|
|
146
|
+
],
|
|
107
147
|
)
|
|
108
148
|
)
|
|
109
|
-
return [
|
|
149
|
+
return [
|
|
150
|
+
RapidataBenchmark(benchmark.name, benchmark.id, self.__openapi_service)
|
|
151
|
+
for benchmark in benchmark_result.items
|
|
152
|
+
]
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .config import rapidata_config
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
class _RapidataConfig:
|
|
2
|
+
def __init__(self):
|
|
3
|
+
self.__maxUploadWorkers: int = 10
|
|
4
|
+
self.__uploadMaxRetries: int = 3
|
|
5
|
+
|
|
6
|
+
@property
|
|
7
|
+
def max_upload_workers(self) -> int:
|
|
8
|
+
return self.__maxUploadWorkers
|
|
9
|
+
|
|
10
|
+
@max_upload_workers.setter
|
|
11
|
+
def max_upload_workers(self, value: int) -> None:
|
|
12
|
+
if value < 1:
|
|
13
|
+
raise ValueError("max_upload_workers must be greater than 0")
|
|
14
|
+
self.__maxUploadWorkers = value
|
|
15
|
+
|
|
16
|
+
@property
|
|
17
|
+
def upload_max_retries(self) -> int:
|
|
18
|
+
return self.__uploadMaxRetries
|
|
19
|
+
|
|
20
|
+
@upload_max_retries.setter
|
|
21
|
+
def upload_max_retries(self, value: int) -> None:
|
|
22
|
+
if value < 1:
|
|
23
|
+
raise ValueError("upload_max_retries must be greater than 0")
|
|
24
|
+
self.__uploadMaxRetries = value
|
|
25
|
+
|
|
26
|
+
def __str__(self) -> str:
|
|
27
|
+
return f"RapidataConfig(max_upload_workers={self.__maxUploadWorkers}, upload_max_retries={self.__uploadMaxRetries})"
|
|
28
|
+
|
|
29
|
+
def __repr__(self) -> str:
|
|
30
|
+
return self.__str__()
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
rapidata_config = _RapidataConfig()
|
|
@@ -5,7 +5,7 @@ Defines the MultiAsset class for handling multiple BaseAsset instances.
|
|
|
5
5
|
|
|
6
6
|
from rapidata.rapidata_client.datapoints.assets._base_asset import BaseAsset
|
|
7
7
|
from rapidata.rapidata_client.datapoints.assets import MediaAsset, TextAsset
|
|
8
|
-
from typing import Iterator, Sequence
|
|
8
|
+
from typing import Iterator, Sequence
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
class MultiAsset(BaseAsset):
|
|
@@ -26,16 +26,16 @@ class MultiAsset(BaseAsset):
|
|
|
26
26
|
"""
|
|
27
27
|
if len(assets) != 2:
|
|
28
28
|
raise ValueError("Assets must come in pairs for comparison tasks.")
|
|
29
|
-
|
|
29
|
+
|
|
30
30
|
for asset in assets:
|
|
31
31
|
if not isinstance(asset, (TextAsset, MediaAsset)):
|
|
32
|
-
raise TypeError("All assets must be a TextAsset or MediaAsset.")
|
|
33
|
-
|
|
32
|
+
raise TypeError("All assets must be a TextAsset or MediaAsset.")
|
|
33
|
+
|
|
34
34
|
if not all(isinstance(asset, type(assets[0])) for asset in assets):
|
|
35
35
|
raise ValueError("All assets must be of the same type.")
|
|
36
|
-
|
|
36
|
+
|
|
37
37
|
self.assets = assets
|
|
38
|
-
|
|
38
|
+
|
|
39
39
|
def __len__(self) -> int:
|
|
40
40
|
"""
|
|
41
41
|
Get the number of assets in the MultiAsset.
|
|
@@ -56,6 +56,6 @@ class MultiAsset(BaseAsset):
|
|
|
56
56
|
|
|
57
57
|
def __str__(self) -> str:
|
|
58
58
|
return f"MultiAsset(assets={self.assets})"
|
|
59
|
-
|
|
59
|
+
|
|
60
60
|
def __repr__(self) -> str:
|
|
61
61
|
return f"MultiAsset(assets={self.assets})"
|