airbyte-cdk 6.20.2.dev0__py3-none-any.whl → 6.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. airbyte_cdk/sources/declarative/auth/oauth.py +34 -0
  2. airbyte_cdk/sources/declarative/checks/__init__.py +18 -2
  3. airbyte_cdk/sources/declarative/checks/check_dynamic_stream.py +51 -0
  4. airbyte_cdk/sources/declarative/concurrent_declarative_source.py +16 -80
  5. airbyte_cdk/sources/declarative/declarative_component_schema.yaml +123 -21
  6. airbyte_cdk/sources/declarative/decoders/__init__.py +9 -1
  7. airbyte_cdk/sources/declarative/decoders/composite_raw_decoder.py +43 -0
  8. airbyte_cdk/sources/declarative/decoders/zipfile_decoder.py +59 -0
  9. airbyte_cdk/sources/declarative/extractors/record_filter.py +5 -3
  10. airbyte_cdk/sources/declarative/incremental/__init__.py +0 -6
  11. airbyte_cdk/sources/declarative/incremental/global_substream_cursor.py +0 -3
  12. airbyte_cdk/sources/declarative/incremental/per_partition_cursor.py +0 -15
  13. airbyte_cdk/sources/declarative/manifest_declarative_source.py +2 -1
  14. airbyte_cdk/sources/declarative/models/declarative_component_schema.py +112 -27
  15. airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py +127 -106
  16. airbyte_cdk/sources/declarative/requesters/README.md +56 -0
  17. airbyte_cdk/sources/declarative/requesters/http_job_repository.py +33 -4
  18. airbyte_cdk/sources/declarative/retrievers/simple_retriever.py +1 -1
  19. airbyte_cdk/sources/declarative/schema/dynamic_schema_loader.py +13 -3
  20. airbyte_cdk/sources/file_based/config/abstract_file_based_spec.py +11 -0
  21. airbyte_cdk/sources/file_based/exceptions.py +34 -0
  22. airbyte_cdk/sources/file_based/file_based_source.py +28 -5
  23. airbyte_cdk/sources/file_based/file_based_stream_reader.py +18 -4
  24. airbyte_cdk/sources/file_based/file_types/unstructured_parser.py +25 -2
  25. airbyte_cdk/sources/file_based/stream/default_file_based_stream.py +30 -2
  26. airbyte_cdk/sources/streams/concurrent/cursor.py +21 -30
  27. airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py +33 -4
  28. airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py +42 -4
  29. airbyte_cdk/sources/types.py +3 -0
  30. airbyte_cdk/sources/utils/transform.py +29 -3
  31. {airbyte_cdk-6.20.2.dev0.dist-info → airbyte_cdk-6.21.0.dist-info}/METADATA +1 -1
  32. {airbyte_cdk-6.20.2.dev0.dist-info → airbyte_cdk-6.21.0.dist-info}/RECORD +35 -33
  33. airbyte_cdk/sources/declarative/incremental/concurrent_partition_cursor.py +0 -331
  34. {airbyte_cdk-6.20.2.dev0.dist-info → airbyte_cdk-6.21.0.dist-info}/LICENSE.txt +0 -0
  35. {airbyte_cdk-6.20.2.dev0.dist-info → airbyte_cdk-6.21.0.dist-info}/WHEEL +0 -0
  36. {airbyte_cdk-6.20.2.dev0.dist-info → airbyte_cdk-6.21.0.dist-info}/entry_points.txt +0 -0
@@ -111,6 +111,40 @@ class ErrorListingFiles(BaseFileBasedSourceError):
111
111
  pass
112
112
 
113
113
 
114
+ class DuplicatedFilesError(BaseFileBasedSourceError):
115
+ def __init__(self, duplicated_files_names: List[dict[str, List[str]]], **kwargs: Any):
116
+ self._duplicated_files_names = duplicated_files_names
117
+ self._stream_name: str = kwargs["stream"]
118
+ super().__init__(self._format_duplicate_files_error_message(), **kwargs)
119
+
120
+ def _format_duplicate_files_error_message(self) -> str:
121
+ duplicated_files_messages = []
122
+ for duplicated_file in self._duplicated_files_names:
123
+ for duplicated_file_name, file_paths in duplicated_file.items():
124
+ file_duplicated_message = (
125
+ f"{len(file_paths)} duplicates found for file name {duplicated_file_name}:\n\n"
126
+ + "".join(f"\n - {file_paths}")
127
+ )
128
+ duplicated_files_messages.append(file_duplicated_message)
129
+
130
+ error_message = (
131
+ f"ERROR: Duplicate filenames found for stream {self._stream_name}. "
132
+ "Duplicate file names are not allowed if the Preserve Sub-Directories in File Paths option is disabled. "
133
+ "Please remove or rename the duplicate files before attempting to re-run the sync.\n\n"
134
+ + "\n".join(duplicated_files_messages)
135
+ )
136
+
137
+ return error_message
138
+
139
+ def __repr__(self) -> str:
140
+ """Return a string representation of the exception."""
141
+ class_name = self.__class__.__name__
142
+ properties_str = ", ".join(
143
+ f"{k}={v!r}" for k, v in self.__dict__.items() if not k.startswith("_")
144
+ )
145
+ return f"{class_name}({properties_str})"
146
+
147
+
114
148
  class CustomFileBasedException(AirbyteTracedException):
