rapidata 2.21.4__py3-none-any.whl → 2.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rapidata might be problematic. Click here for more details.

Files changed (62) hide show
  1. rapidata/__init__.py +5 -0
  2. rapidata/api_client/__init__.py +8 -4
  3. rapidata/api_client/api/__init__.py +1 -0
  4. rapidata/api_client/api/evaluation_workflow_api.py +372 -0
  5. rapidata/api_client/api/identity_api.py +268 -0
  6. rapidata/api_client/api/rapid_api.py +353 -1987
  7. rapidata/api_client/api/simple_workflow_api.py +6 -6
  8. rapidata/api_client/models/__init__.py +7 -4
  9. rapidata/api_client/models/add_campaign_model.py +25 -1
  10. rapidata/api_client/models/add_validation_rapid_model_truth.py +24 -10
  11. rapidata/api_client/models/compare_result.py +2 -0
  12. rapidata/api_client/models/create_order_model.py +43 -2
  13. rapidata/api_client/models/evaluation_workflow_model1.py +115 -0
  14. rapidata/api_client/models/filter.py +2 -2
  15. rapidata/api_client/models/get_validation_rapids_result.py +11 -4
  16. rapidata/api_client/models/get_validation_rapids_result_truth.py +24 -10
  17. rapidata/api_client/models/get_workflow_by_id_result_workflow.py +23 -9
  18. rapidata/api_client/models/get_workflow_results_result.py +118 -0
  19. rapidata/api_client/models/get_workflow_results_result_paged_result.py +105 -0
  20. rapidata/api_client/models/google_one_tap_login_model.py +87 -0
  21. rapidata/api_client/models/labeling_selection.py +22 -3
  22. rapidata/api_client/models/logic_operator.py +1 -0
  23. rapidata/api_client/models/rapid_response.py +3 -1
  24. rapidata/api_client/models/retrieval_mode.py +38 -0
  25. rapidata/api_client/models/root_filter.py +2 -2
  26. rapidata/api_client/models/skip_truth.py +94 -0
  27. rapidata/api_client/models/sticky_state.py +38 -0
  28. rapidata/api_client/models/update_validation_rapid_model.py +11 -4
  29. rapidata/api_client/models/update_validation_rapid_model_truth.py +24 -10
  30. rapidata/api_client/rest.py +1 -0
  31. rapidata/api_client_README.md +10 -11
  32. rapidata/rapidata_client/__init__.py +7 -0
  33. rapidata/rapidata_client/api/rapidata_exception.py +5 -3
  34. rapidata/rapidata_client/assets/__init__.py +1 -0
  35. rapidata/rapidata_client/assets/_media_asset.py +16 -10
  36. rapidata/rapidata_client/assets/_multi_asset.py +6 -0
  37. rapidata/rapidata_client/assets/_sessions.py +35 -0
  38. rapidata/rapidata_client/assets/_text_asset.py +6 -0
  39. rapidata/rapidata_client/demographic/demographic_manager.py +2 -35
  40. rapidata/rapidata_client/logging/__init__.py +2 -0
  41. rapidata/rapidata_client/logging/logger.py +47 -0
  42. rapidata/rapidata_client/logging/output_manager.py +16 -0
  43. rapidata/rapidata_client/order/_rapidata_dataset.py +11 -15
  44. rapidata/rapidata_client/order/_rapidata_order_builder.py +15 -2
  45. rapidata/rapidata_client/order/rapidata_order.py +23 -14
  46. rapidata/rapidata_client/order/rapidata_order_manager.py +4 -2
  47. rapidata/rapidata_client/order/rapidata_results.py +2 -1
  48. rapidata/rapidata_client/rapidata_client.py +6 -1
  49. rapidata/rapidata_client/selection/__init__.py +1 -0
  50. rapidata/rapidata_client/selection/labeling_selection.py +8 -2
  51. rapidata/rapidata_client/selection/retrieval_modes.py +9 -0
  52. rapidata/rapidata_client/settings/alert_on_fast_response.py +2 -1
  53. rapidata/rapidata_client/settings/free_text_minimum_characters.py +2 -1
  54. rapidata/rapidata_client/validation/rapidata_validation_set.py +4 -34
  55. rapidata/rapidata_client/validation/rapids/rapids.py +6 -7
  56. rapidata/rapidata_client/validation/validation_set_manager.py +39 -36
  57. rapidata/service/credential_manager.py +22 -30
  58. rapidata/service/openapi_service.py +11 -0
  59. {rapidata-2.21.4.dist-info → rapidata-2.22.0.dist-info}/METADATA +2 -1
  60. {rapidata-2.21.4.dist-info → rapidata-2.22.0.dist-info}/RECORD +62 -49
  61. {rapidata-2.21.4.dist-info → rapidata-2.22.0.dist-info}/WHEEL +1 -1
  62. {rapidata-2.21.4.dist-info → rapidata-2.22.0.dist-info}/LICENSE +0 -0
