airbyte-cdk 6.45.4__py3-none-any.whl → 6.45.4.post14.dev14544463167__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airbyte_cdk/models/__init__.py +1 -0
- airbyte_cdk/models/airbyte_protocol.py +1 -3
- airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py +1 -1
- airbyte_cdk/sources/declarative/auth/oauth.py +2 -2
- airbyte_cdk/sources/declarative/concurrent_declarative_source.py +8 -0
- airbyte_cdk/sources/declarative/declarative_component_schema.yaml +36 -0
- airbyte_cdk/sources/declarative/extractors/record_selector.py +6 -1
- airbyte_cdk/sources/declarative/models/declarative_component_schema.py +31 -0
- airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py +39 -1
- airbyte_cdk/sources/declarative/retrievers/file_uploader.py +93 -0
- airbyte_cdk/sources/declarative/stream_slicers/declarative_partition_generator.py +9 -4
- airbyte_cdk/sources/file_based/file_based_stream_reader.py +38 -16
- airbyte_cdk/sources/file_based/file_record_data.py +23 -0
- airbyte_cdk/sources/file_based/file_types/file_transfer.py +8 -15
- airbyte_cdk/sources/file_based/schema_helpers.py +11 -1
- airbyte_cdk/sources/file_based/stream/concurrent/adapters.py +3 -12
- airbyte_cdk/sources/file_based/stream/default_file_based_stream.py +15 -38
- airbyte_cdk/sources/file_based/stream/permissions_file_based_stream.py +1 -3
- airbyte_cdk/sources/streams/concurrent/default_stream.py +3 -0
- airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py +28 -11
- airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py +4 -27
- airbyte_cdk/sources/types.py +11 -2
- airbyte_cdk/sources/utils/files_directory.py +15 -0
- airbyte_cdk/sources/utils/record_helper.py +8 -8
- airbyte_cdk/test/entrypoint_wrapper.py +4 -0
- airbyte_cdk/test/mock_http/response_builder.py +8 -0
- airbyte_cdk/test/standard_tests/__init__.py +46 -0
- airbyte_cdk/test/standard_tests/_job_runner.py +159 -0
- airbyte_cdk/test/standard_tests/connector_base.py +148 -0
- airbyte_cdk/test/standard_tests/declarative_sources.py +92 -0
- airbyte_cdk/test/standard_tests/destination_base.py +16 -0
- airbyte_cdk/test/standard_tests/models/__init__.py +7 -0
- airbyte_cdk/test/standard_tests/models/scenario.py +74 -0
- airbyte_cdk/test/standard_tests/pytest_hooks.py +61 -0
- airbyte_cdk/test/standard_tests/source_base.py +140 -0
- {airbyte_cdk-6.45.4.dist-info → airbyte_cdk-6.45.4.post14.dev14544463167.dist-info}/METADATA +3 -2
- {airbyte_cdk-6.45.4.dist-info → airbyte_cdk-6.45.4.post14.dev14544463167.dist-info}/RECORD +41 -30
- airbyte_cdk/models/file_transfer_record_message.py +0 -13
- {airbyte_cdk-6.45.4.dist-info → airbyte_cdk-6.45.4.post14.dev14544463167.dist-info}/LICENSE.txt +0 -0
- {airbyte_cdk-6.45.4.dist-info → airbyte_cdk-6.45.4.post14.dev14544463167.dist-info}/LICENSE_SHORT +0 -0
- {airbyte_cdk-6.45.4.dist-info → airbyte_cdk-6.45.4.post14.dev14544463167.dist-info}/WHEEL +0 -0
- {airbyte_cdk-6.45.4.dist-info → airbyte_cdk-6.45.4.post14.dev14544463167.dist-info}/entry_points.txt +0 -0
@@ -4,7 +4,7 @@
|
|
4
4
|
|
5
5
|
import copy
|
6
6
|
import logging
|
7
|
-
from functools import
|
7
|
+
from functools import lru_cache
|
8
8
|
from typing import TYPE_CHECKING, Any, Iterable, List, Mapping, MutableMapping, Optional, Union
|
9
9
|
|
10
10
|
from typing_extensions import deprecated
|
@@ -258,19 +258,14 @@ class FileBasedStreamPartition(Partition):
|
|
258
258
|
and record_data.record is not None
|
259
259
|
):
|
260
260
|
# `AirbyteMessage`s of type `Record` should also be yielded so they are enqueued
|
261
|
-
|
262
|
-
record_message_data = (
|
263
|
-
record_data.record.file
|
264
|
-
if self._use_file_transfer()
|
265
|
-
else record_data.record.data
|
266
|
-
)
|
261
|
+
record_message_data = record_data.record.data
|
267
262
|
if not record_message_data:
|
268
263
|
raise ExceptionWithDisplayMessage("A record without data was found")
|
269
264
|
else:
|
270
265
|
yield Record(
|
271
266
|
data=record_message_data,
|
272
267
|
stream_name=self.stream_name(),
|
273
|
-
|
268
|
+
file_reference=record_data.record.file_reference,
|
274
269
|
)
|
275
270
|
else:
|
276
271
|
self._message_repository.emit_message(record_data)
|
@@ -306,10 +301,6 @@ class FileBasedStreamPartition(Partition):
|
|
306
301
|
def stream_name(self) -> str:
|
307
302
|
return self._stream.name
|
308
303
|
|
309
|
-
@cache
|
310
|
-
def _use_file_transfer(self) -> bool:
|
311
|
-
return hasattr(self._stream, "use_file_transfer") and self._stream.use_file_transfer
|
312
|
-
|
313
304
|
def __repr__(self) -> str:
|
314
305
|
return f"FileBasedStreamPartition({self._stream.name}, {self._slice})"
|
315
306
|
|
@@ -11,7 +11,7 @@ from functools import cache
|
|
11
11
|
from os import path
|
12
12
|
from typing import Any, Dict, Iterable, List, Mapping, MutableMapping, Optional, Set, Tuple, Union
|
13
13
|
|
14
|
-
from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, FailureType, Level
|
14
|
+
from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, AirbyteStream, FailureType, Level
|
15
15
|
from airbyte_cdk.models import Type as MessageType
|
16
16
|
from airbyte_cdk.sources.file_based.config.file_based_stream_config import PrimaryKeyType
|
17
17
|
from airbyte_cdk.sources.file_based.exceptions import (
|
@@ -56,6 +56,7 @@ class DefaultFileBasedStream(AbstractFileBasedStream, IncrementalMixin):
|
|
56
56
|
airbyte_columns = [ab_last_mod_col, ab_file_name_col]
|
57
57
|
use_file_transfer = False
|
58
58
|
preserve_directory_structure = True
|
59
|
+
_file_transfer = FileTransfer()
|
59
60
|
|
60
61
|
def __init__(self, **kwargs: Any):
|
61
62
|
if self.FILE_TRANSFER_KW in kwargs:
|
@@ -93,21 +94,6 @@ class DefaultFileBasedStream(AbstractFileBasedStream, IncrementalMixin):
|
|
93
94
|
self.config
|
94
95
|
)
|
95
96
|
|
96
|
-
def _filter_schema_invalid_properties(
|
97
|
-
self, configured_catalog_json_schema: Dict[str, Any]
|
98
|
-
) -> Dict[str, Any]:
|
99
|
-
if self.use_file_transfer:
|
100
|
-
return {
|
101
|
-
"type": "object",
|
102
|
-
"properties": {
|
103
|
-
"file_path": {"type": "string"},
|
104
|
-
"file_size": {"type": "string"},
|
105
|
-
self.ab_file_name_col: {"type": "string"},
|
106
|
-
},
|
107
|
-
}
|
108
|
-
else:
|
109
|
-
return super()._filter_schema_invalid_properties(configured_catalog_json_schema)
|
110
|
-
|
111
97
|
def _duplicated_files_names(
|
112
98
|
self, slices: List[dict[str, List[RemoteFile]]]
|
113
99
|
) -> List[dict[str, List[str]]]:
|
@@ -145,14 +131,6 @@ class DefaultFileBasedStream(AbstractFileBasedStream, IncrementalMixin):
|
|
145
131
|
record[self.ab_file_name_col] = file.uri
|
146
132
|
return record
|
147
133
|
|
148
|
-
def transform_record_for_file_transfer(
|
149
|
-
self, record: dict[str, Any], file: RemoteFile
|
150
|
-
) -> dict[str, Any]:
|
151
|
-
# timstamp() returns a float representing the number of seconds since the unix epoch
|
152
|
-
record[self.modified] = int(file.last_modified.timestamp()) * 1000
|
153
|
-
record[self.source_file_url] = file.uri
|
154
|
-
return record
|
155
|
-
|
156
134
|
def read_records_from_slice(self, stream_slice: StreamSlice) -> Iterable[AirbyteMessage]:
|
157
135
|
"""
|
158
136
|
Yield all records from all remote files in `list_files_for_this_sync`.
|
@@ -173,19 +151,13 @@ class DefaultFileBasedStream(AbstractFileBasedStream, IncrementalMixin):
|
|
173
151
|
|
174
152
|
try:
|
175
153
|
if self.use_file_transfer:
|
176
|
-
self.
|
177
|
-
|
178
|
-
file_transfer = FileTransfer()
|
179
|
-
for record in file_transfer.get_file(
|
180
|
-
self.config, file, self.stream_reader, self.logger
|
154
|
+
for file_record_data, file_reference in self._file_transfer.upload(
|
155
|
+
file=file, stream_reader=self.stream_reader, logger=self.logger
|
181
156
|
):
|
182
|
-
line_no += 1
|
183
|
-
if not self.record_passes_validation_policy(record):
|
184
|
-
n_skipped += 1
|
185
|
-
continue
|
186
|
-
record = self.transform_record_for_file_transfer(record, file)
|
187
157
|
yield stream_data_to_airbyte_message(
|
188
|
-
self.name,
|
158
|
+
self.name,
|
159
|
+
file_record_data.dict(exclude_none=True),
|
160
|
+
file_reference=file_reference,
|
189
161
|
)
|
190
162
|
else:
|
191
163
|
for record in parser.parse_records(
|
@@ -259,6 +231,8 @@ class DefaultFileBasedStream(AbstractFileBasedStream, IncrementalMixin):
|
|
259
231
|
|
260
232
|
@cache
|
261
233
|
def get_json_schema(self) -> JsonSchema:
|
234
|
+
if self.use_file_transfer:
|
235
|
+
return file_transfer_schema
|
262
236
|
extra_fields = {
|
263
237
|
self.ab_last_mod_col: {"type": "string"},
|
264
238
|
self.ab_file_name_col: {"type": "string"},
|
@@ -282,9 +256,7 @@ class DefaultFileBasedStream(AbstractFileBasedStream, IncrementalMixin):
|
|
282
256
|
return {"type": "object", "properties": {**extra_fields, **schema["properties"]}}
|
283
257
|
|
284
258
|
def _get_raw_json_schema(self) -> JsonSchema:
|
285
|
-
if self.
|
286
|
-
return file_transfer_schema
|
287
|
-
elif self.config.input_schema:
|
259
|
+
if self.config.input_schema:
|
288
260
|
return self.config.get_input_schema() # type: ignore
|
289
261
|
elif self.config.schemaless:
|
290
262
|
return schemaless_schema
|
@@ -341,6 +313,11 @@ class DefaultFileBasedStream(AbstractFileBasedStream, IncrementalMixin):
|
|
341
313
|
self.config.globs or [], self.config.legacy_prefix, self.logger
|
342
314
|
)
|
343
315
|
|
316
|
+
def as_airbyte_stream(self) -> AirbyteStream:
|
317
|
+
file_stream = super().as_airbyte_stream()
|
318
|
+
file_stream.is_file_based = self.use_file_transfer
|
319
|
+
return file_stream
|
320
|
+
|
344
321
|
def infer_schema(self, files: List[RemoteFile]) -> Mapping[str, Any]:
|
345
322
|
loop = asyncio.get_event_loop()
|
346
323
|
schema = loop.run_until_complete(self._infer_schema(files))
|
@@ -61,9 +61,7 @@ class PermissionsFileBasedStream(DefaultFileBasedStream):
|
|
61
61
|
permissions_record = self.transform_record(
|
62
62
|
permissions_record, file, file_datetime_string
|
63
63
|
)
|
64
|
-
yield stream_data_to_airbyte_message(
|
65
|
-
self.name, permissions_record, is_file_transfer_message=False
|
66
|
-
)
|
64
|
+
yield stream_data_to_airbyte_message(self.name, permissions_record)
|
67
65
|
except Exception as e:
|
68
66
|
self.logger.error(f"Failed to retrieve permissions for file {file.uri}: {str(e)}")
|
69
67
|
yield AirbyteMessage(
|
@@ -29,6 +29,7 @@ class DefaultStream(AbstractStream):
|
|
29
29
|
logger: Logger,
|
30
30
|
cursor: Cursor,
|
31
31
|
namespace: Optional[str] = None,
|
32
|
+
supports_file_transfer: bool = False,
|
32
33
|
) -> None:
|
33
34
|
self._stream_partition_generator = partition_generator
|
34
35
|
self._name = name
|
@@ -39,6 +40,7 @@ class DefaultStream(AbstractStream):
|
|
39
40
|
self._logger = logger
|
40
41
|
self._cursor = cursor
|
41
42
|
self._namespace = namespace
|
43
|
+
self._supports_file_transfer = supports_file_transfer
|
42
44
|
|
43
45
|
def generate_partitions(self) -> Iterable[Partition]:
|
44
46
|
yield from self._stream_partition_generator.generate()
|
@@ -68,6 +70,7 @@ class DefaultStream(AbstractStream):
|
|
68
70
|
json_schema=dict(self._json_schema),
|
69
71
|
supported_sync_modes=[SyncMode.full_refresh],
|
70
72
|
is_resumable=False,
|
73
|
+
is_file_based=self._supports_file_transfer,
|
71
74
|
)
|
72
75
|
|
73
76
|
if self._namespace:
|
@@ -130,7 +130,7 @@ class AbstractOauth2Authenticator(AuthBase):
|
|
130
130
|
headers = self.get_refresh_request_headers()
|
131
131
|
return headers if headers else None
|
132
132
|
|
133
|
-
def refresh_access_token(self) -> Tuple[str,
|
133
|
+
def refresh_access_token(self) -> Tuple[str, AirbyteDateTime]:
|
134
134
|
"""
|
135
135
|
Returns the refresh token and its expiration datetime
|
136
136
|
|
@@ -148,6 +148,14 @@ class AbstractOauth2Authenticator(AuthBase):
|
|
148
148
|
# PRIVATE METHODS
|
149
149
|
# ----------------
|
150
150
|
|
151
|
+
def _default_token_expiry_date(self) -> AirbyteDateTime:
|
152
|
+
"""
|
153
|
+
Returns the default token expiry date
|
154
|
+
"""
|
155
|
+
# 1 hour was chosen as a middle ground to avoid unnecessary frequent refreshes and token expiration
|
156
|
+
default_token_expiry_duration_hours = 1 # 1 hour
|
157
|
+
return ab_datetime_now() + timedelta(hours=default_token_expiry_duration_hours)
|
158
|
+
|
151
159
|
def _wrap_refresh_token_exception(
|
152
160
|
self, exception: requests.exceptions.RequestException
|
153
161
|
) -> bool:
|
@@ -257,14 +265,10 @@ class AbstractOauth2Authenticator(AuthBase):
|
|
257
265
|
|
258
266
|
def _parse_token_expiration_date(self, value: Union[str, int]) -> AirbyteDateTime:
|
259
267
|
"""
|
260
|
-
|
268
|
+
Parse a string or integer token expiration date into a datetime object
|
261
269
|
|
262
270
|
:return: expiration datetime
|
263
271
|
"""
|
264
|
-
if not value and not self.token_has_expired():
|
265
|
-
# No expiry token was provided but the previous one is not expired so it's fine
|
266
|
-
return self.get_token_expiry_date()
|
267
|
-
|
268
272
|
if self.token_expiry_is_time_of_expiration:
|
269
273
|
if not self.token_expiry_date_format:
|
270
274
|
raise ValueError(
|
@@ -308,17 +312,30 @@ class AbstractOauth2Authenticator(AuthBase):
|
|
308
312
|
"""
|
309
313
|
return self._find_and_get_value_from_response(response_data, self.get_refresh_token_name())
|
310
314
|
|
311
|
-
def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) ->
|
315
|
+
def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> AirbyteDateTime:
|
312
316
|
"""
|
313
317
|
Extracts the token_expiry_date, like `expires_in` or `expires_at`, etc from the given response data.
|
314
318
|
|
319
|
+
If the token_expiry_date is not found, it will return an existing token expiry date if set, or a default token expiry date.
|
320
|
+
|
315
321
|
Args:
|
316
322
|
response_data (Mapping[str, Any]): The response data from which to extract the token_expiry_date.
|
317
323
|
|
318
324
|
Returns:
|
319
|
-
|
325
|
+
The extracted token_expiry_date or None if not found.
|
320
326
|
"""
|
321
|
-
|
327
|
+
expires_in = self._find_and_get_value_from_response(
|
328
|
+
response_data, self.get_expires_in_name()
|
329
|
+
)
|
330
|
+
if expires_in is not None:
|
331
|
+
return self._parse_token_expiration_date(expires_in)
|
332
|
+
|
333
|
+
# expires_in is None
|
334
|
+
existing_expiry_date = self.get_token_expiry_date()
|
335
|
+
if existing_expiry_date and not self.token_has_expired():
|
336
|
+
return existing_expiry_date
|
337
|
+
|
338
|
+
return self._default_token_expiry_date()
|
322
339
|
|
323
340
|
def _find_and_get_value_from_response(
|
324
341
|
self,
|
@@ -344,7 +361,7 @@ class AbstractOauth2Authenticator(AuthBase):
|
|
344
361
|
"""
|
345
362
|
if current_depth > max_depth:
|
346
363
|
# this is needed to avoid an inf loop, possible with a very deep nesting observed.
|
347
|
-
message = f"The maximum level of recursion is reached. Couldn't find the
|
364
|
+
message = f"The maximum level of recursion is reached. Couldn't find the specified `{key_name}` in the response."
|
348
365
|
raise ResponseKeysMaxRecurtionReached(
|
349
366
|
internal_message=message, message=message, failure_type=FailureType.config_error
|
350
367
|
)
|
@@ -441,7 +458,7 @@ class AbstractOauth2Authenticator(AuthBase):
|
|
441
458
|
"""Expiration date of the access token"""
|
442
459
|
|
443
460
|
@abstractmethod
|
444
|
-
def set_token_expiry_date(self, value:
|
461
|
+
def set_token_expiry_date(self, value: AirbyteDateTime) -> None:
|
445
462
|
"""Setter for access token expiration date"""
|
446
463
|
|
447
464
|
@abstractmethod
|
@@ -120,8 +120,8 @@ class Oauth2Authenticator(AbstractOauth2Authenticator):
|
|
120
120
|
def get_token_expiry_date(self) -> AirbyteDateTime:
|
121
121
|
return self._token_expiry_date
|
122
122
|
|
123
|
-
def set_token_expiry_date(self, value:
|
124
|
-
self._token_expiry_date =
|
123
|
+
def set_token_expiry_date(self, value: AirbyteDateTime) -> None:
|
124
|
+
self._token_expiry_date = value
|
125
125
|
|
126
126
|
@property
|
127
127
|
def token_expiry_is_time_of_expiration(self) -> bool:
|
@@ -316,26 +316,6 @@ class SingleUseRefreshTokenOauth2Authenticator(Oauth2Authenticator):
|
|
316
316
|
"""Returns True if the token is expired"""
|
317
317
|
return ab_datetime_now() > self.get_token_expiry_date()
|
318
318
|
|
319
|
-
@staticmethod
|
320
|
-
def get_new_token_expiry_date(
|
321
|
-
access_token_expires_in: str,
|
322
|
-
token_expiry_date_format: str | None = None,
|
323
|
-
) -> AirbyteDateTime:
|
324
|
-
"""
|
325
|
-
Calculate the new token expiry date based on the provided expiration duration or format.
|
326
|
-
|
327
|
-
Args:
|
328
|
-
access_token_expires_in (str): The duration (in seconds) until the access token expires, or the expiry date in a specific format.
|
329
|
-
token_expiry_date_format (str | None, optional): The format of the expiry date if provided. Defaults to None.
|
330
|
-
|
331
|
-
Returns:
|
332
|
-
AirbyteDateTime: The calculated expiry date of the access token.
|
333
|
-
"""
|
334
|
-
if token_expiry_date_format:
|
335
|
-
return ab_datetime_parse(access_token_expires_in)
|
336
|
-
else:
|
337
|
-
return ab_datetime_now() + timedelta(seconds=int(access_token_expires_in))
|
338
|
-
|
339
319
|
def get_access_token(self) -> str:
|
340
320
|
"""Retrieve new access and refresh token if the access token has expired.
|
341
321
|
The new refresh token is persisted with the set_refresh_token function
|
@@ -346,16 +326,13 @@ class SingleUseRefreshTokenOauth2Authenticator(Oauth2Authenticator):
|
|
346
326
|
new_access_token, access_token_expires_in, new_refresh_token = (
|
347
327
|
self.refresh_access_token()
|
348
328
|
)
|
349
|
-
new_token_expiry_date: AirbyteDateTime = self.get_new_token_expiry_date(
|
350
|
-
access_token_expires_in, self._token_expiry_date_format
|
351
|
-
)
|
352
329
|
self.access_token = new_access_token
|
353
330
|
self.set_refresh_token(new_refresh_token)
|
354
|
-
self.set_token_expiry_date(
|
331
|
+
self.set_token_expiry_date(access_token_expires_in)
|
355
332
|
self._emit_control_message()
|
356
333
|
return self.access_token
|
357
334
|
|
358
|
-
def refresh_access_token(self) -> Tuple[str,
|
335
|
+
def refresh_access_token(self) -> Tuple[str, AirbyteDateTime, str]: # type: ignore[override]
|
359
336
|
"""
|
360
337
|
Refreshes the access token by making a handled request and extracting the necessary token information.
|
361
338
|
|
airbyte_cdk/sources/types.py
CHANGED
@@ -6,6 +6,7 @@ from __future__ import annotations
|
|
6
6
|
|
7
7
|
from typing import Any, ItemsView, Iterator, KeysView, List, Mapping, Optional, ValuesView
|
8
8
|
|
9
|
+
from airbyte_cdk.models import AirbyteRecordMessageFileReference
|
9
10
|
from airbyte_cdk.utils.slice_hasher import SliceHasher
|
10
11
|
|
11
12
|
# A FieldPointer designates a path to a field inside a mapping. For example, retrieving ["k1", "k1.2"] in the object {"k1" :{"k1.2":
|
@@ -23,12 +24,12 @@ class Record(Mapping[str, Any]):
|
|
23
24
|
data: Mapping[str, Any],
|
24
25
|
stream_name: str,
|
25
26
|
associated_slice: Optional[StreamSlice] = None,
|
26
|
-
|
27
|
+
file_reference: Optional[AirbyteRecordMessageFileReference] = None,
|
27
28
|
):
|
28
29
|
self._data = data
|
29
30
|
self._associated_slice = associated_slice
|
30
31
|
self.stream_name = stream_name
|
31
|
-
self.
|
32
|
+
self._file_reference = file_reference
|
32
33
|
|
33
34
|
@property
|
34
35
|
def data(self) -> Mapping[str, Any]:
|
@@ -38,6 +39,14 @@ class Record(Mapping[str, Any]):
|
|
38
39
|
def associated_slice(self) -> Optional[StreamSlice]:
|
39
40
|
return self._associated_slice
|
40
41
|
|
42
|
+
@property
|
43
|
+
def file_reference(self) -> AirbyteRecordMessageFileReference:
|
44
|
+
return self._file_reference
|
45
|
+
|
46
|
+
@file_reference.setter
|
47
|
+
def file_reference(self, value: AirbyteRecordMessageFileReference) -> None:
|
48
|
+
self._file_reference = value
|
49
|
+
|
41
50
|
def __repr__(self) -> str:
|
42
51
|
return repr(self._data)
|
43
52
|
|
@@ -0,0 +1,15 @@
|
|
1
|
+
#
|
2
|
+
# Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
3
|
+
#
|
4
|
+
import os
|
5
|
+
|
6
|
+
AIRBYTE_STAGING_DIRECTORY = os.getenv("AIRBYTE_STAGING_DIRECTORY", "/staging/files")
|
7
|
+
DEFAULT_LOCAL_DIRECTORY = "/tmp/airbyte-file-transfer"
|
8
|
+
|
9
|
+
|
10
|
+
def get_files_directory() -> str:
|
11
|
+
return (
|
12
|
+
AIRBYTE_STAGING_DIRECTORY
|
13
|
+
if os.path.exists(AIRBYTE_STAGING_DIRECTORY)
|
14
|
+
else DEFAULT_LOCAL_DIRECTORY
|
15
|
+
)
|
@@ -9,10 +9,10 @@ from airbyte_cdk.models import (
|
|
9
9
|
AirbyteLogMessage,
|
10
10
|
AirbyteMessage,
|
11
11
|
AirbyteRecordMessage,
|
12
|
+
AirbyteRecordMessageFileReference,
|
12
13
|
AirbyteTraceMessage,
|
13
14
|
)
|
14
15
|
from airbyte_cdk.models import Type as MessageType
|
15
|
-
from airbyte_cdk.models.file_transfer_record_message import AirbyteFileTransferRecordMessage
|
16
16
|
from airbyte_cdk.sources.streams.core import StreamData
|
17
17
|
from airbyte_cdk.sources.utils.transform import TransformConfig, TypeTransformer
|
18
18
|
|
@@ -22,7 +22,7 @@ def stream_data_to_airbyte_message(
|
|
22
22
|
data_or_message: StreamData,
|
23
23
|
transformer: TypeTransformer = TypeTransformer(TransformConfig.NoTransform),
|
24
24
|
schema: Optional[Mapping[str, Any]] = None,
|
25
|
-
|
25
|
+
file_reference: Optional[AirbyteRecordMessageFileReference] = None,
|
26
26
|
) -> AirbyteMessage:
|
27
27
|
if schema is None:
|
28
28
|
schema = {}
|
@@ -36,12 +36,12 @@ def stream_data_to_airbyte_message(
|
|
36
36
|
# taken unless configured. See
|
37
37
|
# docs/connector-development/cdk-python/schemas.md for details.
|
38
38
|
transformer.transform(data, schema)
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
39
|
+
message = AirbyteRecordMessage(
|
40
|
+
stream=stream_name,
|
41
|
+
data=data,
|
42
|
+
emitted_at=now_millis,
|
43
|
+
file_reference=file_reference,
|
44
|
+
)
|
45
45
|
return AirbyteMessage(type=MessageType.RECORD, record=message)
|
46
46
|
case AirbyteTraceMessage():
|
47
47
|
return AirbyteMessage(type=MessageType.TRACE, trace=data_or_message)
|
@@ -82,6 +82,10 @@ class EntrypointOutput:
|
|
82
82
|
def state_messages(self) -> List[AirbyteMessage]:
|
83
83
|
return self._get_message_by_types([Type.STATE])
|
84
84
|
|
85
|
+
@property
|
86
|
+
def connection_status_messages(self) -> List[AirbyteMessage]:
|
87
|
+
return self._get_message_by_types([Type.CONNECTION_STATUS])
|
88
|
+
|
85
89
|
@property
|
86
90
|
def most_recent_state(self) -> Any:
|
87
91
|
state_messages = self._get_message_by_types([Type.STATE])
|
@@ -198,6 +198,14 @@ def find_template(resource: str, execution_folder: str) -> Dict[str, Any]:
|
|
198
198
|
return json.load(template_file) # type: ignore # we assume the dev correctly set up the resource file
|
199
199
|
|
200
200
|
|
201
|
+
def find_binary_response(resource: str, execution_folder: str) -> bytes:
|
202
|
+
response_filepath = str(
|
203
|
+
get_unit_test_folder(execution_folder) / "resource" / "http" / "response" / f"{resource}"
|
204
|
+
)
|
205
|
+
with open(response_filepath, "rb") as response_file:
|
206
|
+
return response_file.read() # type: ignore # we assume the dev correctly set up the resource file
|
207
|
+
|
208
|
+
|
201
209
|
def create_record_builder(
|
202
210
|
response_template: Dict[str, Any],
|
203
211
|
records_path: Union[FieldPath, NestedPath],
|
@@ -0,0 +1,46 @@
|
|
1
|
+
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
|
2
|
+
'''FAST Airbyte Standard Tests
|
3
|
+
|
4
|
+
This module provides a set of base classes for declarative connector test suites.
|
5
|
+
The goal of this module is to provide a robust and extensible framework for testing Airbyte
|
6
|
+
connectors.
|
7
|
+
|
8
|
+
Example usage:
|
9
|
+
|
10
|
+
```python
|
11
|
+
# `test_airbyte_standards.py`
|
12
|
+
from airbyte_cdk.test import standard_tests
|
13
|
+
|
14
|
+
pytest_plugins = [
|
15
|
+
"airbyte_cdk.test.standard_tests.pytest_hooks",
|
16
|
+
]
|
17
|
+
|
18
|
+
|
19
|
+
class TestSuiteSourcePokeAPI(standard_tests.DeclarativeSourceTestSuite):
|
20
|
+
"""Test suite for the source."""
|
21
|
+
```
|
22
|
+
|
23
|
+
Available test suites base classes:
|
24
|
+
- `DeclarativeSourceTestSuite`: A test suite for declarative sources.
|
25
|
+
- `SourceTestSuiteBase`: A test suite for sources.
|
26
|
+
- `DestinationTestSuiteBase`: A test suite for destinations.
|
27
|
+
|
28
|
+
'''
|
29
|
+
|
30
|
+
from airbyte_cdk.test.standard_tests.connector_base import (
|
31
|
+
ConnectorTestScenario,
|
32
|
+
ConnectorTestSuiteBase,
|
33
|
+
)
|
34
|
+
from airbyte_cdk.test.standard_tests.declarative_sources import (
|
35
|
+
DeclarativeSourceTestSuite,
|
36
|
+
)
|
37
|
+
from airbyte_cdk.test.standard_tests.destination_base import DestinationTestSuiteBase
|
38
|
+
from airbyte_cdk.test.standard_tests.source_base import SourceTestSuiteBase
|
39
|
+
|
40
|
+
__all__ = [
|
41
|
+
"ConnectorTestScenario",
|
42
|
+
"ConnectorTestSuiteBase",
|
43
|
+
"DeclarativeSourceTestSuite",
|
44
|
+
"DestinationTestSuiteBase",
|
45
|
+
"SourceTestSuiteBase",
|
46
|
+
]
|
@@ -0,0 +1,159 @@
|
|
1
|
+
# Copyright (c) 2025 Airbyte, Inc., all rights reserved.
|
2
|
+
"""Job runner for Airbyte Standard Tests."""
|
3
|
+
|
4
|
+
import logging
|
5
|
+
import tempfile
|
6
|
+
import uuid
|
7
|
+
from dataclasses import asdict
|
8
|
+
from pathlib import Path
|
9
|
+
from typing import Any, Callable, Literal
|
10
|
+
|
11
|
+
import orjson
|
12
|
+
from typing_extensions import Protocol, runtime_checkable
|
13
|
+
|
14
|
+
from airbyte_cdk.models import (
|
15
|
+
ConfiguredAirbyteCatalog,
|
16
|
+
Status,
|
17
|
+
)
|
18
|
+
from airbyte_cdk.test import entrypoint_wrapper
|
19
|
+
from airbyte_cdk.test.standard_tests.models import (
|
20
|
+
ConnectorTestScenario,
|
21
|
+
)
|
22
|
+
|
23
|
+
|
24
|
+
def _errors_to_str(
|
25
|
+
entrypoint_output: entrypoint_wrapper.EntrypointOutput,
|
26
|
+
) -> str:
|
27
|
+
"""Convert errors from entrypoint output to a string."""
|
28
|
+
if not entrypoint_output.errors:
|
29
|
+
# If there are no errors, return an empty string.
|
30
|
+
return ""
|
31
|
+
|
32
|
+
return "\n" + "\n".join(
|
33
|
+
[
|
34
|
+
str(error.trace.error).replace(
|
35
|
+
"\\n",
|
36
|
+
"\n",
|
37
|
+
)
|
38
|
+
for error in entrypoint_output.errors
|
39
|
+
if error.trace
|
40
|
+
],
|
41
|
+
)
|
42
|
+
|
43
|
+
|
44
|
+
@runtime_checkable
|
45
|
+
class IConnector(Protocol):
|
46
|
+
"""A connector that can be run in a test scenario.
|
47
|
+
|
48
|
+
Note: We currently use 'spec' to determine if we have a connector object.
|
49
|
+
In the future, it would be preferred to leverage a 'launch' method instead,
|
50
|
+
directly on the connector (which doesn't yet exist).
|
51
|
+
"""
|
52
|
+
|
53
|
+
def spec(self, logger: logging.Logger) -> Any:
|
54
|
+
"""Connectors should have a `spec` method."""
|
55
|
+
|
56
|
+
|
57
|
+
def run_test_job(
|
58
|
+
connector: IConnector | type[IConnector] | Callable[[], IConnector],
|
59
|
+
verb: Literal["read", "check", "discover"],
|
60
|
+
test_scenario: ConnectorTestScenario,
|
61
|
+
*,
|
62
|
+
catalog: ConfiguredAirbyteCatalog | dict[str, Any] | None = None,
|
63
|
+
) -> entrypoint_wrapper.EntrypointOutput:
|
64
|
+
"""Run a test scenario from provided CLI args and return the result."""
|
65
|
+
if not connector:
|
66
|
+
raise ValueError("Connector is required")
|
67
|
+
|
68
|
+
if catalog and isinstance(catalog, ConfiguredAirbyteCatalog):
|
69
|
+
# Convert the catalog to a dict if it's already a ConfiguredAirbyteCatalog.
|
70
|
+
catalog = asdict(catalog)
|
71
|
+
|
72
|
+
connector_obj: IConnector
|
73
|
+
if isinstance(connector, type) or callable(connector):
|
74
|
+
# If the connector is a class or a factory lambda, instantiate it.
|
75
|
+
connector_obj = connector()
|
76
|
+
elif isinstance(connector, IConnector):
|
77
|
+
connector_obj = connector
|
78
|
+
else:
|
79
|
+
raise ValueError(
|
80
|
+
f"Invalid connector input: {type(connector)}",
|
81
|
+
)
|
82
|
+
|
83
|
+
args: list[str] = [verb]
|
84
|
+
if test_scenario.config_path:
|
85
|
+
args += ["--config", str(test_scenario.config_path)]
|
86
|
+
elif test_scenario.config_dict:
|
87
|
+
config_path = (
|
88
|
+
Path(tempfile.gettempdir()) / "airbyte-test" / f"temp_config_{uuid.uuid4().hex}.json"
|
89
|
+
)
|
90
|
+
config_path.parent.mkdir(parents=True, exist_ok=True)
|
91
|
+
config_path.write_text(orjson.dumps(test_scenario.config_dict).decode())
|
92
|
+
args += ["--config", str(config_path)]
|
93
|
+
|
94
|
+
catalog_path: Path | None = None
|
95
|
+
if verb not in ["discover", "check"]:
|
96
|
+
# We need a catalog for read.
|
97
|
+
if catalog:
|
98
|
+
# Write the catalog to a temp json file and pass the path to the file as an argument.
|
99
|
+
catalog_path = (
|
100
|
+
Path(tempfile.gettempdir())
|
101
|
+
/ "airbyte-test"
|
102
|
+
/ f"temp_catalog_{uuid.uuid4().hex}.json"
|
103
|
+
)
|
104
|
+
catalog_path.parent.mkdir(parents=True, exist_ok=True)
|
105
|
+
catalog_path.write_text(orjson.dumps(catalog).decode())
|
106
|
+
elif test_scenario.configured_catalog_path:
|
107
|
+
catalog_path = Path(test_scenario.configured_catalog_path)
|
108
|
+
|
109
|
+
if catalog_path:
|
110
|
+
args += ["--catalog", str(catalog_path)]
|
111
|
+
|
112
|
+
# This is a bit of a hack because the source needs the catalog early.
|
113
|
+
# Because it *also* can fail, we have to redundantly wrap it in a try/except block.
|
114
|
+
|
115
|
+
result: entrypoint_wrapper.EntrypointOutput = entrypoint_wrapper._run_command( # noqa: SLF001 # Non-public API
|
116
|
+
source=connector_obj, # type: ignore [arg-type]
|
117
|
+
args=args,
|
118
|
+
expecting_exception=test_scenario.expect_exception,
|
119
|
+
)
|
120
|
+
if result.errors and not test_scenario.expect_exception:
|
121
|
+
raise AssertionError(
|
122
|
+
f"Expected no errors but got {len(result.errors)}: \n" + _errors_to_str(result)
|
123
|
+
)
|
124
|
+
|
125
|
+
if verb == "check":
|
126
|
+
# Check is expected to fail gracefully without an exception.
|
127
|
+
# Instead, we assert that we have a CONNECTION_STATUS message with
|
128
|
+
# a failure status.
|
129
|
+
assert len(result.connection_status_messages) == 1, (
|
130
|
+
"Expected exactly one CONNECTION_STATUS message. Got "
|
131
|
+
f"{len(result.connection_status_messages)}:\n"
|
132
|
+
+ "\n".join([str(msg) for msg in result.connection_status_messages])
|
133
|
+
+ _errors_to_str(result)
|
134
|
+
)
|
135
|
+
if test_scenario.expect_exception:
|
136
|
+
conn_status = result.connection_status_messages[0].connectionStatus
|
137
|
+
assert conn_status, (
|
138
|
+
"Expected CONNECTION_STATUS message to be present. Got: \n"
|
139
|
+
+ "\n".join([str(msg) for msg in result.connection_status_messages])
|
140
|
+
)
|
141
|
+
assert conn_status.status == Status.FAILED, (
|
142
|
+
"Expected CONNECTION_STATUS message to be FAILED. Got: \n"
|
143
|
+
+ "\n".join([str(msg) for msg in result.connection_status_messages])
|
144
|
+
)
|
145
|
+
|
146
|
+
return result
|
147
|
+
|
148
|
+
# For all other verbs, we assert check that an exception is raised (or not).
|
149
|
+
if test_scenario.expect_exception:
|
150
|
+
if not result.errors:
|
151
|
+
raise AssertionError("Expected exception but got none.")
|
152
|
+
|
153
|
+
return result
|
154
|
+
|
155
|
+
assert not result.errors, (
|
156
|
+
f"Expected no errors but got {len(result.errors)}: \n" + _errors_to_str(result)
|
157
|
+
)
|
158
|
+
|
159
|
+
return result
|