airbyte-cdk 6.45.4__py3-none-any.whl → 6.45.4.post13.dev14543731901__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. airbyte_cdk/models/__init__.py +1 -0
  2. airbyte_cdk/models/airbyte_protocol.py +1 -3
  3. airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py +1 -1
  4. airbyte_cdk/sources/declarative/auth/oauth.py +2 -2
  5. airbyte_cdk/sources/declarative/concurrent_declarative_source.py +8 -0
  6. airbyte_cdk/sources/declarative/declarative_component_schema.yaml +36 -0
  7. airbyte_cdk/sources/declarative/extractors/record_selector.py +6 -1
  8. airbyte_cdk/sources/declarative/models/declarative_component_schema.py +31 -0
  9. airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py +39 -1
  10. airbyte_cdk/sources/declarative/retrievers/file_uploader.py +93 -0
  11. airbyte_cdk/sources/declarative/stream_slicers/declarative_partition_generator.py +9 -4
  12. airbyte_cdk/sources/file_based/file_based_stream_reader.py +41 -16
  13. airbyte_cdk/sources/file_based/file_record_data.py +23 -0
  14. airbyte_cdk/sources/file_based/file_types/file_transfer.py +8 -15
  15. airbyte_cdk/sources/file_based/schema_helpers.py +11 -1
  16. airbyte_cdk/sources/file_based/stream/concurrent/adapters.py +3 -12
  17. airbyte_cdk/sources/file_based/stream/default_file_based_stream.py +15 -38
  18. airbyte_cdk/sources/file_based/stream/permissions_file_based_stream.py +1 -3
  19. airbyte_cdk/sources/streams/concurrent/default_stream.py +3 -0
  20. airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py +28 -11
  21. airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py +4 -27
  22. airbyte_cdk/sources/types.py +11 -2
  23. airbyte_cdk/sources/utils/files_directory.py +15 -0
  24. airbyte_cdk/sources/utils/record_helper.py +8 -8
  25. airbyte_cdk/test/entrypoint_wrapper.py +4 -0
  26. airbyte_cdk/test/mock_http/response_builder.py +8 -0
  27. airbyte_cdk/test/standard_tests/__init__.py +46 -0
  28. airbyte_cdk/test/standard_tests/_job_runner.py +159 -0
  29. airbyte_cdk/test/standard_tests/connector_base.py +148 -0
  30. airbyte_cdk/test/standard_tests/declarative_sources.py +92 -0
  31. airbyte_cdk/test/standard_tests/destination_base.py +16 -0
  32. airbyte_cdk/test/standard_tests/models/__init__.py +7 -0
  33. airbyte_cdk/test/standard_tests/models/scenario.py +74 -0
  34. airbyte_cdk/test/standard_tests/pytest_hooks.py +61 -0
  35. airbyte_cdk/test/standard_tests/source_base.py +140 -0
  36. {airbyte_cdk-6.45.4.dist-info → airbyte_cdk-6.45.4.post13.dev14543731901.dist-info}/METADATA +3 -2
  37. {airbyte_cdk-6.45.4.dist-info → airbyte_cdk-6.45.4.post13.dev14543731901.dist-info}/RECORD +41 -30
  38. airbyte_cdk/models/file_transfer_record_message.py +0 -13
  39. {airbyte_cdk-6.45.4.dist-info → airbyte_cdk-6.45.4.post13.dev14543731901.dist-info}/LICENSE.txt +0 -0
  40. {airbyte_cdk-6.45.4.dist-info → airbyte_cdk-6.45.4.post13.dev14543731901.dist-info}/LICENSE_SHORT +0 -0
  41. {airbyte_cdk-6.45.4.dist-info → airbyte_cdk-6.45.4.post13.dev14543731901.dist-info}/WHEEL +0 -0
  42. {airbyte_cdk-6.45.4.dist-info → airbyte_cdk-6.45.4.post13.dev14543731901.dist-info}/entry_points.txt +0 -0
@@ -4,7 +4,7 @@
4
4
 
5
5
  import copy
6
6
  import logging
7
- from functools import cache, lru_cache
7
+ from functools import lru_cache
8
8
  from typing import TYPE_CHECKING, Any, Iterable, List, Mapping, MutableMapping, Optional, Union
9
9
 
10
10
  from typing_extensions import deprecated
@@ -258,19 +258,14 @@ class FileBasedStreamPartition(Partition):
258
258
  and record_data.record is not None