@@ -29,6 +29,8 @@ from rapidata.rapidata_client.assets import MediaAsset, TextAsset, MultiAsset, B
29
29
 
30
30
  from typing import Optional, cast, Sequence
31
31
 
32
+ from rapidata.rapidata_client.logging import logger, managed_print
33
+
32
34
 
33
35
  class RapidataOrderBuilder:
34
36
  """Builder object for creating Rapidata orders.
@@ -73,7 +75,7 @@ class RapidataOrderBuilder:
73
75
  raise ValueError("You must provide a workflow to create an order.")
74
76
 
75
77
  if self.__referee is None:
76
- print("No referee provided, using default NaiveReferee.")
78
+ managed_print("No referee provided, using default NaiveReferee.")
77
79
  self.__referee = NaiveReferee()
78
80
 
79
81
  return CreateOrderModel(
@@ -113,6 +115,7 @@ class RapidataOrderBuilder:
113
115
  RapidataOrder: The created RapidataOrder instance.
114
116
  """
115
117
  order_model = self._to_model()
118
+ logger.debug(f"Creating order with model: {order_model}")
116
119
  if isinstance(
117
120
  self.__workflow, CompareWorkflow
118
121
  ): # Temporary fix; will be handled by backend in the future
@@ -125,12 +128,17 @@ class RapidataOrderBuilder:
125
128
  )
126
129
 
127
130
  self.order_id = str(result.order_id)
131
+ logger.debug(f"Order created with ID: {self.order_id}")
128
132
 
129
133
  self.__dataset = (
130
134
  RapidataDataset(result.dataset_id, self.__openapi_service)
131
135
  if result.dataset_id
132
136
  else None
133
137
  )
138
+ if self.__dataset:
139
+ logger.debug(f"Dataset created with ID: {self.__dataset.dataset_id}")
140
+ else:
141
+ logger.warning("No dataset created for this order.")
134
142
 
135
143
  order = RapidataOrder(
136
144
  order_id=self.order_id,
@@ -138,6 +146,9 @@ class RapidataOrderBuilder:
138
146
  name=self._name,
139
147
  )
140
148
 
149
+ logger.debug(f"Order created: {order}")
150
+ logger.debug("Adding media to the order.")
151
+
141
152
  if all(isinstance(item, MediaAsset) for item in self.__assets) and self.__dataset:
142
153
  assets = cast(list[MediaAsset], self.__assets)
143
154
  self.__dataset._add_media_from_paths(assets, self.__metadata, max_upload_workers)
@@ -183,6 +194,8 @@ class RapidataOrderBuilder:
183
194
  "Media paths must all be of the same type: MediaAsset, TextAsset, or MultiAsset."
184
195
  )
185
196
 
197
+ logger.debug("Media added to the order.")
198
+ logger.debug("Setting order to preview")
186
199
  self.__openapi_service.order_api.order_order_id_preview_post(self.order_id)
187
200
 
188
201
  return order
@@ -291,7 +304,7 @@ class RapidataOrderBuilder:
291
304
  raise TypeError("Filters must be of type Filter.")
292
305
 
293
306
  if len(self.__user_filters) > 0:
294
- print("Overwriting existing user filters.")
307
+ managed_print("Overwriting existing user filters.")
295
308
 
296
309
  self.__user_filters = filters
297
310
  return self
@@ -17,6 +17,7 @@ from rapidata.api_client.models.preliminary_download_model import PreliminaryDow
17
17
  from rapidata.api_client.models.workflow_artifact_model import WorkflowArtifactModel
