rapidata 2.35.1__py3-none-any.whl → 2.35.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rapidata might be problematic. Click here for more details.

@@ -2,34 +2,39 @@ import requests
2
2
  from requests.adapters import HTTPAdapter
3
3
  from urllib3.util.retry import Retry
4
4
 
5
+ from rapidata.rapidata_client.config.config import rapidata_config
6
+
7
+
5
8
  class SessionManager:
6
9
  _session = None
7
-
10
+
8
11
  @classmethod
9
- def get_session(cls, ) -> requests.Session:
12
+ def get_session(
13
+ cls,
14
+ ) -> requests.Session:
10
15
  """Get a singleton requests session with retry logic.
11
16
 
12
17
  Returns:
13
18
  requests.Session: A singleton requests session with retry logic.
14
19
  """
15
20
  if cls._session is None:
16
- max_retries: int = 5
17
- max_workers: int = 10
21
+ max_retries: int = rapidata_config.upload_max_retries
22
+ max_workers: int = rapidata_config.max_upload_workers
18
23
  cls._session = requests.Session()
19
24
  retries = Retry(
20
25
  total=max_retries,
21
26
  backoff_factor=1,
22
27
  status_forcelist=[500, 502, 503, 504],
23
28
  allowed_methods=["GET"],
24
- respect_retry_after_header=True
29
+ respect_retry_after_header=True,
25
30
  )
26
31
 
27
32
  adapter = HTTPAdapter(
28
33
  pool_connections=max_workers * 2,
29
34
  pool_maxsize=max_workers * 4,
30
- max_retries=retries
35
+ max_retries=retries,
31
36
  )
32
- cls._session.mount('http://', adapter)
33
- cls._session.mount('https://', adapter)
37
+ cls._session.mount("http://", adapter)
38
+ cls._session.mount("https://", adapter)
34
39
 
35
40
  return cls._session
@@ -1,23 +1,28 @@
1
- from itertools import zip_longest
2
-
3
- from rapidata.api_client.models.create_datapoint_from_text_sources_model import CreateDatapointFromTextSourcesModel
4
- from rapidata.api_client.models.dataset_dataset_id_datapoints_post_request_metadata_inner import DatasetDatasetIdDatapointsPostRequestMetadataInner
5
1
  from rapidata.rapidata_client.datapoints.datapoint import Datapoint
6
- from rapidata.rapidata_client.datapoints.metadata import Metadata
7
- from rapidata.rapidata_client.datapoints.assets import TextAsset, MediaAsset, MultiAsset, BaseAsset
2
+ from rapidata.rapidata_client.datapoints.assets import TextAsset, MediaAsset
8
3
  from rapidata.service import LocalFileService
9
4
  from rapidata.service.openapi_service import OpenAPIService
10
5
  from concurrent.futures import ThreadPoolExecutor, as_completed
11
6
  from tqdm import tqdm
12
7
 
13
- from typing import cast, Sequence, Generator
14
- from rapidata.rapidata_client.logging import logger, managed_print, RapidataOutputManager
8
+ from typing import Generator
9
+ from rapidata.rapidata_client.logging import (
10
+ logger,
11
+ managed_print,
12
+ RapidataOutputManager,
13
+ )
15
14
  import time
16
15
  import threading
16
+ from rapidata.rapidata_client.api.rapidata_exception import (
17
+ suppress_rapidata_error_logging,
18
+ )
19
+ from rapidata.rapidata_client.config.config import rapidata_config
20
+
17
21
 
18
22
  def chunk_list(lst: list, chunk_size: int) -> Generator:
19
23
  for i in range(0, len(lst), chunk_size):
20
- yield lst[i:i + chunk_size]
24
+ yield lst[i : i + chunk_size]
25
+
21
26
 
22
27
  class RapidataDataset:
23
28
  def __init__(self, dataset_id: str, openapi_service: OpenAPIService):
@@ -31,39 +36,49 @@ class RapidataDataset:
31
36
  ) -> tuple[list[Datapoint], list[Datapoint]]:
32
37
  if not datapoints:
33
38
  return [], []
34
-
39
+
35
40
  effective_asset_type = datapoints[0]._get_effective_asset_type()
36
-
41
+
42
+ logger.debug(f"Config for datapoint upload: {rapidata_config}")
43
+
37
44
  if issubclass(effective_asset_type, MediaAsset):
38
- return self._add_media_from_paths(datapoints)
45
+ return self._add_media_from_paths(
46
+ datapoints,
47
+ )
39
48
  elif issubclass(effective_asset_type, TextAsset):
40
49
  return self._add_texts(datapoints)
41
50
  else:
42
51
  raise ValueError(f"Unsupported asset type: {effective_asset_type}")
43
52
 
44
53
  def _add_texts(
45
- self,
46
- datapoints: list[Datapoint],
47
- max_workers: int = 10,
54
+ self, datapoints: list[Datapoint]
48
55
  ) -> tuple[list[Datapoint], list[Datapoint]]:
49
-
56
+
50
57
  def upload_text_datapoint(datapoint: Datapoint, index: int) -> Datapoint:
51
58
  model = datapoint.create_text_upload_model(index)
52
-
53
- self.openapi_service.dataset_api.dataset_dataset_id_datapoints_texts_post(dataset_id=self.id, create_datapoint_from_text_sources_model=model)
59
+
60
+ self.openapi_service.dataset_api.dataset_dataset_id_datapoints_texts_post(
61
+ dataset_id=self.id, create_datapoint_from_text_sources_model=model
62
+ )
54
63
  return datapoint
55
64
 
56
65
  successful_uploads: list[Datapoint] = []
57
66
  failed_uploads: list[Datapoint] = []
58
67
 
59
68
  total_uploads = len(datapoints)
60
- with ThreadPoolExecutor(max_workers=max_workers) as executor:
69
+ with ThreadPoolExecutor(
70
+ max_workers=rapidata_config.max_upload_workers
71
+ ) as executor:
61
72
  future_to_datapoint = {
62
73
  executor.submit(upload_text_datapoint, datapoint, index=i): datapoint
63
74
  for i, datapoint in enumerate(datapoints)
64
75
  }
65
76
 
66
- with tqdm(total=total_uploads, desc="Uploading text datapoints", disable=RapidataOutputManager.silent_mode) as pbar:
77
+ with tqdm(
78
+ total=total_uploads,
79
+ desc="Uploading text datapoints",
80
+ disable=RapidataOutputManager.silent_mode,
81
+ ) as pbar:
67
82
  for future in as_completed(future_to_datapoint.keys()):
68
83
  datapoint = future_to_datapoint[future]
69
84
  try:
@@ -72,7 +87,7 @@ class RapidataDataset:
72
87
  successful_uploads.append(result)
73
88
  except Exception as e:
74
89
  failed_uploads.append(datapoint)
75
- logger.error(f"Upload failed for {datapoint}: {str(e)}")
90
+ logger.error("Upload failed for %s: %s", datapoint, str(e))
76
91
 
77
92
  return successful_uploads, failed_uploads
78
93
 
@@ -80,20 +95,21 @@ class RapidataDataset:
80
95
  self,
81
96
  datapoint: Datapoint,
82
97
  index: int,
83
- max_retries: int = 3,
84
98
  ) -> tuple[list[Datapoint], list[Datapoint]]:
85
99
  """
86
100
  Process single upload with retry logic and error tracking.
87
-
101
+
88
102
  Args:
89
103
  media_asset: MediaAsset or MultiAsset to upload
90
104
  meta_list: Optional sequence of metadata for the asset
91
105
  index: Sort index for the upload
92
106
  max_retries: Maximum number of retry attempts (default: 3)
93
-
107
+
94
108
  Returns:
95
109
  tuple[list[Datapoint], list[Datapoint]]: Lists of successful and failed datapoints
96
110
  """
111
+ logger.debug("Processing single upload for %s with index %s", datapoint, index)
112
+
97
113
  local_successful: list[Datapoint] = []