259
259
  ):
260
260
  # `AirbyteMessage`s of type `Record` should also be yielded so they are enqueued
261
- # If stream is flagged for file_transfer the record should data in file key
262
- record_message_data = (
263
- record_data.record.file
264
- if self._use_file_transfer()
265
- else record_data.record.data
266
- )
261
+ record_message_data = record_data.record.data
267
262
  if not record_message_data:
268
263
  raise ExceptionWithDisplayMessage("A record without data was found")
269
264
  else:
270
265
  yield Record(
271
266
  data=record_message_data,
272
267
  stream_name=self.stream_name(),
273
- is_file_transfer_message=self._use_file_transfer(),
268
+ file_reference=record_data.record.file_reference,
274
269
  )
275
270
  else:
276
271
  self._message_repository.emit_message(record_data)
@@ -306,10 +301,6 @@ class FileBasedStreamPartition(Partition):
306
301
  def stream_name(self) -> str:
307
302
  return self._stream.name
308
303
 
309
- @cache
310
- def _use_file_transfer(self) -> bool:
311
- return hasattr(self._stream, "use_file_transfer") and self._stream.use_file_transfer
312
-
313
304
  def __repr__(self) -> str:
314
305
  return f"FileBasedStreamPartition({self._stream.name}, {self._slice})"
315
306
 
@@ -11,7 +11,7 @@ from functools import cache
11
11
  from os import path
12
12
  from typing import Any, Dict, Iterable, List, Mapping, MutableMapping, Optional, Set, Tuple, Union
13
13
 
14
- from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, FailureType, Level
14
+ from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, AirbyteStream, FailureType, Level
15
15
  from airbyte_cdk.models import Type as MessageType
16
16
  from airbyte_cdk.sources.file_based.config.file_based_stream_config import PrimaryKeyType
17
17
  from airbyte_cdk.sources.file_based.exceptions import (
@@ -56,6 +56,7 @@ class DefaultFileBasedStream(AbstractFileBasedStream, IncrementalMixin):
56
56
  airbyte_columns = [ab_last_mod_col, ab_file_name_col]
57
57
  use_file_transfer = False
58
58
  preserve_directory_structure = True
59
+ _file_transfer = FileTransfer()
59
60
 
60
61
  def __init__(self, **kwargs: Any):
61
62
  if self.FILE_TRANSFER_KW in kwargs:
@@ -93,21 +94,6 @@ class DefaultFileBasedStream(AbstractFileBasedStream, IncrementalMixin):
93
94
  self.config
94
95
  )
95
96
 
96
- def _filter_schema_invalid_properties(
97
- self, configured_catalog_json_schema: Dict[str, Any]
98
- ) -> Dict[str, Any]:
99
- if self.use_file_transfer:
100
- return {
101
- "type": "object",
102
- "properties": {
103
- "file_path": {"type": "string"},
104
- "file_size": {"type": "string"},
105
- self.ab_file_name_col: {"type": "string"},
106
- },
107
- }
108
- else:
109
- return super()._filter_schema_invalid_properties(configured_catalog_json_schema)
110
-
111
97
  def _duplicated_files_names(
112
98
  self, slices: List[dict[str, List[RemoteFile]]]
113
99
  ) -> List[dict[str, List[str]]]:
@@ -145,14 +131,6 @@ class DefaultFileBasedStream(AbstractFileBasedStream, IncrementalMixin):
145
131
  record[self.ab_file_name_col] = file.uri
146
132
  return record
147
133
 
148
- def transform_record_for_file_transfer(
149
- self, record: dict[str, Any], file: RemoteFile
150
- ) -> dict[str, Any]:
151
- # timstamp() returns a float representing the number of seconds since the unix epoch
152
- record[self.modified] = int(file.last_modified.timestamp()) * 1000
153
- record[self.source_file_url] = file.uri
154
- return record
155
-
156
134
  def read_records_from_slice(self, stream_slice: StreamSlice) -> Iterable[AirbyteMessage]:
157
135
  """
158
136
  Yield all records from all remote files in `list_files_for_this_sync`.
@@ -173,19 +151,13 @@ class DefaultFileBasedStream(AbstractFileBasedStream, IncrementalMixin):
173
151
 
174
152
  try:
175
153
  if self.use_file_transfer:
176
- self.logger.info(f"{self.name}: {file} file-based syncing")
177
- # todo: complete here the code to not rely on local parser
178
- file_transfer = FileTransfer()
179
- for record in file_transfer.get_file(
180
- self.config, file, self.stream_reader, self.logger
154
+ for file_record_data, file_reference in self._file_transfer.upload(
155
+ file=file, stream_reader=self.stream_reader, logger=self.logger
181
156
  ):
182
- line_no += 1
183
- if not self.record_passes_validation_policy(record):
184
- n_skipped += 1
185
- continue
186
- record = self.transform_record_for_file_transfer(record, file)
187
157
  yield stream_data_to_airbyte_message(
188
- self.name, record, is_file_transfer_message=True
158
+ self.name,
159
+ file_record_data.dict(exclude_none=True),
160
+ file_reference=file_reference,
189
161
  )
190
162
  else:
191
163
  for record in parser.parse_records(
@@ -259,6 +231,8 @@ class DefaultFileBasedStream(AbstractFileBasedStream, IncrementalMixin):
259
231
 
260
232
  @cache
261
233
  def get_json_schema(self) -> JsonSchema:
234
+ if self.use_file_transfer:
235
+ return file_transfer_schema
262
236
  extra_fields = {
263
237
  self.ab_last_mod_col: {"type": "string"},
264
238
  self.ab_file_name_col: {"type": "string"},
@@ -282,9 +256,7 @@ class DefaultFileBasedStream(AbstractFileBasedStream, IncrementalMixin):
282
256
  return {"type": "object", "properties": {**extra_fields, **schema["properties"]}}
283
257
 
284
258
  def _get_raw_json_schema(self) -> JsonSchema:
285
- if self.use_file_transfer:
286
- return file_transfer_schema
287
- elif self.config.input_schema:
259
+ if self.config.input_schema:
288
260
  return self.config.get_input_schema() # type: ignore
289
261
  elif self.config.schemaless:
290
262
  return schemaless_schema
@@ -341,6 +313,11 @@ class DefaultFileBasedStream(AbstractFileBasedStream, IncrementalMixin):
341
313
  self.config.globs or [], self.config.legacy_prefix, self.logger
342
314
  )
343
315
 
316
+ def as_airbyte_stream(self) -> AirbyteStream:
317
+ file_stream = super().as_airbyte_stream()
318
+ file_stream.is_file_based = self.use_file_transfer
319
+ return file_stream
320
+
344
321
  def infer_schema(self, files: List[RemoteFile]) -> Mapping[str, Any]:
345
322
  loop = asyncio.get_event_loop()
346
323
  schema = loop.run_until_complete(self._infer_schema(files))
@@ -61,9 +61,7 @@ class PermissionsFileBasedStream(DefaultFileBasedStream):
61
61
  permissions_record = self.transform_record(
62
62
  permissions_record, file, file_datetime_string
63
63
  )
64
- yield stream_data_to_airbyte_message(
65
- self.name, permissions_record, is_file_transfer_message=False
66
- )
64
+ yield stream_data_to_airbyte_message(self.name, permissions_record)
67
65
  except Exception as e:
68
66
  self.logger.error(f"Failed to retrieve permissions for file {file.uri}: {str(e)}")
69
67
  yield AirbyteMessage(
@@ -29,6 +29,7 @@ class DefaultStream(AbstractStream):
29
29
  logger: Logger,
30
30
  cursor: Cursor,
31
31
  namespace: Optional[str] = None,
32
+ supports_file_transfer: bool = False,
32
33
  ) -> None:
33
34
  self._stream_partition_generator = partition_generator
34
35
  self._name = name
@@ -39,6 +40,7 @@ class DefaultStream(AbstractStream):
39
40
  self._logger = logger
40
41
  self._cursor = cursor
41
42
  self._namespace = namespace
43
+ self._supports_file_transfer = supports_file_transfer
42
44
 
43
45
  def generate_partitions(self) -> Iterable[Partition]:
44
46
  yield from self._stream_partition_generator.generate()
@@ -68,6 +70,7 @@ class DefaultStream(AbstractStream):
68
70
  json_schema=dict(self._json_schema),
69
71
  supported_sync_modes=[SyncMode.full_refresh],
70
72
  is_resumable=False,
73
+ is_file_based=self._supports_file_transfer,
71
74
  )
72
75
 
73
76
  if self._namespace:
@@ -130,7 +130,7 @@ class AbstractOauth2Authenticator(AuthBase):
130
130
  headers = self.get_refresh_request_headers()
131
131
  return headers if headers else None
132
132
 
133
- def refresh_access_token(self) -> Tuple[str, Union[str, int]]:
133
+ def refresh_access_token(self) -> Tuple[str, AirbyteDateTime]:
134
134
  """