18
18
  from rapidata.rapidata_client.order.rapidata_results import RapidataResults
19
19
  from rapidata.service.openapi_service import OpenAPIService
20
+ from rapidata.rapidata_client.logging import logger, managed_print
20
21
 
21
22
 
22
23
  class RapidataOrder:
@@ -47,23 +48,29 @@ class RapidataOrder:
47
48
  self._max_retries = 10
48
49
  self._retry_delay = 2
49
50
  self.order_details_page = f"https://app.{self.__openapi_service.environment}/order/detail/{self.order_id}"
51
+ logger.debug("RapidataOrder initialized")
50
52
 
51
- def run(self, print_link: bool = True) -> "RapidataOrder":
53
+ def run(self) -> "RapidataOrder":
52
54
  """Runs the order to start collecting responses."""
53
- self.__openapi_service.order_api.order_submit_post(self.order_id)
54
- if print_link:
55
- print(f"Order '{self.name}' is now viewable under: {self.order_details_page}")
55
+ logger.info(f"Starting order '{self}'")
56
+ self.__openapi_service.order_api.order_order_id_submit_post(self.order_id)
57
+ logger.debug(f"Order '{self}' has been started.")
58
+ managed_print(f"Order '{self.name}' is now viewable under: {self.order_details_page}")
56
59
  return self
57
60
 
58
61
  def pause(self) -> None:
59
62
  """Pauses the order."""
63
+ logger.info(f"Pausing order '{self}'")
60
64
  self.__openapi_service.order_api.order_pause_post(self.order_id)
61
- print(f"Order '{self}' has been paused.")
65
+ logger.debug(f"Order '{self}' has been paused.")
66
+ managed_print(f"Order '{self}' has been paused.")
62
67
 
63
68
  def unpause(self) -> None:
64
69
  """Unpauses/resumes the order."""
70
+ logger.info(f"Unpausing order '{self}'")
65
71
  self.__openapi_service.order_api.order_resume_post(self.order_id)
66
- print(f"Order '{self}' has been unpaused.")
72
+ logger.debug(f"Order '{self}' has been unpaused.")
73
+ managed_print(f"Order '{self}' has been unpaused.")
67
74
 
68
75
  def get_status(self) -> str:
69
76
  """
@@ -95,12 +102,12 @@ class RapidataOrder:
95
102
  raise Exception("Order has not been started yet. Please start it first.")
96
103
 
97
104
  while self.get_status() == OrderState.SUBMITTED:
98
- print(f"Order '{self.name}' is submitted and being reviewed. Standby...", end="\r")
105
+ managed_print(f"Order '{self}' is submitted and being reviewed. Standby...", end="\r")
99
106
  sleep(1)
100
107
 
101
108
  if self.get_status() == OrderState.MANUALREVIEW:
102
109
  raise Exception(
103
- f"Order '{self.name}' is in manual review. It might take some time to start. "
110
+ f"Order '{self}' is in manual review. It might take some time to start. "
104
111
  "To speed up the process, contact support (info@rapidata.ai).\n"
105
112
  "Once started, run this method again to display the progress bar."
106
113
  )
@@ -145,12 +152,12 @@ class RapidataOrder:
145
152
  Note that preliminary results are not final and may not contain all the datapoints & responses. Only the onese that are already available.
146
153
  This will throw an exception if there are no responses available yet.
147
154
  """
148
-
155
+ logger.info(f"Getting results for order '{self}'...")
149
156
  if preliminary_results and self.get_status() not in [OrderState.COMPLETED]:
150
157
  return self.__get_preliminary_results()
151
158
 
152
159
  elif preliminary_results and self.get_status() in [OrderState.COMPLETED]:
153
- print("Order is already completed. Returning final results.")
160
+ managed_print("Order is already completed. Returning final results.")
154
161
 
155
162
  while self.get_status() not in [OrderState.COMPLETED, OrderState.PAUSED, OrderState.MANUALREVIEW, OrderState.FAILED]:
156
163
  sleep(5)
@@ -167,10 +174,11 @@ class RapidataOrder:
167
174
  Raises:
168
175
  Exception: If the order is not in processing state.