115
149
  """
116
150
  A specialized exception for file-based connectors.
@@ -242,7 +242,7 @@ class FileBasedSource(ConcurrentSourceAdapter, ABC):
242
242
  stream=self._make_default_stream(
243
243
  stream_config=stream_config,
244
244
  cursor=cursor,
245
- use_file_transfer=self._use_file_transfer(parsed_config),
245
+ parsed_config=parsed_config,
246
246
  ),
247
247
  source=self,
248
248
  logger=self.logger,
@@ -273,7 +273,7 @@ class FileBasedSource(ConcurrentSourceAdapter, ABC):
273
273
  stream=self._make_default_stream(
274
274
  stream_config=stream_config,
275
275
  cursor=cursor,
276
- use_file_transfer=self._use_file_transfer(parsed_config),
276
+ parsed_config=parsed_config,
277
277
  ),
278
278
  source=self,
279
279
  logger=self.logger,
@@ -285,7 +285,7 @@ class FileBasedSource(ConcurrentSourceAdapter, ABC):
285
285
  stream = self._make_default_stream(
286
286
  stream_config=stream_config,
287
287
  cursor=cursor,
288
- use_file_transfer=self._use_file_transfer(parsed_config),
288
+ parsed_config=parsed_config,
289
289
  )
290
290
 
291
291
  streams.append(stream)
@@ -298,7 +298,7 @@ class FileBasedSource(ConcurrentSourceAdapter, ABC):
298
298
  self,
299
299
  stream_config: FileBasedStreamConfig,
300
300
  cursor: Optional[AbstractFileBasedCursor],
301
- use_file_transfer: bool = False,
301
+ parsed_config: AbstractFileBasedSpec,
302
302
  ) -> AbstractFileBasedStream:
303
303
  return DefaultFileBasedStream(
304
304
  config=stream_config,
@@ -310,7 +310,8 @@ class FileBasedSource(ConcurrentSourceAdapter, ABC):
310
310
  validation_policy=self._validate_and_get_validation_policy(stream_config),
311
311
  errors_collector=self.errors_collector,
312
312
  cursor=cursor,
313
- use_file_transfer=use_file_transfer,
313
+ use_file_transfer=self._use_file_transfer(parsed_config),
314
+ preserve_directory_structure=self._preserve_directory_structure(parsed_config),
314
315
  )
315
316
 
316
317
  def _get_stream_from_catalog(
@@ -385,3 +386,25 @@ class FileBasedSource(ConcurrentSourceAdapter, ABC):
385
386
  and parsed_config.delivery_method.delivery_type == "use_file_transfer"
386
387
  )
387
388
  return use_file_transfer
389
+
390
+ @staticmethod
391
+ def _preserve_directory_structure(parsed_config: AbstractFileBasedSpec) -> bool:
392
+ """
393
+ Determines whether to preserve directory structure during file transfer.
394
+
395
+ When enabled, files maintain their subdirectory paths in the destination.
396
+ When disabled, files are flattened to the root of the destination.
397
+
398
+ Args:
399
+ parsed_config: The parsed configuration containing delivery method settings
400
+
401
+ Returns:
402
+ True if directory structure should be preserved (default), False otherwise
403
+ """
404
+ if (
405
+ FileBasedSource._use_file_transfer(parsed_config)
406
+ and hasattr(parsed_config.delivery_method, "preserve_directory_structure")
407
+ and parsed_config.delivery_method.preserve_directory_structure is not None
408
+ ):
409
+ return parsed_config.delivery_method.preserve_directory_structure
410
+ return True
@@ -135,6 +135,17 @@ class AbstractFileBasedStreamReader(ABC):
135
135
  return use_file_transfer
136
136
  return False
137
137
 
138
+ def preserve_directory_structure(self) -> bool:
139
+ # fall back to preserve subdirectories if config is not present or incomplete
140
+ if (
141
+ self.use_file_transfer()
142
+ and self.config
143
+ and hasattr(self.config.delivery_method, "preserve_directory_structure")
144
+ and self.config.delivery_method.preserve_directory_structure is not None
145
+ ):
146
+ return self.config.delivery_method.preserve_directory_structure
147
+ return True
148
+
138
149
  @abstractmethod
139
150
  def get_file(
140
151
  self, file: RemoteFile, local_directory: str, logger: logging.Logger
@@ -159,10 +170,13 @@ class AbstractFileBasedStreamReader(ABC):
159
170
  """