135
135
  Returns the refresh token and its expiration datetime
136
136
 
@@ -148,6 +148,14 @@ class AbstractOauth2Authenticator(AuthBase):
148
148
  # PRIVATE METHODS
149
149
  # ----------------
150
150
 
151
+ def _default_token_expiry_date(self) -> AirbyteDateTime:
152
+ """
153
+ Returns the default token expiry date
154
+ """
155
+ # 1 hour was chosen as a middle ground to avoid unnecessary frequent refreshes and token expiration
156
+ default_token_expiry_duration_hours = 1 # 1 hour
157
+ return ab_datetime_now() + timedelta(hours=default_token_expiry_duration_hours)
158
+
151
159
  def _wrap_refresh_token_exception(
152
160
  self, exception: requests.exceptions.RequestException
153
161
  ) -> bool:
@@ -257,14 +265,10 @@ class AbstractOauth2Authenticator(AuthBase):
257
265
 
258
266
  def _parse_token_expiration_date(self, value: Union[str, int]) -> AirbyteDateTime:
259
267
  """
260
- Return the expiration datetime of the refresh token
268
+ Parse a string or integer token expiration date into a datetime object
261
269
 
262
270
  :return: expiration datetime
263
271
  """
264
- if not value and not self.token_has_expired():
265
- # No expiry token was provided but the previous one is not expired so it's fine
266
- return self.get_token_expiry_date()
267
-
268
272
  if self.token_expiry_is_time_of_expiration:
269
273
  if not self.token_expiry_date_format:
270
274
  raise ValueError(
@@ -308,17 +312,30 @@ class AbstractOauth2Authenticator(AuthBase):
308
312
  """
309
313
  return self._find_and_get_value_from_response(response_data, self.get_refresh_token_name())
310
314
 
311
- def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> Any:
315
+ def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> AirbyteDateTime:
312
316
  """
313
317
  Extracts the token_expiry_date, like `expires_in` or `expires_at`, etc from the given response data.
314
318
 
319
+ If the token_expiry_date is not found, it will return an existing token expiry date if set, or a default token expiry date.
320
+
315
321
  Args:
316
322
  response_data (Mapping[str, Any]): The response data from which to extract the token_expiry_date.
317
323
 
318
324
  Returns:
319
- str: The extracted token_expiry_date.
325
+ The extracted token_expiry_date or None if not found.
320
326
  """
321
- return self._find_and_get_value_from_response(response_data, self.get_expires_in_name())
327
+ expires_in = self._find_and_get_value_from_response(
328
+ response_data, self.get_expires_in_name()
329
+ )
330
+ if expires_in is not None:
331
+ return self._parse_token_expiration_date(expires_in)
332
+
333
+ # expires_in is None
334
+ existing_expiry_date = self.get_token_expiry_date()
335
+ if existing_expiry_date and not self.token_has_expired():
336
+ return existing_expiry_date
337
+
338
+ return self._default_token_expiry_date()
322
339
 
323
340
  def _find_and_get_value_from_response(
324
341
  self,
@@ -344,7 +361,7 @@ class AbstractOauth2Authenticator(AuthBase):
344
361
  """
345
362
  if current_depth > max_depth:
346
363
  # this is needed to avoid an inf loop, possible with a very deep nesting observed.
347
- message = f"The maximum level of recursion is reached. Couldn't find the speficied `{key_name}` in the response."
364
+ message = f"The maximum level of recursion is reached. Couldn't find the specified `{key_name}` in the response."
348
365
  raise ResponseKeysMaxRecurtionReached(
349
366
  internal_message=message, message=message, failure_type=FailureType.config_error
350
367
  )
@@ -441,7 +458,7 @@ class AbstractOauth2Authenticator(AuthBase):
441
458
  """Expiration date of the access token"""
442
459
 
443
460
  @abstractmethod
444
- def set_token_expiry_date(self, value: Union[str, int]) -> None:
461
+ def set_token_expiry_date(self, value: AirbyteDateTime) -> None:
445
462
  """Setter for access token expiration date"""
446
463
 
447
464
  @abstractmethod
@@ -120,8 +120,8 @@ class Oauth2Authenticator(AbstractOauth2Authenticator):
120
120
  def get_token_expiry_date(self) -> AirbyteDateTime:
121
121
  return self._token_expiry_date
122
122
 
123
- def set_token_expiry_date(self, value: Union[str, int]) -> None:
124
- self._token_expiry_date = self._parse_token_expiration_date(value)
123
+ def set_token_expiry_date(self, value: AirbyteDateTime) -> None:
124
+ self._token_expiry_date = value
125
125
 
126
126
  @property
127
127
  def token_expiry_is_time_of_expiration(self) -> bool:
@@ -316,26 +316,6 @@ class SingleUseRefreshTokenOauth2Authenticator(Oauth2Authenticator):
316
316
  """Returns True if the token is expired"""
317
317
  return ab_datetime_now() > self.get_token_expiry_date()
318
318
 
319
- @staticmethod
320
- def get_new_token_expiry_date(
321
- access_token_expires_in: str,
322
- token_expiry_date_format: str | None = None,
323
- ) -> AirbyteDateTime:
324
- """
325
- Calculate the new token expiry date based on the provided expiration duration or format.
326
-
327
- Args:
328
- access_token_expires_in (str): The duration (in seconds) until the access token expires, or the expiry date in a specific format.
329
- token_expiry_date_format (str | None, optional): The format of the expiry date if provided. Defaults to None.
330
-
331
- Returns:
332
- AirbyteDateTime: The calculated expiry date of the access token.
333
- """
334
- if token_expiry_date_format:
335
- return ab_datetime_parse(access_token_expires_in)
336
- else:
337
- return ab_datetime_now() + timedelta(seconds=int(access_token_expires_in))
338
-
339
319
  def get_access_token(self) -> str:
340
320
  """Retrieve new access and refresh token if the access token has expired.
341
321
  The new refresh token is persisted with the set_refresh_token function
@@ -346,16 +326,13 @@ class SingleUseRefreshTokenOauth2Authenticator(Oauth2Authenticator):
346
326
  new_access_token, access_token_expires_in, new_refresh_token = (
347
327
  self.refresh_access_token()
348
328
  )
349
- new_token_expiry_date: AirbyteDateTime = self.get_new_token_expiry_date(
350
- access_token_expires_in, self._token_expiry_date_format
351
- )
352
329
  self.access_token = new_access_token
353
330
  self.set_refresh_token(new_refresh_token)
354
- self.set_token_expiry_date(new_token_expiry_date)
331
+ self.set_token_expiry_date(access_token_expires_in)
355
332
  self._emit_control_message()
356
333
  return self.access_token
357
334
 
358
- def refresh_access_token(self) -> Tuple[str, str, str]: # type: ignore[override]
335
+ def refresh_access_token(self) -> Tuple[str, AirbyteDateTime, str]: # type: ignore[override]
359
336
  """
360
337
  Refreshes the access token by making a handled request and extracting the necessary token information.
361
338
 
@@ -6,6 +6,7 @@ from __future__ import annotations
6
6
 
7
7
  from typing import Any, ItemsView, Iterator, KeysView, List, Mapping, Optional, ValuesView
8
8
 
9
+ from airbyte_cdk.models import AirbyteRecordMessageFileReference
9
10
  from airbyte_cdk.utils.slice_hasher import SliceHasher
10
11
 
11
12
  # A FieldPointer designates a path to a field inside a mapping. For example, retrieving ["k1", "k1.2"] in the object {"k1" :{"k1.2":
@@ -23,12 +24,12 @@ class Record(Mapping[str, Any]):
23
24
  data: Mapping[str, Any],
24
25
  stream_name: str,
25
26
  associated_slice: Optional[StreamSlice] = None,
26
- is_file_transfer_message: bool = False,
27
+ file_reference: Optional[AirbyteRecordMessageFileReference] = None,
27
28
  ):
28
29
  self._data = data
29
30
  self._associated_slice = associated_slice
30
31
  self.stream_name = stream_name
31
- self.is_file_transfer_message = is_file_transfer_message
32
+ self._file_reference = file_reference
32
33
 
33
34
  @property
34
35
  def data(self) -> Mapping[str, Any]:
@@ -38,6 +39,14 @@ class Record(Mapping[str, Any]):
38
39
  def associated_slice(self) -> Optional[StreamSlice]:
39
40
  return self._associated_slice
40
41
 