169
176
  """
177
+ logger.info("Opening order details page in browser...")
170
178
  could_open_browser = webbrowser.open(self.order_details_page)
171
179
  if not could_open_browser:
172
180
  encoded_url = urllib.parse.quote(self.order_details_page, safe="%/:=&?~#+!$,;'@()*[]")
173
- print(Fore.RED + f'Please open this URL in your browser: "{encoded_url}"' + Fore.RESET)
181
+ managed_print(Fore.RED + f'Please open this URL in your browser: "{encoded_url}"' + Fore.RESET)
174
182
 
175
183
  def preview(self) -> None:
176
184
  """
@@ -178,13 +186,14 @@ class RapidataOrder:
178
186
 
179
187
  Raises:
180
188
  Exception: If the order is not in processing state.
181
- """
189
+ """
190
+ logger.info("Opening order preview in browser...")
182
191
  campaign_id = self.__get_campaign_id()
183
192
  auth_url = f"https://app.{self.__openapi_service.environment}/order/detail/{self.order_id}/preview?campaignId={campaign_id}"
184
193
  could_open_browser = webbrowser.open(auth_url)
185
194
  if not could_open_browser:
186
195
  encoded_url = urllib.parse.quote(auth_url, safe="%/:=&?~#+!$,;'@()*[]")
187
- print(Fore.RED + f'Please open this URL in your browser: "{encoded_url}"' + Fore.RESET)
196
+ managed_print(Fore.RED + f'Please open this URL in your browser: "{encoded_url}"' + Fore.RESET)
188
197
 
189
198
  def __get_pipeline_id(self) -> str:
190
199
  """Internal method to fetch and cache the pipeline ID."""
@@ -249,7 +258,7 @@ class RapidataOrder:
249
258
  raise Exception(f"Failed to get preliminary results: {str(e)}") from e
250
259
 
251
260
  def __str__(self) -> str:
252
- return f"name: '{self.name}' order id: {self.order_id}"
261
+ return f"RapidataOrder(name='{self.name}', order_id='{self.order_id}')"
253
262
 
254
263
  def __repr__(self) -> str:
255
264
  return f"RapidataOrder(name='{self.name}', order_id='{self.order_id}')"
@@ -34,6 +34,7 @@ from rapidata.api_client.models.page_info import PageInfo
34
34
  from rapidata.api_client.models.root_filter import RootFilter
35
35
  from rapidata.api_client.models.filter import Filter
36
36
  from rapidata.api_client.models.sort_criterion import SortCriterion
37
+ from rapidata.rapidata_client.logging import logger
37
38
 
38
39
  from tqdm import tqdm
39
40
 
@@ -53,6 +54,7 @@ class RapidataOrderManager:
53
54
  self.settings = RapidataSettings
54
55
  self.selections = RapidataSelections
55
56
  self.__priority = 50
57
+ logger.debug("RapidataOrderManager initialized")
56
58
 
57
59
  def __get_selections(self, validation_set_id: str | None, labeling_amount=3) -> Sequence[RapidataSelection]:
58
60
  if validation_set_id:
@@ -85,7 +87,7 @@ class RapidataOrderManager:
85
87
  raise ValueError("You can only use contexts or sentences, not both")
86
88
 
87
89
  if contexts and data_type == RapidataDataTypes.TEXT:
88
- print("Warning: Contexts are not supported for text data type. Ignoring contexts.")
90
+ logger.warning("Warning: Contexts are not supported for text data type. Ignoring contexts.")
89
91
 
90
92
  if not confidence_threshold:
91
93
  referee = NaiveReferee(responses=responses_per_datapoint)
@@ -98,7 +100,7 @@ class RapidataOrderManager:
98
100
  order_builder = RapidataOrderBuilder(name=name, openapi_service=self._openapi_service)
99
101
 
100
102
  if selections and validation_set_id:
101
- print("Warning: You provided both selections and validation_set_id. Ignoring validation_set_id.")
103
+ logger.warning("Warning: Both selections and validation_set_id provided. Ignoring validation_set_id.")
102
104
 
103
105
  if selections is None:
104
106
  selections = self.__get_selections(validation_set_id, labeling_amount=default_labeling_amount)
@@ -2,6 +2,7 @@ import pandas as pd
2
2
  from typing import Any