160
171
  ...
161
172
 
162
- @staticmethod
163
- def _get_file_transfer_paths(file: RemoteFile, local_directory: str) -> List[str]:
164
- # Remove left slashes from source path format to make relative path for writing locally
165
- file_relative_path = file.uri.lstrip("/")
173
+ def _get_file_transfer_paths(self, file: RemoteFile, local_directory: str) -> List[str]:
174
+ preserve_directory_structure = self.preserve_directory_structure()
175
+ if preserve_directory_structure:
176
+ # Remove left slashes from source path format to make relative path for writing locally
177
+ file_relative_path = file.uri.lstrip("/")
178
+ else:
179
+ file_relative_path = path.basename(file.uri)
166
180
  local_file_path = path.join(local_directory, file_relative_path)
167
181
 
168
182
  # Ensure the local directory exists
@@ -2,6 +2,7 @@
2
2
  # Copyright (c) 2023 Airbyte, Inc., all rights reserved.
3
3
  #
4
4
  import logging
5
+ import os
5
6
  import traceback
6
7
  from datetime import datetime
7
8
  from io import BytesIO, IOBase
@@ -42,12 +43,34 @@ unstructured_partition_pdf = None
42
43
  unstructured_partition_docx = None
43
44
  unstructured_partition_pptx = None
44
45
 
46
+ AIRBYTE_NLTK_DATA_DIR = "/airbyte/nltk_data"
47
+ TMP_NLTK_DATA_DIR = "/tmp/nltk_data"
48
+
49
+
50
+ def get_nltk_temp_folder() -> str:
51
+ """
52
+ For non-root connectors /tmp is not currently writable, but we should allow it in the future.
53
+ It's safe to use /airbyte for now. Fallback to /tmp for local development.
54
+ """
55
+ try:
56
+ nltk_data_dir = AIRBYTE_NLTK_DATA_DIR
57
+ os.makedirs(nltk_data_dir, exist_ok=True)
58
+ except OSError:
59
+ nltk_data_dir = TMP_NLTK_DATA_DIR
60
+ os.makedirs(nltk_data_dir, exist_ok=True)
61
+ return nltk_data_dir
62
+
63
+
45
64
  try:
65
+ nltk_data_dir = get_nltk_temp_folder()
66
+ nltk.data.path.append(nltk_data_dir)
46
67
  nltk.data.find("tokenizers/punkt.zip")
47
68
  nltk.data.find("tokenizers/punkt_tab.zip")
69
+ nltk.data.find("tokenizers/averaged_perceptron_tagger_eng.zip")
48
70
  except LookupError:
49
- nltk.download("punkt")
50
- nltk.download("punkt_tab")
71
+ nltk.download("punkt", download_dir=nltk_data_dir, quiet=True)
72
+ nltk.download("punkt_tab", download_dir=nltk_data_dir, quiet=True)
73
+ nltk.download("averaged_perceptron_tagger_eng", download_dir=nltk_data_dir, quiet=True)
51
74
 
52
75
 
53
76
  def optional_decode(contents: Union[str, bytes]) -> str:
@@ -5,14 +5,17 @@
5
5
  import asyncio
6
6
  import itertools
7
7
  import traceback
8
+ from collections import defaultdict
8
9
  from copy import deepcopy
9
10
  from functools import cache
10
- from typing import Any, Dict, Iterable, List, Mapping, MutableMapping, Optional, Set, Union
11
+ from os import path
12
+ from typing import Any, Dict, Iterable, List, Mapping, MutableMapping, Optional, Set, Tuple, Union
11
13
 
12
14
  from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, FailureType, Level
13
15
  from airbyte_cdk.models import Type as MessageType
14
16
  from airbyte_cdk.sources.file_based.config.file_based_stream_config import PrimaryKeyType