42
+ @property
43
+ def file_reference(self) -> AirbyteRecordMessageFileReference:
44
+ return self._file_reference
45
+
46
+ @file_reference.setter
47
+ def file_reference(self, value: AirbyteRecordMessageFileReference) -> None:
48
+ self._file_reference = value
49
+
41
50
  def __repr__(self) -> str:
42
51
  return repr(self._data)
43
52
 
@@ -0,0 +1,15 @@
1
+ #
2
+ # Copyright (c) 2025 Airbyte, Inc., all rights reserved.
3
+ #
4
+ import os
5
+
6
+ AIRBYTE_STAGING_DIRECTORY = os.getenv("AIRBYTE_STAGING_DIRECTORY", "/staging/files")
7
+ DEFAULT_LOCAL_DIRECTORY = "/tmp/airbyte-file-transfer"
8
+
9
+
10
+ def get_files_directory() -> str:
11
+ return (
12
+ AIRBYTE_STAGING_DIRECTORY
13
+ if os.path.exists(AIRBYTE_STAGING_DIRECTORY)
14
+ else DEFAULT_LOCAL_DIRECTORY
15
+ )
@@ -9,10 +9,10 @@ from airbyte_cdk.models import (
9
9
  AirbyteLogMessage,
10
10
  AirbyteMessage,
11
11
  AirbyteRecordMessage,
12
+ AirbyteRecordMessageFileReference,
12
13
  AirbyteTraceMessage,
13
14
  )
14
15
  from airbyte_cdk.models import Type as MessageType
15
- from airbyte_cdk.models.file_transfer_record_message import AirbyteFileTransferRecordMessage
16
16
  from airbyte_cdk.sources.streams.core import StreamData
17
17
  from airbyte_cdk.sources.utils.transform import TransformConfig, TypeTransformer
18
18
 
@@ -22,7 +22,7 @@ def stream_data_to_airbyte_message(
22
22
  data_or_message: StreamData,
23
23
  transformer: TypeTransformer = TypeTransformer(TransformConfig.NoTransform),
24
24
  schema: Optional[Mapping[str, Any]] = None,
25
- is_file_transfer_message: bool = False,
25
+ file_reference: Optional[AirbyteRecordMessageFileReference] = None,
26
26
  ) -> AirbyteMessage:
27
27
  if schema is None:
28
28
  schema = {}