3
3
  from pandas.core.indexes.base import Index
4
4
  import json
5
+ from rapidata.rapidata_client.logging import managed_print
5
6
 
6
7
  class RapidataResults(dict):
7
8
  """
@@ -32,7 +33,7 @@ class RapidataResults(dict):
32
33
  return pd.DataFrame()
33
34
 
34
35
  if self["info"].get("orderType") is None:
35
- print("Warning: Results are old and Order type is not specified. Dataframe might be wrong.")
36
+ managed_print("Warning: Results are old and Order type is not specified. Dataframe might be wrong.")
36
37
 
37
38
  # Check for detailed results if split_details is True
38
39
  if split_details:
@@ -8,6 +8,7 @@ from rapidata.rapidata_client.validation.validation_set_manager import (
8
8
 
9
9
  from rapidata.rapidata_client.demographic.demographic_manager import DemographicManager
10
10
 
11
+ from rapidata.rapidata_client.logging import logger
11
12
 
12
13
  class RapidataClient:
13
14
  """The Rapidata client is the main entry point for interacting with the Rapidata API. It allows you to create orders and validation sets."""
@@ -38,6 +39,7 @@ class RapidataClient:
38
39
  order (RapidataOrderManager): The RapidataOrderManager instance.
39
40
  validation (ValidationSetManager): The ValidationSetManager instance.
40
41
  """
42
+ logger.debug("Initializing OpenAPIService")
41
43
  self._openapi_service = OpenAPIService(
42
44
  client_id=client_id,
43
45
  client_secret=client_secret,
@@ -48,12 +50,15 @@ class RapidataClient:
48
50
  leeway=leeway,
49
51
  )
50
52
 
53
+ logger.debug("Initializing RapidataOrderManager")
51
54
  self.order = RapidataOrderManager(openapi_service=self._openapi_service)
52
55
 
56
+ logger.debug("Initializing ValidationSetManager")
53
57
  self.validation = ValidationSetManager(openapi_service=self._openapi_service)
54
58
 
59
+ logger.debug("Initializing DemographicManager")
55
60
  self._demographic = DemographicManager(openapi_service=self._openapi_service)
56
-
61
+
57
62
  def reset_credentials(self):
58
63
  """Reset the credentials saved in the configuration file for the current environment."""
59
64
  self._openapi_service.reset_credentials()
@@ -7,3 +7,4 @@ from .capped_selection import CappedSelection
7
7
  from .shuffling_selection import ShufflingSelection
8
8
  from .ab_test_selection import AbTestSelection
9
9
  from .static_selection import StaticSelection
10
+ from .retrieval_modes import RetrievalMode
@@ -1,5 +1,6 @@
1
1
  from typing import Any
2
2
  from rapidata.rapidata_client.selection._base_selection import RapidataSelection
3
+ from rapidata.rapidata_client.selection.retrieval_modes import RetrievalMode
3
4
  from rapidata.api_client.models.labeling_selection import (
4
5
  LabelingSelection as LabelingSelectionModel,
5
6
  )
@@ -12,10 +13,15 @@ class LabelingSelection(RapidataSelection):
12
13
 
13
14
  Args:
14
15
  amount (int): The amount of labeling rapids that will be shown per session.
16
+ retrieval_mode (RetrievalMode): The retrieval mode to use. Defaults to "Random".
17
+ max_iterations (int | None): The maximum number an annotator can see the same task. Defaults to None.
18
+ This parameter is only taken into account when using "Shuffled" or "Sequential" retrieval modes.
15
19
  """
16
20
 
17
- def __init__(self, amount: int):
21
+ def __init__(self, amount: int, retrieval_mode: RetrievalMode = RetrievalMode.Random, max_iterations: int | None = None):
18
22
  self.amount = amount
23
+ self.retrieval_mode = retrieval_mode
24
+ self.max_iterations = max_iterations
19
25
 
20
26
  def _to_model(self) -> Any:
21
- return LabelingSelectionModel(_t="LabelingSelection", amount=self.amount)
27
+ return LabelingSelectionModel(_t="LabelingSelection", amount=self.amount, retrievalMode=self.retrieval_mode.value, maxIterations=self.max_iterations)
@@ -0,0 +1,9 @@
1
+ from enum import Enum
2
+
3
+ class RetrievalMode(Enum):
4
+ # Will just randomly shuffle the datapoints. This is the default and will NOT take into account the "max_iterations" parameter.
5
+ Random = "Random"
6
+ # Will shuffle the datapoints randomly for each user. The user will then see the datapoints in that order. This will take into account the "max_iterations" parameter.
7
+ Shuffled = "Shuffled"
8
+ # Will show the datapoints in the order they are in the dataset. This will take into account the "max_iterations" parameter.
9
+ Sequential = "Sequential"
@@ -1,4 +1,5 @@
1
1
  from rapidata.rapidata_client.settings._rapidata_setting import RapidataSetting
2
+ from rapidata.rapidata_client.logging import managed_print
2
3
 
3
4
  class AlertOnFastResponse(RapidataSetting):
4
5
  """