15
17
  from airbyte_cdk.sources.file_based.exceptions import (
18
+ DuplicatedFilesError,
16
19
  FileBasedSourceError,
17
20
  InvalidSchemaError,
18
21
  MissingSchemaError,
@@ -43,6 +46,8 @@ class DefaultFileBasedStream(AbstractFileBasedStream, IncrementalMixin):
43
46
  """
44
47
 
45
48
  FILE_TRANSFER_KW = "use_file_transfer"
49
+ PRESERVE_DIRECTORY_STRUCTURE_KW = "preserve_directory_structure"
50
+ FILES_KEY = "files"
46
51
  DATE_TIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ"
47
52
  ab_last_mod_col = "_ab_source_file_last_modified"
48
53
  ab_file_name_col = "_ab_source_file_url"
@@ -50,10 +55,15 @@ class DefaultFileBasedStream(AbstractFileBasedStream, IncrementalMixin):
50
55
  source_file_url = "source_file_url"
51
56
  airbyte_columns = [ab_last_mod_col, ab_file_name_col]
52
57
  use_file_transfer = False
58
+ preserve_directory_structure = True
53
59
 
54
60
  def __init__(self, **kwargs: Any):
55
61
  if self.FILE_TRANSFER_KW in kwargs:
56
62
  self.use_file_transfer = kwargs.pop(self.FILE_TRANSFER_KW, False)
63
+ if self.PRESERVE_DIRECTORY_STRUCTURE_KW in kwargs:
64
+ self.preserve_directory_structure = kwargs.pop(
65
+ self.PRESERVE_DIRECTORY_STRUCTURE_KW, True
66
+ )
57
67
  super().__init__(**kwargs)
58
68
 
59
69
  @property
@@ -98,15 +108,33 @@ class DefaultFileBasedStream(AbstractFileBasedStream, IncrementalMixin):
98
108
  else:
99
109
  return super()._filter_schema_invalid_properties(configured_catalog_json_schema)
100
110
 
111
+ def _duplicated_files_names(
112
+ self, slices: List[dict[str, List[RemoteFile]]]
113
+ ) -> List[dict[str, List[str]]]:
114
+ seen_file_names: Dict[str, List[str]] = defaultdict(list)
115
+ for file_slice in slices:
116
+ for file_found in file_slice[self.FILES_KEY]:
117
+ file_name = path.basename(file_found.uri)
118
+ seen_file_names[file_name].append(file_found.uri)
119
+ return [
120
+ {file_name: paths} for file_name, paths in seen_file_names.items() if len(paths) > 1
121
+ ]
122
+
101
123
  def compute_slices(self) -> Iterable[Optional[Mapping[str, Any]]]:
102
124
  # Sort files by last_modified, uri and return them grouped by last_modified
103
125
  all_files = self.list_files()
104
126
  files_to_read = self._cursor.get_files_to_sync(all_files, self.logger)
105
127
  sorted_files_to_read = sorted(files_to_read, key=lambda f: (f.last_modified, f.uri))
106
128
  slices = [
107
- {"files": list(group[1])}
129
+ {self.FILES_KEY: list(group[1])}
108
130
  for group in itertools.groupby(sorted_files_to_read, lambda f: f.last_modified)
109
131
  ]
132
+ if slices and not self.preserve_directory_structure:
133
+ duplicated_files_names = self._duplicated_files_names(slices)
134
+ if duplicated_files_names:
135
+ raise DuplicatedFilesError(
136
+ stream=self.name, duplicated_files_names=duplicated_files_names
137
+ )
110
138
  return slices
111
139
 
112
140
  def transform_record(
@@ -196,9 +196,7 @@ class ConcurrentCursor(Cursor):
196
196
 
197
197
  @property
198
198
  def state(self) -> MutableMapping[str, Any]:
199
- return self._connector_state_converter.convert_to_state_message(
200
- self.cursor_field, self._concurrent_state
201
- )
199
+ return self._concurrent_state
202
200
 
203
201
  @property
204
202
  def cursor_field(self) -> CursorField:
@@ -243,10 +241,10 @@ class ConcurrentCursor(Cursor):
243
241
  return self._connector_state_converter.parse_value(self._cursor_field.extract_value(record))
244
242
 
245
243
  def close_partition(self, partition: Partition) -> None:
246
- slice_count_before = len(self._concurrent_state.get("slices", []))
244
+ slice_count_before = len(self.state.get("slices", []))
247
245
  self._add_slice_to_state(partition)
248
246
  if slice_count_before < len(
249
- self._concurrent_state["slices"]
247
+ self.state["slices"]
250
248
  ): # only emit if at least one slice has been processed
251
249
  self._merge_partitions()
252
250
  self._emit_state_message()
@@ -258,11 +256,11 @@ class ConcurrentCursor(Cursor):
258
256
  )
259
257
 
260
258
  if self._slice_boundary_fields:
261
- if "slices" not in self._concurrent_state:
259
+ if "slices" not in self.state:
262
260
  raise RuntimeError(
263
261
  f"The state for stream {self._stream_name} should have at least one slice to delineate the sync start time, but no slices are present. This is unexpected. Please contact Support."
264
262
  )
265
- self._concurrent_state["slices"].append(
263
+ self.state["slices"].append(
266
264
  {
267
265
  self._connector_state_converter.START_KEY: self._extract_from_slice(
268
266
  partition, self._slice_boundary_fields[self._START_BOUNDARY]
@@ -290,7 +288,7 @@ class ConcurrentCursor(Cursor):
290
288
  "expected. Please contact the Airbyte team."
291
289
  )
292
290
 
293
- self._concurrent_state["slices"].append(
291
+ self.state["slices"].append(
294
292
  {
295
293
  self._connector_state_converter.START_KEY: self.start,
296
294
  self._connector_state_converter.END_KEY: most_recent_cursor_value,
@@ -302,7 +300,9 @@ class ConcurrentCursor(Cursor):
302
300
  self._connector_state_manager.update_state_for_stream(
303
301
  self._stream_name,
304
302
  self._stream_namespace,
305
- self.state,
303
+ self._connector_state_converter.convert_to_state_message(
304
+ self._cursor_field, self.state
305
+ ),
306
306
  )
307
307
  state_message = self._connector_state_manager.create_state_message(
308
308
  self._stream_name, self._stream_namespace
@@ -310,9 +310,7 @@ class ConcurrentCursor(Cursor):
310
310
  self._message_repository.emit_message(state_message)
311
311
 
312
312
  def _merge_partitions(self) -> None:
313
- self._concurrent_state["slices"] = self._connector_state_converter.merge_intervals(
314
- self._concurrent_state["slices"]
315
- )
313
+ self.state["slices"] = self._connector_state_converter.merge_intervals(self.state["slices"])
316
314
 
317
315
  def _extract_from_slice(self, partition: Partition, key: str) -> CursorValueType:
318
316
  try:
@@ -349,42 +347,36 @@ class ConcurrentCursor(Cursor):
349
347
  if self._start is not None and self._is_start_before_first_slice():
350
348
  yield from self._split_per_slice_range(
351
349
  self._start,
352
- self._concurrent_state["slices"][0][self._connector_state_converter.START_KEY],
350
+ self.state["slices"][0][self._connector_state_converter.START_KEY],
353
351
  False,
354
352
  )
355
353
 
356
- if len(self._concurrent_state["slices"]) == 1:
354
+ if len(self.state["slices"]) == 1:
357
355
  yield from self._split_per_slice_range(
358
356
  self._calculate_lower_boundary_of_last_slice(
359
- self._concurrent_state["slices"][0][self._connector_state_converter.END_KEY]
357
+ self.state["slices"][0][self._connector_state_converter.END_KEY]
360
358
  ),
361
359
  self._end_provider(),
362
360
  True,
363
361
  )
364
- elif len(self._concurrent_state["slices"]) > 1:
365
- for i in range(len(self._concurrent_state["slices"]) - 1):
362
+ elif len(self.state["slices"]) > 1:
363
+ for i in range(len(self.state["slices"]) - 1):
366
364
  if self._cursor_granularity:
367
365
  yield from self._split_per_slice_range(
368
- self._concurrent_state["slices"][i][self._connector_state_converter.END_KEY]
366
+ self.state["slices"][i][self._connector_state_converter.END_KEY]
369
367
  + self._cursor_granularity,
370
- self._concurrent_state["slices"][i + 1][
371
- self._connector_state_converter.START_KEY
372
- ],
368
+ self.state["slices"][i + 1][self._connector_state_converter.START_KEY],
373
369
  False,
374
370
  )
375
371
  else:
376
372
  yield from self._split_per_slice_range(
377
- self._concurrent_state["slices"][i][
378
- self._connector_state_converter.END_KEY
379
- ],
380
- self._concurrent_state["slices"][i + 1][
381
- self._connector_state_converter.START_KEY
382
- ],
373
+ self.state["slices"][i][self._connector_state_converter.END_KEY],
374
+ self.state["slices"][i + 1][self._connector_state_converter.START_KEY],
383
375
  False,
384
376
  )
385
377
  yield from self._split_per_slice_range(
386
378
  self._calculate_lower_boundary_of_last_slice(
387
- self._concurrent_state["slices"][-1][self._connector_state_converter.END_KEY]
379
+ self.state["slices"][-1][self._connector_state_converter.END_KEY]
388
380
  ),
389
381
  self._end_provider(),
390
382
  True,
@@ -395,8 +387,7 @@ class ConcurrentCursor(Cursor):
395
387
  def _is_start_before_first_slice(self) -> bool:
396
388
  return (
397
389
  self._start is not None
398
- and self._start
399
- < self._concurrent_state["slices"][0][self._connector_state_converter.START_KEY]
390
+ and self._start < self.state["slices"][0][self._connector_state_converter.START_KEY]
400
391
  )
401
392
 
402
393
  def _calculate_lower_boundary_of_last_slice(
@@ -81,10 +81,10 @@ class AbstractOauth2Authenticator(AuthBase):
81
81
  Override to define additional parameters
82
82
  """
83
83
  payload: MutableMapping[str, Any] = {
84
- "grant_type": self.get_grant_type(),
85
- "client_id": self.get_client_id(),
86
- "client_secret": self.get_client_secret(),
87
- "refresh_token": self.get_refresh_token(),
84
+ self.get_grant_type_name(): self.get_grant_type(),
85
+ self.get_client_id_name(): self.get_client_id(),
86
+ self.get_client_secret_name(): self.get_client_secret(),
87
+ self.get_refresh_token_name(): self.get_refresh_token(),
88
88
  }
89
89
 
90
90
  if self.get_scopes():
@@ -98,6 +98,14 @@ class AbstractOauth2Authenticator(AuthBase):
98
98
 
99
99
  return payload
100
100
 
101
+ def build_refresh_request_headers(self) -> Mapping[str, Any] | None:
102
+ """
103
+ Returns the request headers to set on the refresh request
104
+
105
+ """
106
+ headers = self.get_refresh_request_headers()
107
+ return headers if headers else None
108
+
101
109
  def _wrap_refresh_token_exception(
102
110
  self, exception: requests.exceptions.RequestException
103
111
  ) -> bool:
@@ -128,6 +136,7 @@ class AbstractOauth2Authenticator(AuthBase):
128
136
  method="POST",
129
137
  url=self.get_token_refresh_endpoint(), # type: ignore # returns None, if not provided, but str | bytes is expected.
130
138
  data=self.build_refresh_request_body(),
139
+ headers=self.build_refresh_request_headers(),
131
140
  )