98
114
  local_failed: list[Datapoint] = []
99
115
 
@@ -103,44 +119,52 @@ class RapidataDataset:
103
119
  urls = datapoint.get_urls()
104
120
 
105
121
  last_exception = None
106
- for attempt in range(max_retries):
122
+ for attempt in range(rapidata_config.upload_max_retries):
107
123
  try:
108
- self.openapi_service.dataset_api.dataset_dataset_id_datapoints_post(
109
- dataset_id=self.id,
110
- file=local_paths,
111
- url=urls,
112
- metadata=metadata,
113
- sort_index=index,
114
- )
115
-
124
+ with suppress_rapidata_error_logging():
125
+ self.openapi_service.dataset_api.dataset_dataset_id_datapoints_post(
126
+ dataset_id=self.id,
127
+ file=local_paths,
128
+ url=urls,
129
+ metadata=metadata,
130
+ sort_index=index,
131
+ )
132
+
116
133
  local_successful.append(datapoint)
117
134
 
118
135
  return local_successful, local_failed
119
-
136
+
120
137
  except Exception as e:
121
138
  last_exception = e
122
- if attempt < max_retries - 1:
139
+ if attempt < rapidata_config.upload_max_retries - 1:
123
140
  # Exponential backoff: wait 1s, then 2s, then 4s
124
- retry_delay = 2 ** attempt
141
+ retry_delay = 2**attempt
125
142
  time.sleep(retry_delay)
126
- managed_print(f"\nRetrying {attempt + 1} of {max_retries}...\n")
127
-
143
+ logger.debug("Error: %s", str(last_exception))
144
+ logger.debug(
145
+ "Retrying %s of %s...",
146
+ attempt + 1,
147
+ rapidata_config.upload_max_retries,
148
+ )
149
+
128
150
  # If we get here, all retries failed
129
151
  local_failed.append(datapoint)
130
- logger.error(f"\nUpload failed for {datapoint} after {max_retries} attempts. Final error: {str(last_exception)}")
152
+ tqdm.write(
153
+ f"Upload failed for {datapoint} after {rapidata_config.upload_max_retries} attempts. \nFinal error: \n{str(last_exception)}"
154
+ )
131
155
 
132
156
  return local_successful, local_failed
133
157
 
134
158
  def _get_progress_tracker(
135
- self,
136
- total_uploads: int,
137
- stop_event: threading.Event,
159
+ self,
160
+ total_uploads: int,
161
+ stop_event: threading.Event,
138
162
  progress_error_event: threading.Event,
139
163
  progress_poll_interval: float,
140
164
  ) -> threading.Thread:
141
165
  """
142
166
  Create and return a progress tracking thread that shows actual API progress.
143
-
167
+
144
168
  Args:
145
169
  total_uploads: Total number of uploads to track
146
170
  initial_ready: Initial number of ready items
@@ -148,84 +172,97 @@ class RapidataDataset:
148
172
  stop_event: Event to signal thread to stop
149
173
  progress_error_event: Event to signal an error in progress tracking
150
174
  progress_poll_interval: Time between progress checks
151
-
175
+
152
176
  Returns:
153
177
  threading.Thread: The progress tracking thread
154
178
  """
179
+
155
180
  def progress_tracking_thread():
156
181
  try:
157
182
  # Initialize progress bar with 0 completions
158
- with tqdm(total=total_uploads, desc="Uploading datapoints", disable=RapidataOutputManager.silent_mode) as pbar:
183
+ with tqdm(
184
+ total=total_uploads,
185
+ desc="Uploading datapoints",
186
+ disable=RapidataOutputManager.silent_mode,
187
+ ) as pbar:
159
188
  prev_ready = 0
160
189
  prev_failed = 0
161
190
  stall_count = 0
162
191
  last_progress_time = time.time()
163
-
192
+
164
193
  # We'll wait for all uploads to finish + some extra time
165
194
  # for the backend to fully process everything
166
195
  all_uploads_complete = threading.Event()