@@ -12,7 +13,7 @@ class AlertOnFastResponse(RapidataSetting):
12
13
  if not isinstance(threshold, int):
13
14
  raise ValueError("The alert must be an integer.")
14
15
  if threshold < 10:
15
- print(f"Warning: Are you sure you want to set the threshold so low ({threshold} milliseconds)?")
16
+ managed_print(f"Warning: Are you sure you want to set the threshold so low ({threshold} milliseconds)?")
16
17
  if threshold > 25000:
17
18
  raise ValueError("The alert must be less than 25000 milliseconds.")
18
19
  if threshold < 0:
@@ -1,4 +1,5 @@
1
1
  from rapidata.rapidata_client.settings._rapidata_setting import RapidataSetting
2
+ from rapidata.rapidata_client.logging import managed_print, logger
2
3
 
3
4
  class FreeTextMinimumCharacters(RapidataSetting):
4
5
  """
@@ -12,5 +13,5 @@ class FreeTextMinimumCharacters(RapidataSetting):
12
13
  if value < 1:
13
14
  raise ValueError("The minimum number of characters must be greater than or equal to 1.")
14
15
  if value > 40:
15
- print(f"Warning: Are you sure you want to set the minimum number of characters at {value}?")
16
+ managed_print(f"Warning: Are you sure you want to set the minimum number of characters at {value}?")
16
17
  super().__init__(key="free_text_minimum_characters", value=value)
@@ -1,8 +1,8 @@
1
1
  from rapidata.rapidata_client.validation.rapids.rapids import Rapid
2
2
  from rapidata.service.openapi_service import OpenAPIService
3
- from requests.adapters import HTTPAdapter, Retry
4
- import requests
3
+ from rapidata.rapidata_client.logging import logger
5
4
  from rapidata.api_client.models.update_dimensions_model import UpdateDimensionsModel
5
+ from rapidata.rapidata_client.assets._sessions import SessionManager
6
6
 
7
7
  class RapidataValidationSet:
8
8
  """A class for interacting with a Rapidata validation set.
@@ -20,7 +20,6 @@ class RapidataValidationSet:
20
20
  self.id = validation_set_id
21
21
  self.name = name
22
22
  self.__openapi_service = openapi_service
23
- self.__session = self._get_session()
24
23
 
25
24
  def add_rapid(self, rapid: Rapid):
26
25
  """Add a Rapid to the validation set.
@@ -28,7 +27,7 @@ class RapidataValidationSet:
28
27
  Args:
29
28
  rapid (Rapid): The Rapid to add to the validation set.
30
29
  """
31
- rapid._add_to_validation_set(self.id, self.__openapi_service, self.__session)
30
+ rapid._add_to_validation_set(self.id, self.__openapi_service)
32
31
  return self
33
32
 
34
33
  def update_dimensions(self, dimensions: list[str]):
@@ -37,39 +36,10 @@ class RapidataValidationSet:
37
36
  Args:
38
37
  dimensions (list[str]): The new dimensions of the validation set.