132
141
  if response.ok:
133
142
  response_json = response.json()
@@ -206,14 +215,26 @@ class AbstractOauth2Authenticator(AuthBase):
206
215
  def get_token_refresh_endpoint(self) -> Optional[str]:
207
216
  """Returns the endpoint to refresh the access token"""
208
217
 
218
+ @abstractmethod
219
+ def get_client_id_name(self) -> str:
220
+ """The client id name to authenticate"""
221
+
209
222
  @abstractmethod
210
223
  def get_client_id(self) -> str:
211
224
  """The client id to authenticate"""
212
225
 
226
+ @abstractmethod
227
+ def get_client_secret_name(self) -> str:
228
+ """The client secret name to authenticate"""
229
+
213
230
  @abstractmethod
214
231
  def get_client_secret(self) -> str:
215
232
  """The client secret to authenticate"""
216
233
 
234
+ @abstractmethod
235
+ def get_refresh_token_name(self) -> str:
236
+ """The refresh token name to authenticate"""
237
+
217
238
  @abstractmethod
218
239
  def get_refresh_token(self) -> Optional[str]:
219
240
  """The token used to refresh the access token when it expires"""
@@ -242,10 +263,18 @@ class AbstractOauth2Authenticator(AuthBase):
242
263
  def get_refresh_request_body(self) -> Mapping[str, Any]:
243
264
  """Returns the request body to set on the refresh request"""
244
265
 
266
+ @abstractmethod
267
+ def get_refresh_request_headers(self) -> Mapping[str, Any]:
268
+ """Returns the request headers to set on the refresh request"""
269
+
245
270
  @abstractmethod
246
271
  def get_grant_type(self) -> str:
247
272
  """Returns grant_type specified for requesting access_token"""
248
273
 
274
+ @abstractmethod
275
+ def get_grant_type_name(self) -> str:
276
+ """Returns grant_type specified name for requesting access_token"""
277
+
249
278
  @property
250
279
  @abstractmethod
251
280
  def access_token(self) -> str:
@@ -30,12 +30,17 @@ class Oauth2Authenticator(AbstractOauth2Authenticator):
30
30
  client_id: str,
31
31
  client_secret: str,
32
32
  refresh_token: str,
33
+ client_id_name: str = "client_id",
34
+ client_secret_name: str = "client_secret",
35
+ refresh_token_name: str = "refresh_token",
33
36
  scopes: List[str] | None = None,
34
37
  token_expiry_date: pendulum.DateTime | None = None,
35
38
  token_expiry_date_format: str | None = None,
36
39
  access_token_name: str = "access_token",
37
40
  expires_in_name: str = "expires_in",