167
-
196
+
168
197
  while not stop_event.is_set() or not all_uploads_complete.is_set():
169
198
  try:
170
- current_progress = self.openapi_service.dataset_api.dataset_dataset_id_progress_get(self.id)
171
-
199
+ current_progress = self.openapi_service.dataset_api.dataset_dataset_id_progress_get(
200
+ self.id
201
+ )
202
+
172
203
  # Calculate items completed since our initialization
173
204
  completed_ready = current_progress.ready
174
205
  completed_failed = current_progress.failed
175
206
  total_completed = completed_ready + completed_failed
176
-
207
+
177
208
  # Calculate newly completed items since our last check
178
209
  new_ready = current_progress.ready - prev_ready
179
210
  new_failed = current_progress.failed - prev_failed
180
-
211
+
181
212
  # Update progress bar position to show actual completed items
182
213
  # First reset to match the actual completed count
183
214
  pbar.n = total_completed
184
215
  pbar.refresh()
185
-
216
+
186
217
  if new_ready > 0 or new_failed > 0:
187
218
  # We saw progress
188
219
  stall_count = 0
189
220
  last_progress_time = time.time()
190
221
  else:
191
222
  stall_count += 1
192
-
223
+
193
224
  # Update our tracking variables
194
225
  prev_ready = current_progress.ready
195
226
  prev_failed = current_progress.failed or 0
196
-
227
+
197
228
  # Check if stop_event was set (all uploads submitted)
198
229
  if stop_event.is_set():
199
- elapsed_since_last_progress = time.time() - last_progress_time
200
-
230
+ elapsed_since_last_progress = (
231
+ time.time() - last_progress_time
232
+ )
233
+
201
234
  # If we haven't seen progress for a while after all uploads were submitted
202
235
  if elapsed_since_last_progress > 5.0:
203
236
  # If we're at 100%, we're done
204
237
  if total_completed >= total_uploads:
205
238
  all_uploads_complete.set()
206
239
  break
207
-
240
+
208
241
  # If we're not at 100% but it's been a while with no progress
209
242
  if stall_count > 5:
210
243
  # We've polled several times with no progress, assume we're done
211
- logger.warning(f"\nProgress seems stalled at {total_completed}/{total_uploads}. Please try again.")
244
+ logger.warning(
245
+ "\nProgress seems stalled at %s/%s.",
246
+ total_completed,
247
+ total_uploads,
248
+ )
212
249
  break
213
-
250
+
214
251
  except Exception as e:
215
- logger.error(f"\nError checking progress: {str(e)}")
252
+ logger.error("\nError checking progress: %s", str(e))
216
253
  stall_count += 1
217
-
254
+
218
255
  if stall_count > 10: # Too many consecutive errors
219
256
  progress_error_event.set()
220
257
  break
221
-
258
+
222
259
  # Sleep before next poll
223
260
  time.sleep(progress_poll_interval)
224
-
261
+
225
262
  except Exception as e:
226
- logger.error(f"Progress tracking thread error: {str(e)}")
263
+ logger.error("Progress tracking thread error: %s", str(e))
227
264
  progress_error_event.set()
228
-
265
+
229
266
  # Create and return the thread
230
267
  progress_thread = threading.Thread(target=progress_tracking_thread)
231
268
  progress_thread.daemon = True
@@ -234,68 +271,70 @@ class RapidataDataset:
234
271
  def _process_uploads_in_chunks(
235
272
  self,
236
273
  datapoints: list[Datapoint],
237
- max_workers: int,
238
274
  chunk_size: int,
239
275
  stop_progress_tracking: threading.Event,
240
- progress_tracking_error: threading.Event
276
+ progress_tracking_error: threading.Event,
241
277
  ) -> tuple[list[Datapoint], list[Datapoint]]:
242
278
  """
243
279
  Process uploads in chunks with a ThreadPoolExecutor.
244
-
280
+
245
281
  Args:
246
282
  media_paths: List of assets to upload
247
283
  multi_metadata: Optional sequence of sequences of metadata
248
- max_workers: Maximum number of concurrent workers
249
284
  chunk_size: Number of items to process in each batch
250
285
  stop_progress_tracking: Event to signal progress tracking to stop
251
286
  progress_tracking_error: Event to detect progress tracking errors
252
-
287
+
253
288
  Returns:
254
289
  tuple[list[str], list[str]]: Lists of successful and failed uploads
255
290
  """
256
291
  successful_uploads: list[Datapoint] = []
257
292
  failed_uploads: list[Datapoint] = []
258
-
293
+
259
294
  try:
260
- with ThreadPoolExecutor(max_workers=max_workers) as executor:
295
+ with ThreadPoolExecutor(
296
+ max_workers=rapidata_config.max_upload_workers
297
+ ) as executor:
261
298
  # Process uploads in chunks to avoid overwhelming the system
262
299
  for chunk_idx, chunk in enumerate(chunk_list(datapoints, chunk_size)):
263
300
  futures = [
264
301
  executor.submit(
265
- self._process_single_upload,
266
- datapoint,
267
- index=(chunk_idx * chunk_size + i)
302
+ self._process_single_upload,
303
+ datapoint,
304
+ index=(chunk_idx * chunk_size + i),
268
305
  )
269
306
  for i, datapoint in enumerate(chunk)
270
307
  ]
271
-
308
+
272
309
  # Wait for this chunk to complete before starting the next one
273
310
  for future in as_completed(futures):
274
311
  if progress_tracking_error.is_set():
275
- raise RuntimeError("Progress tracking failed, aborting uploads")
276
-
312
+ raise RuntimeError(
313
+ "Progress tracking failed, aborting uploads"
314
+ )
315
+
277
316
  try:
278
317
  chunk_successful, chunk_failed = future.result()
279
318
  successful_uploads.extend(chunk_successful)
280
319
  failed_uploads.extend(chunk_failed)
281
320
  except Exception as e:
282
- logger.error(f"Future execution failed: {str(e)}")
321
+ logger.error("Future execution failed: %s", str(e))
283
322
  finally:
284
323
  # Signal to the progress tracking thread that all uploads have been submitted
285
324
  stop_progress_tracking.set()
286
-
325
+
287
326
  return successful_uploads, failed_uploads
288
327
 
289
328
  def _log_final_progress(
290
- self,
291
- total_uploads: int,
329
+ self,
330
+ total_uploads: int,
292
331
  progress_poll_interval: float,
293
332
  successful_uploads: list[Datapoint],
294
- failed_uploads: list[Datapoint]
333
+ failed_uploads: list[Datapoint],
295
334
  ) -> None:
296
335
  """
297
336
  Log the final progress of the upload operation.
298
-
337
+
299
338
  Args:
300
339
  total_uploads: Total number of uploads
301
340
  initial_ready: Initial number of ready items
@@ -304,93 +343,105 @@ class RapidataDataset:
304
343
  successful_uploads: List of successful uploads for fallback reporting
305
344
  failed_uploads: List of failed uploads for fallback reporting
306
345
  """
307
- try:
346
+ try:
308
347
  # Get final progress
309
- final_progress = self.openapi_service.dataset_api.dataset_dataset_id_progress_get(self.id)
348
+ final_progress = (
349
+ self.openapi_service.dataset_api.dataset_dataset_id_progress_get(
350
+ self.id
351
+ )
352
+ )
310
353
  total_ready = final_progress.ready
311
354
  total_failed = final_progress.failed
312
-
355
+
313
356
  # Make sure we account for all uploads
314
357
  if total_ready + total_failed < total_uploads:
315
358
  # Try one more time after a longer wait
316
359
  time.sleep(5 * progress_poll_interval)
317
- final_progress = self.openapi_service.dataset_api.dataset_dataset_id_progress_get(self.id)
360
+ final_progress = (
361
+ self.openapi_service.dataset_api.dataset_dataset_id_progress_get(
362
+ self.id
363
+ )
364
+ )
318
365
  total_ready = final_progress.ready