@@ -36,12 +36,12 @@ def stream_data_to_airbyte_message(
36
36
  # taken unless configured. See
37
37
  # docs/connector-development/cdk-python/schemas.md for details.
38
38
  transformer.transform(data, schema)
39
- if is_file_transfer_message:
40
- message = AirbyteFileTransferRecordMessage(
41
- stream=stream_name, file=data, emitted_at=now_millis, data={}
42
- )
43
- else:
44
- message = AirbyteRecordMessage(stream=stream_name, data=data, emitted_at=now_millis)
39
+ message = AirbyteRecordMessage(
40
+ stream=stream_name,
41
+ data=data,
42
+ emitted_at=now_millis,
43
+ file_reference=file_reference,
44
+ )
45
45
  return AirbyteMessage(type=MessageType.RECORD, record=message)
46
46
  case AirbyteTraceMessage():
47
47
  return AirbyteMessage(type=MessageType.TRACE, trace=data_or_message)
@@ -82,6 +82,10 @@ class EntrypointOutput:
82
82
  def state_messages(self) -> List[AirbyteMessage]:
83
83
  return self._get_message_by_types([Type.STATE])
84
84
 
85
+ @property
86
+ def connection_status_messages(self) -> List[AirbyteMessage]:
87
+ return self._get_message_by_types([Type.CONNECTION_STATUS])
88
+
85
89
  @property
86
90
  def most_recent_state(self) -> Any:
87
91
  state_messages = self._get_message_by_types([Type.STATE])
@@ -198,6 +198,14 @@ def find_template(resource: str, execution_folder: str) -> Dict[str, Any]:
198
198
  return json.load(template_file) # type: ignore # we assume the dev correctly set up the resource file
199
199
 
200
200
 
201
+ def find_binary_response(resource: str, execution_folder: str) -> bytes:
202
+ response_filepath = str(
203
+ get_unit_test_folder(execution_folder) / "resource" / "http" / "response" / f"{resource}"
204
+ )
205
+ with open(response_filepath, "rb") as response_file:
206
+ return response_file.read() # type: ignore # we assume the dev correctly set up the resource file
207
+
208
+
201
209
  def create_record_builder(
202
210
  response_template: Dict[str, Any],
203
211
  records_path: Union[FieldPath, NestedPath],
@@ -0,0 +1,46 @@
1
+ # Copyright (c) 2024 Airbyte, Inc., all rights reserved.
2
+ '''FAST Airbyte Standard Tests
3
+
4
+ This module provides a set of base classes for declarative connector test suites.
5
+ The goal of this module is to provide a robust and extensible framework for testing Airbyte
6
+ connectors.
7
+
8
+ Example usage:
9
+
10
+ ```python
11
+ # `test_airbyte_standards.py`
12
+ from airbyte_cdk.test import standard_tests
13
+
14
+ pytest_plugins = [
15
+ "airbyte_cdk.test.standard_tests.pytest_hooks",
16
+ ]
17
+
18
+
19
+ class TestSuiteSourcePokeAPI(standard_tests.DeclarativeSourceTestSuite):
20
+ """Test suite for the source."""
21
+ ```
22
+
23
+ Available test suites base classes:
24
+ - `DeclarativeSourceTestSuite`: A test suite for declarative sources.
25
+ - `SourceTestSuiteBase`: A test suite for sources.
26
+ - `DestinationTestSuiteBase`: A test suite for destinations.
27
+
28
+ '''
29
+
30
+ from airbyte_cdk.test.standard_tests.connector_base import (
31
+ ConnectorTestScenario,
32
+ ConnectorTestSuiteBase,
33
+ )
34
+ from airbyte_cdk.test.standard_tests.declarative_sources import (
35
+ DeclarativeSourceTestSuite,
36
+ )
37
+ from airbyte_cdk.test.standard_tests.destination_base import DestinationTestSuiteBase
38
+ from airbyte_cdk.test.standard_tests.source_base import SourceTestSuiteBase
39
+
40
+ __all__ = [
41
+ "ConnectorTestScenario",
42
+ "ConnectorTestSuiteBase",
43
+ "DeclarativeSourceTestSuite",
44
+ "DestinationTestSuiteBase",
45
+ "SourceTestSuiteBase",
46
+ ]
@@ -0,0 +1,159 @@
1
+ # Copyright (c) 2025 Airbyte, Inc., all rights reserved.
2
+ """Job runner for Airbyte Standard Tests."""
3
+
4
+ import logging
5
+ import tempfile
6
+ import uuid
7
+ from dataclasses import asdict
8
+ from pathlib import Path
9
+ from typing import Any, Callable, Literal
10
+
11
+ import orjson
12
+ from typing_extensions import Protocol, runtime_checkable
13
+
14
+ from airbyte_cdk.models import (
15
+ ConfiguredAirbyteCatalog,
16
+ Status,
17
+ )
18
+ from airbyte_cdk.test import entrypoint_wrapper
19
+ from airbyte_cdk.test.standard_tests.models import (
20
+ ConnectorTestScenario,
21
+ )
22
+
23
+
24
+ def _errors_to_str(
25
+ entrypoint_output: entrypoint_wrapper.EntrypointOutput,
26
+ ) -> str:
27
+ """Convert errors from entrypoint output to a string."""
28
+ if not entrypoint_output.errors:
29
+ # If there are no errors, return an empty string.
30
+ return ""
31
+
32
+ return "\n" + "\n".join(
33
+ [
34
+ str(error.trace.error).replace(
35
+ "\\n",
36
+ "\n",
37
+ )
38
+ for error in entrypoint_output.errors
39
+ if error.trace
40
+ ],
41
+ )
42
+
43
+
44
+ @runtime_checkable
45
+ class IConnector(Protocol):
46
+ """A connector that can be run in a test scenario.
47
+
48
+ Note: We currently use 'spec' to determine if we have a connector object.
49
+ In the future, it would be preferred to leverage a 'launch' method instead,
50
+ directly on the connector (which doesn't yet exist).
51
+ """
52
+
53
+ def spec(self, logger: logging.Logger) -> Any:
54
+ """Connectors should have a `spec` method."""
55
+
56
+
57
+ def run_test_job(
58
+ connector: IConnector | type[IConnector] | Callable[[], IConnector],
59
+ verb: Literal["read", "check", "discover"],
60
+ test_scenario: ConnectorTestScenario,
61
+ *,
62
+ catalog: ConfiguredAirbyteCatalog | dict[str, Any] | None = None,
63
+ ) -> entrypoint_wrapper.EntrypointOutput:
64
+ """Run a test scenario from provided CLI args and return the result."""
65
+ if not connector:
66
+ raise ValueError("Connector is required")
67
+
68
+ if catalog and isinstance(catalog, ConfiguredAirbyteCatalog):
69
+ # Convert the catalog to a dict if it's already a ConfiguredAirbyteCatalog.
70
+ catalog = asdict(catalog)
71
+
72
+ connector_obj: IConnector
73
+ if isinstance(connector, type) or callable(connector):
74
+ # If the connector is a class or a factory lambda, instantiate it.
75
+ connector_obj = connector()
76
+ elif isinstance(connector, IConnector):
77
+ connector_obj = connector
78
+ else:
79
+ raise ValueError(
80
+ f"Invalid connector input: {type(connector)}",
81
+ )
82
+
83
+ args: list[str] = [verb]
84
+ if test_scenario.config_path:
85
+ args += ["--config", str(test_scenario.config_path)]
86
+ elif test_scenario.config_dict:
87
+ config_path = (
88
+ Path(tempfile.gettempdir()) / "airbyte-test" / f"temp_config_{uuid.uuid4().hex}.json"
89
+ )
90
+ config_path.parent.mkdir(parents=True, exist_ok=True)
91
+ config_path.write_text(orjson.dumps(test_scenario.config_dict).decode())
92
+ args += ["--config", str(config_path)]
93
+
94
+ catalog_path: Path | None = None
95
+ if verb not in ["discover", "check"]:
96
+ # We need a catalog for read.
97
+ if catalog:
98
+ # Write the catalog to a temp json file and pass the path to the file as an argument.
99
+ catalog_path = (
100
+ Path(tempfile.gettempdir())
101
+ / "airbyte-test"
102
+ / f"temp_catalog_{uuid.uuid4().hex}.json"
103
+ )
104
+ catalog_path.parent.mkdir(parents=True, exist_ok=True)
105
+ catalog_path.write_text(orjson.dumps(catalog).decode())
106
+ elif test_scenario.configured_catalog_path:
107
+ catalog_path = Path(test_scenario.configured_catalog_path)
108
+
109
+ if catalog_path:
110
+ args += ["--catalog", str(catalog_path)]
111
+
112
+ # This is a bit of a hack because the source needs the catalog early.
113
+ # Because it *also* can fail, we have to redundantly wrap it in a try/except block.
114
+
115
+ result: entrypoint_wrapper.EntrypointOutput = entrypoint_wrapper._run_command( # noqa: SLF001 # Non-public API
116
+ source=connector_obj, # type: ignore [arg-type]
117
+ args=args,
118
+ expecting_exception=test_scenario.expect_exception,
119
+ )
120
+ if result.errors and not test_scenario.expect_exception:
121
+ raise AssertionError(
122
+ f"Expected no errors but got {len(result.errors)}: \n" + _errors_to_str(result)
123
+ )
124
+
125
+ if verb == "check":
126
+ # Check is expected to fail gracefully without an exception.
127
+ # Instead, we assert that we have a CONNECTION_STATUS message with
128
+ # a failure status.
129
+ assert len(result.connection_status_messages) == 1, (
130
+ "Expected exactly one CONNECTION_STATUS message. Got "
131
+ f"{len(result.connection_status_messages)}:\n"
132
+ + "\n".join([str(msg) for msg in result.connection_status_messages])
133
+ + _errors_to_str(result)
134
+ )
135
+ if test_scenario.expect_exception:
136
+ conn_status = result.connection_status_messages[0].connectionStatus
137
+ assert conn_status, (
138
+ "Expected CONNECTION_STATUS message to be present. Got: \n"
139
+ + "\n".join([str(msg) for msg in result.connection_status_messages])
140
+ )
141
+ assert conn_status.status == Status.FAILED, (
142
+ "Expected CONNECTION_STATUS message to be FAILED. Got: \n"
143
+ + "\n".join([str(msg) for msg in result.connection_status_messages])
144
+ )
145
+
146
+ return result
147
+
148
+ # For all other verbs, we assert check that an exception is raised (or not).
149
+ if test_scenario.expect_exception:
150
+ if not result.errors:
151
+ raise AssertionError("Expected exception but got none.")
152
+
153
+ return result
154
+
155
+ assert not result.errors, (
156
+ f"Expected no errors but got {len(result.errors)}: \n" + _errors_to_str(result)
157
+ )
158
+
159
+ return result