38
41
  refresh_request_body: Mapping[str, Any] | None = None,
42
+ refresh_request_headers: Mapping[str, Any] | None = None,
43
+ grant_type_name: str = "grant_type",
39
44
  grant_type: str = "refresh_token",
40
45
  token_expiry_is_time_of_expiration: bool = False,
41
46
  refresh_token_error_status_codes: Tuple[int, ...] = (),
@@ -43,13 +48,18 @@ class Oauth2Authenticator(AbstractOauth2Authenticator):
43
48
  refresh_token_error_values: Tuple[str, ...] = (),
44
49
  ):
45
50
  self._token_refresh_endpoint = token_refresh_endpoint
51
+ self._client_secret_name = client_secret_name
46
52
  self._client_secret = client_secret
53
+ self._client_id_name = client_id_name
47
54
  self._client_id = client_id
55
+ self._refresh_token_name = refresh_token_name
48
56
  self._refresh_token = refresh_token
49
57
  self._scopes = scopes
50
58
  self._access_token_name = access_token_name
51
59
  self._expires_in_name = expires_in_name
52
60
  self._refresh_request_body = refresh_request_body
61
+ self._refresh_request_headers = refresh_request_headers
62
+ self._grant_type_name = grant_type_name
53
63
  self._grant_type = grant_type
54
64
 
55
65
  self._token_expiry_date = token_expiry_date or pendulum.now().subtract(days=1) # type: ignore [no-untyped-call]
@@ -63,12 +73,21 @@ class Oauth2Authenticator(AbstractOauth2Authenticator):
63
73
  def get_token_refresh_endpoint(self) -> str:
64
74
  return self._token_refresh_endpoint
65
75
 
76
+ def get_client_id_name(self) -> str:
77
+ return self._client_id_name
78
+
66
79
  def get_client_id(self) -> str:
67
80
  return self._client_id
68
81
 
82
+ def get_client_secret_name(self) -> str:
83
+ return self._client_secret_name
84
+
69
85
  def get_client_secret(self) -> str:
70
86
  return self._client_secret
71
87
 
88
+ def get_refresh_token_name(self) -> str:
89
+ return self._refresh_token_name
90
+
72
91
  def get_refresh_token(self) -> str:
73
92
  return self._refresh_token
74
93
 
@@ -84,6 +103,12 @@ class Oauth2Authenticator(AbstractOauth2Authenticator):
84
103
  def get_refresh_request_body(self) -> Mapping[str, Any]:
85
104
  return self._refresh_request_body # type: ignore [return-value]
86
105
 
106
+ def get_refresh_request_headers(self) -> Mapping[str, Any]:
107
+ return self._refresh_request_headers # type: ignore [return-value]
108
+
109
+ def get_grant_type_name(self) -> str:
110
+ return self._grant_type_name
111
+
87
112
  def get_grant_type(self) -> str:
88
113
  return self._grant_type
89
114
 
@@ -129,8 +154,12 @@ class SingleUseRefreshTokenOauth2Authenticator(Oauth2Authenticator):
129
154
  expires_in_name: str = "expires_in",
130
155
  refresh_token_name: str = "refresh_token",
131
156
  refresh_request_body: Mapping[str, Any] | None = None,
157
+ refresh_request_headers: Mapping[str, Any] | None = None,
158
+ grant_type_name: str = "grant_type",
132
159
  grant_type: str = "refresh_token",
160
+ client_id_name: str = "client_id",
133
161
  client_id: Optional[str] = None,
162
+ client_secret_name: str = "client_secret",
134
163
  client_secret: Optional[str] = None,
135
164
  access_token_config_path: Sequence[str] = ("credentials", "access_token"),
136
165
  refresh_token_config_path: Sequence[str] = ("credentials", "refresh_token"),
@@ -151,6 +180,7 @@ class SingleUseRefreshTokenOauth2Authenticator(Oauth2Authenticator):
151
180
  expires_in_name (str, optional): Name of the name of the field that characterizes when the current access token will expire, used to parse the refresh token response. Defaults to "expires_in".
152
181
  refresh_token_name (str, optional): Name of the name of the refresh token field, used to parse the refresh token response. Defaults to "refresh_token".
153
182
  refresh_request_body (Mapping[str, Any], optional): Custom key value pair that will be added to the refresh token request body. Defaults to None.