319
366
  total_failed = final_progress.failed
320
-
321
- success_rate = (total_ready / total_uploads * 100) if total_uploads > 0 else 0
322
-
323
- logger.info(f"Upload complete: {total_ready} ready, {total_uploads-total_ready} failed ({success_rate:.1f}% success rate)")
367
+
368
+ success_rate = (
369
+ (total_ready / total_uploads * 100) if total_uploads > 0 else 0
370
+ )
371
+
372
+ logger.info(
373
+ "Upload complete: %s ready, %s failed (%s%% success rate)",
374
+ total_ready,
375
+ total_uploads - total_ready,
376
+ success_rate,
377
+ )
324
378
  except Exception as e:
325
- logger.error(f"Error getting final progress: {str(e)}")
326
- logger.info(f"Upload summary from local tracking: {len(successful_uploads)} succeeded, {len(failed_uploads)} failed")
379
+ logger.error("Error getting final progress: %s", str(e))
380
+ logger.info(
381
+ "Upload summary from local tracking: %s succeeded, %s failed",
382
+ len(successful_uploads),
383
+ len(failed_uploads),
384
+ )
327
385
 
328
386
  if failed_uploads:
329
- logger.error(f"Failed uploads: {failed_uploads}")
387
+ logger.error("Failed uploads: %s", failed_uploads)
330
388
 
331
389
  def _add_media_from_paths(
332
390
  self,
333
391
  datapoints: list[Datapoint],
334
- max_workers: int = 10,
335
392
  chunk_size: int = 50,
336
393
  progress_poll_interval: float = 0.5,
337
394
  ) -> tuple[list[Datapoint], list[Datapoint]]:
338
395
  """
339
396
  Upload media paths in chunks with managed resources.
340
-
397
+
341
398
  Args:
342
399
  datapoints: List of Datapoint objects to upload
343
- max_workers: Maximum number of concurrent upload workers
344
400
  chunk_size: Number of items to process in each batch
345
401
  progress_poll_interval: Time in seconds between progress checks
346
-
347
402
  Returns:
348
403
  tuple[list[Datapoint], list[Datapoint]]: Lists of successful and failed datapoints
349
-
404
+
350
405
  Raises:
351
406
  ValueError: If multi_metadata lengths don't match media_paths length
352
407
  """
353
-
408
+
354
409
  # Setup tracking variables
355
410
  total_uploads = len(datapoints)
356
-
411
+
357
412
  # Create thread control events
358
413
  stop_progress_tracking = threading.Event()
359
414
  progress_tracking_error = threading.Event()
360
-
415
+
361
416
  # Create and start progress tracking thread
362
417
  progress_thread = self._get_progress_tracker(
363
- total_uploads,
364
- stop_progress_tracking,
418
+ total_uploads,
419
+ stop_progress_tracking,
365
420
  progress_tracking_error,
366
- progress_poll_interval
421
+ progress_poll_interval,
367
422
  )
368
423
  progress_thread.start()
369
-
424
+
370
425
  # Process uploads in chunks
371
426
  try:
372
427
  successful_uploads, failed_uploads = self._process_uploads_in_chunks(
373
428
  datapoints,
374
- max_workers,
375
429
  chunk_size,
376
430
  stop_progress_tracking,
377
- progress_tracking_error
431
+ progress_tracking_error,
378
432
  )
379
433
  finally:
380
434
  progress_thread.join(10) # Add margin to the timeout for tqdm
381
-
435
+
382
436
  # Log final progress
383
437
  self._log_final_progress(
384
- total_uploads,
385
- progress_poll_interval,
386
- successful_uploads,
387
- failed_uploads
438
+ total_uploads, progress_poll_interval, successful_uploads, failed_uploads
388
439
  )
389
440
 
390
441
  return successful_uploads, failed_uploads
391
442
 
392
443
  def __str__(self) -> str:
393
444
  return f"RapidataDataset(id={self.id})"
394
-
445
+
395
446
  def __repr__(self) -> str:
396
447
  return self.__str__()