39
38
  """
39
+ logger.debug(f"Updating dimensions for validation set {self.id} to {dimensions}")
40
40
  self.__openapi_service.validation_api.validation_validation_set_id_dimensions_patch(self.id, UpdateDimensionsModel(dimensions=dimensions))
41
41
  return self
42
42
 
43
- def _get_session(self, max_retries: int = 5, max_workers: int = 10) -> requests.Session:
44
- """Get a requests session with retry logic.
45
-
46
-
47
- Args:
48
- max_retries (int): The maximum number of retries.
49
- max_workers (int): The maximum number of workers.
50
-
51
- Returns:
52
- requests.Session: A requests session with retry logic.
53
- """
54
- session = requests.Session()
55
- retries = Retry(
56
- total=max_retries,
57
- backoff_factor=1,
58
- status_forcelist=[500, 502, 503, 504],
59
- allowed_methods=["GET"],
60
- respect_retry_after_header=True
61
- )
62
-
63
- adapter = HTTPAdapter(
64
- pool_connections=max_workers * 2,
65
- pool_maxsize=max_workers * 4,
66
- max_retries=retries
67
- )
68
- session.mount('http://', adapter)
69
- session.mount('https://', adapter)
70
-
71
- return session
72
-
73
43
  def __str__(self):
74
44
  return f"name: '{self.name}' id: {self.id}"
75
45
 
@@ -19,7 +19,8 @@ from rapidata.api_client.models.create_datapoint_from_files_model_metadata_inner
19
19
 
20
20
  from rapidata.service.openapi_service import OpenAPIService
21
21
 
22
- import requests
22
+ from rapidata.rapidata_client.logging import logger
23
+
23
24
 
24
25
  class Rapid():
25
26
  def __init__(self, asset: MediaAsset | TextAsset | MultiAsset, metadata: Sequence[Metadata], payload: Any, truth: Any, randomCorrectProbability: float, explanation: str | None):
@@ -29,15 +30,16 @@ class Rapid():
29
30
  self.truth = truth
30
31
  self.randomCorrectProbability = randomCorrectProbability
31
32
  self.explanation = explanation
33
+ logger.debug(f"Created Rapid with asset: {self.asset}, metadata: {self.metadata}, payload: {self.payload}, truth: {self.truth}, randomCorrectProbability: {self.randomCorrectProbability}, explanation: {self.explanation}")
32
34
 
33
- def _add_to_validation_set(self, validationSetId: str, openapi_service: OpenAPIService, session: requests.Session) -> None:
35
+ def _add_to_validation_set(self, validationSetId: str, openapi_service: OpenAPIService) -> None:
34
36
  if isinstance(self.asset, TextAsset) or (isinstance(self.asset, MultiAsset) and isinstance(self.asset.assets[0], TextAsset)):
35
37
  openapi_service.validation_api.validation_add_validation_text_rapid_post(
36
38
  add_validation_text_rapid_model=self.__to_text_model(validationSetId)
37
39
  )
38
40
 
39
41
  elif isinstance(self.asset, MediaAsset) or (isinstance(self.asset, MultiAsset) and isinstance(self.asset.assets[0], MediaAsset)):
40
- model = self.__to_media_model(validationSetId, session=session)
42
+ model = self.__to_media_model(validationSetId)
41
43
  openapi_service.validation_api.validation_add_validation_rapid_post(
42
44
  model=model[0], files=model[1]
43
45
  )
@@ -45,7 +47,7 @@ class Rapid():
45
47
  else:
46
48
  raise TypeError("The asset must be a MediaAsset, TextAsset, or MultiAsset")
47
49
 
48
- def __to_media_model(self, validationSetId: str, session: requests.Session) -> tuple[AddValidationRapidModel, list[StrictStr | tuple[StrictStr, StrictBytes] | StrictBytes]]:
50
+ def __to_media_model(self, validationSetId: str) -> tuple[AddValidationRapidModel, list[StrictStr | tuple[StrictStr, StrictBytes] | StrictBytes]]:
49
51
  assets: list[MediaAsset] = []
50
52
  if isinstance(self.asset, MultiAsset):
51
53
  for asset in self.asset.assets:
@@ -60,9 +62,6 @@ class Rapid():
60
62
  if isinstance(self.asset, MediaAsset):
61
63
  assets = [self.asset]
62
64
 
63
- for asset in assets:
64
- asset.session = session
65
-
66
65
  return (AddValidationRapidModel(
67
66
  validationSetId=validationSetId,
68
67
  payload=AddValidationRapidModelPayload(self.payload),