183
+ refresh_request_headers (Mapping[str, Any], optional): Custom key value pair that will be added to the refresh token request headers. Defaults to None.
154
184
  grant_type (str, optional): OAuth grant type. Defaults to "refresh_token".
155
185
  client_id (Optional[str]): The client id to authenticate. If not specified, defaults to credentials.client_id in the config object.
156
186
  client_secret (Optional[str]): The client secret to authenticate. If not specified, defaults to credentials.client_secret in the config object.
@@ -174,23 +204,31 @@ class SingleUseRefreshTokenOauth2Authenticator(Oauth2Authenticator):
174
204
  ("credentials", "client_secret"),
175
205
  )
176
206
  )
207
+ self._client_id_name = client_id_name
208
+ self._client_secret_name = client_secret_name
177
209
  self._access_token_config_path = access_token_config_path
178
210
  self._refresh_token_config_path = refresh_token_config_path
179
211
  self._token_expiry_date_config_path = token_expiry_date_config_path
180
212
  self._token_expiry_date_format = token_expiry_date_format
181
213
  self._refresh_token_name = refresh_token_name
214
+ self._grant_type_name = grant_type_name
182
215
  self._connector_config = connector_config
183
216
  self.__message_repository = message_repository
184
217
  super().__init__(
185
- token_refresh_endpoint,
186
- self.get_client_id(),
187
- self.get_client_secret(),
188
- self.get_refresh_token(),
218
+ token_refresh_endpoint=token_refresh_endpoint,
219
+ client_id_name=self._client_id_name,
220
+ client_id=self.get_client_id(),
221
+ client_secret_name=self._client_secret_name,
222
+ client_secret=self.get_client_secret(),
223
+ refresh_token=self.get_refresh_token(),
224
+ refresh_token_name=self._refresh_token_name,
189
225
  scopes=scopes,
190
226
  token_expiry_date=self.get_token_expiry_date(),
191
227
  access_token_name=access_token_name,
192
228
  expires_in_name=expires_in_name,
193
229
  refresh_request_body=refresh_request_body,
230
+ refresh_request_headers=refresh_request_headers,
231
+ grant_type_name=self._grant_type_name,
194
232
  grant_type=grant_type,
195
233
  token_expiry_date_format=token_expiry_date_format,
196
234
  token_expiry_is_time_of_expiration=token_expiry_is_time_of_expiration,
@@ -152,3 +152,6 @@ class StreamSlice(Mapping[str, Any]):
152
152
 
153
153
  def __hash__(self) -> int:
154
154
  return hash(orjson.dumps(self._stream_slice, option=orjson.OPT_SORT_KEYS))
155
+
156
+ def __bool__(self) -> bool:
157
+ return bool(self._stream_slice) or bool(self._extra_fields)
@@ -9,6 +9,7 @@ from typing import Any, Callable, Dict, Generator, Mapping, Optional, cast
9
9
 
10
10
  from jsonschema import Draft7Validator, RefResolver, ValidationError, Validator, validators
11
11
 
12
+ MAX_NESTING_DEPTH = 3
12
13
  json_to_python_simple = {
13
14
  "string": str,
14
15
  "number": float,
@@ -225,6 +226,31 @@ class TypeTransformer:
225
226
  logger.warning(self.get_error_message(e))
226
227
 
227
228
  def get_error_message(self, e: ValidationError) -> str:
228
- instance_json_type = python_to_json[type(e.instance)]
229
- key_path = "." + ".".join(map(str, e.path))
230
- return f"Failed to transform value {repr(e.instance)} of type '{instance_json_type}' to '{e.validator_value}', key path: '{key_path}'"
229
+ """
230
+ Construct a sanitized error message from a ValidationError instance.
231
+ """
232
+ field_path = ".".join(map(str, e.path))
233
+ type_structure = self._get_type_structure(e.instance)
234
+
235
+ return f"Failed to transform value from type '{type_structure}' to type '{e.validator_value}' at path: '{field_path}'"
236
+
237
+ def _get_type_structure(self, input_data: Any, current_depth: int = 0) -> Any:
238
+ """
239
+ Get the structure of a given input data for use in error message construction.
240
+ """
241
+ # Handle null values
242
+ if input_data is None:
243
+ return "null"
244
+
245
+ # Avoid recursing too deep
246
+ if current_depth >= MAX_NESTING_DEPTH:
247
+ return "object" if isinstance(input_data, dict) else python_to_json[type(input_data)]
248
+
249
+ if isinstance(input_data, dict):
250
+ return {
251
+ key: self._get_type_structure(field_value, current_depth + 1)
252
+ for key, field_value in input_data.items()
253
+ }
254
+
255
+ else:
256
+ return python_to_json[type(input_data)]