airbyte-cdk 0.53.9__py3-none-any.whl → 0.55.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. airbyte_cdk/sources/concurrent_source/__init__.py +3 -0
  2. airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py +190 -0
  3. airbyte_cdk/sources/concurrent_source/concurrent_source.py +161 -0
  4. airbyte_cdk/sources/concurrent_source/concurrent_source_adapter.py +63 -0
  5. airbyte_cdk/sources/concurrent_source/partition_generation_completed_sentinel.py +17 -0
  6. airbyte_cdk/sources/concurrent_source/thread_pool_manager.py +97 -0
  7. airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py +16 -4
  8. airbyte_cdk/sources/declarative/requesters/paginators/default_paginator.py +14 -14
  9. airbyte_cdk/sources/declarative/requesters/request_options/request_options_provider.py +2 -2
  10. airbyte_cdk/sources/streams/concurrent/abstract_stream.py +4 -4
  11. airbyte_cdk/sources/streams/concurrent/adapters.py +34 -12
  12. airbyte_cdk/sources/streams/concurrent/default_stream.py +79 -0
  13. airbyte_cdk/sources/streams/concurrent/partition_enqueuer.py +7 -7
  14. airbyte_cdk/sources/streams/concurrent/partitions/partition.py +23 -0
  15. airbyte_cdk/sources/streams/concurrent/partitions/record.py +4 -3
  16. airbyte_cdk/sources/streams/concurrent/partitions/types.py +2 -3
  17. airbyte_cdk/sources/utils/slice_logger.py +5 -0
  18. {airbyte_cdk-0.53.9.dist-info → airbyte_cdk-0.55.0.dist-info}/METADATA +1 -1
  19. {airbyte_cdk-0.53.9.dist-info → airbyte_cdk-0.55.0.dist-info}/RECORD +40 -28
  20. {airbyte_cdk-0.53.9.dist-info → airbyte_cdk-0.55.0.dist-info}/WHEEL +1 -1
  21. unit_tests/sources/concurrent_source/__init__.py +3 -0
  22. unit_tests/sources/concurrent_source/test_concurrent_source_adapter.py +105 -0
  23. unit_tests/sources/declarative/parsers/test_model_to_component_factory.py +33 -0
  24. unit_tests/sources/declarative/requesters/paginators/test_default_paginator.py +9 -2
  25. unit_tests/sources/streams/concurrent/scenarios/stream_facade_builder.py +14 -7
  26. unit_tests/sources/streams/concurrent/scenarios/stream_facade_scenarios.py +2 -3
  27. unit_tests/sources/streams/concurrent/scenarios/thread_based_concurrent_stream_scenarios.py +44 -55
  28. unit_tests/sources/streams/concurrent/scenarios/thread_based_concurrent_stream_source_builder.py +24 -15
  29. unit_tests/sources/streams/concurrent/test_adapters.py +52 -32
  30. unit_tests/sources/streams/concurrent/test_concurrent_partition_generator.py +6 -5
  31. unit_tests/sources/streams/concurrent/test_concurrent_read_processor.py +604 -0
  32. unit_tests/sources/streams/concurrent/test_cursor.py +1 -1
  33. unit_tests/sources/streams/concurrent/{test_thread_based_concurrent_stream.py → test_default_stream.py} +7 -144
  34. unit_tests/sources/streams/concurrent/test_partition_reader.py +2 -2
  35. unit_tests/sources/streams/concurrent/test_thread_pool_manager.py +98 -0
  36. unit_tests/sources/streams/test_stream_read.py +1 -2
  37. unit_tests/sources/test_concurrent_source.py +105 -0
  38. unit_tests/sources/test_source_read.py +461 -0
  39. airbyte_cdk/sources/streams/concurrent/thread_based_concurrent_stream.py +0 -221
  40. {airbyte_cdk-0.53.9.dist-info → airbyte_cdk-0.55.0.dist-info}/LICENSE.txt +0 -0
  41. {airbyte_cdk-0.53.9.dist-info → airbyte_cdk-0.55.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,461 @@
1
+ #
2
+ # Copyright (c) 2023 Airbyte, Inc., all rights reserved.
3
+ #
4
+ import logging
5
+ from typing import Any, Iterable, List, Mapping, Optional, Tuple, Union
6
+ from unittest.mock import Mock
7
+
8
+ import freezegun
9
+ from airbyte_cdk.models import (
10
+ AirbyteMessage,
11
+ AirbyteRecordMessage,
12
+ AirbyteStream,
13
+ AirbyteStreamStatus,
14
+ AirbyteStreamStatusTraceMessage,
15
+ AirbyteTraceMessage,
16
+ ConfiguredAirbyteCatalog,
17
+ ConfiguredAirbyteStream,
18
+ DestinationSyncMode,
19
+ StreamDescriptor,
20
+ SyncMode,
21
+ TraceType,
22
+ )
23
+ from airbyte_cdk.models import Type as MessageType
24
+ from airbyte_cdk.sources import AbstractSource
25
+ from airbyte_cdk.sources.concurrent_source.concurrent_source import ConcurrentSource
26
+ from airbyte_cdk.sources.concurrent_source.concurrent_source_adapter import ConcurrentSourceAdapter
27
+ from airbyte_cdk.sources.message import InMemoryMessageRepository
28
+ from airbyte_cdk.sources.streams import Stream
29
+ from airbyte_cdk.sources.streams.concurrent.adapters import StreamFacade
30
+ from airbyte_cdk.sources.streams.concurrent.cursor import NoopCursor
31
+ from airbyte_cdk.sources.streams.core import StreamData
32
+ from airbyte_cdk.utils import AirbyteTracedException
33
+ from unit_tests.sources.streams.concurrent.scenarios.thread_based_concurrent_stream_source_builder import NeverLogSliceLogger
34
+
35
+
36
+ class _MockStream(Stream):
37
+ def __init__(self, slice_to_records: Mapping[str, List[Mapping[str, Any]]], name: str):
38
+ self._slice_to_records = slice_to_records
39
+ self._name = name
40
+
41
+ @property
42
+ def name(self) -> str:
43
+ return self._name
44
+
45
+ @property
46
+ def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]:
47
+ return None
48
+
49
+ def stream_slices(
50
+ self, *, sync_mode: SyncMode, cursor_field: Optional[List[str]] = None, stream_state: Optional[Mapping[str, Any]] = None
51
+ ) -> Iterable[Optional[Mapping[str, Any]]]:
52
+ for partition in self._slice_to_records.keys():
53
+ yield {"partition": partition}
54
+
55
+ def read_records(
56
+ self,
57
+ sync_mode: SyncMode,
58
+ cursor_field: Optional[List[str]] = None,
59
+ stream_slice: Optional[Mapping[str, Any]] = None,
60
+ stream_state: Optional[Mapping[str, Any]] = None,
61
+ ) -> Iterable[StreamData]:
62
+ for record_or_exception in self._slice_to_records[stream_slice["partition"]]:
63
+ if isinstance(record_or_exception, Exception):
64
+ raise record_or_exception
65
+ else:
66
+ yield record_or_exception
67
+
68
+ def get_json_schema(self) -> Mapping[str, Any]:
69
+ return {}
70
+
71
+
72
+ class _MockSource(AbstractSource):
73
+ message_repository = InMemoryMessageRepository()
74
+
75
+ def check_connection(self, logger: logging.Logger, config: Mapping[str, Any]) -> Tuple[bool, Optional[Any]]:
76
+ pass
77
+
78
+ def set_streams(self, streams):
79
+ self._streams = streams
80
+
81
+ def streams(self, config: Mapping[str, Any]) -> List[Stream]:
82
+ return self._streams
83
+
84
+
85
+ class _MockConcurrentSource(ConcurrentSourceAdapter):
86
+ message_repository = InMemoryMessageRepository()
87
+
88
+ def __init__(self, logger):
89
+ concurrent_source = ConcurrentSource.create(1, 1, logger, NeverLogSliceLogger(), self.message_repository)
90
+ super().__init__(concurrent_source)
91
+
92
+ def check_connection(self, logger: logging.Logger, config: Mapping[str, Any]) -> Tuple[bool, Optional[Any]]:
93
+ pass
94
+
95
+ def set_streams(self, streams):
96
+ self._streams = streams
97
+
98
+ def streams(self, config: Mapping[str, Any]) -> List[Stream]:
99
+ return self._streams
100
+
101
+
102
+ @freezegun.freeze_time("2020-01-01T00:00:00")
103
+ def test_concurrent_source_yields_the_same_messages_as_abstract_source_when_no_exceptions_are_raised():
104
+ records_stream_1_partition_1 = [
105
+ {"id": 1, "partition": "1"},
106
+ {"id": 2, "partition": "1"},
107
+ ]
108
+ records_stream_1_partition_2 = [
109
+ {"id": 3, "partition": "2"},
110
+ {"id": 4, "partition": "2"},
111
+ ]
112
+ records_stream_2_partition_1 = [
113
+ {"id": 100, "partition": "A"},
114
+ {"id": 200, "partition": "A"},
115
+ ]
116
+ records_stream_2_partition_2 = [
117
+ {"id": 300, "partition": "B"},
118
+ {"id": 400, "partition": "B"},
119
+ ]
120
+ stream_1_slice_to_partition = {"1": records_stream_1_partition_1, "2": records_stream_1_partition_2}
121
+ stream_2_slice_to_partition = {"A": records_stream_2_partition_1, "B": records_stream_2_partition_2}
122
+ state = None
123
+ logger = _init_logger()
124
+
125
+ source, concurrent_source = _init_sources([stream_1_slice_to_partition, stream_2_slice_to_partition], state, logger)
126
+
127
+ config = {}
128
+ catalog = _create_configured_catalog(source._streams)
129
+ messages_from_abstract_source = _read_from_source(source, logger, config, catalog, state, None)
130
+ messages_from_concurrent_source = _read_from_source(concurrent_source, logger, config, catalog, state, None)
131
+
132
+ expected_messages = [
133
+ AirbyteMessage(
134
+ type=MessageType.TRACE,
135
+ trace=AirbyteTraceMessage(
136
+ type=TraceType.STREAM_STATUS,
137
+ emitted_at=1577836800000.0,
138
+ error=None,
139
+ estimate=None,
140
+ stream_status=AirbyteStreamStatusTraceMessage(
141
+ stream_descriptor=StreamDescriptor(name="stream0"), status=AirbyteStreamStatus(AirbyteStreamStatus.STARTED)
142
+ ),
143
+ ),
144
+ ),
145
+ AirbyteMessage(
146
+ type=MessageType.TRACE,
147
+ trace=AirbyteTraceMessage(
148
+ type=TraceType.STREAM_STATUS,
149
+ emitted_at=1577836800000.0,
150
+ error=None,
151
+ estimate=None,
152
+ stream_status=AirbyteStreamStatusTraceMessage(
153
+ stream_descriptor=StreamDescriptor(name="stream0"), status=AirbyteStreamStatus(AirbyteStreamStatus.RUNNING)
154
+ ),
155
+ ),
156
+ ),
157
+ AirbyteMessage(
158
+ type=MessageType.RECORD,
159
+ record=AirbyteRecordMessage(
160
+ stream="stream0",
161
+ data=records_stream_1_partition_1[0],
162
+ emitted_at=1577836800000,
163
+ ),
164
+ ),
165
+ AirbyteMessage(
166
+ type=MessageType.RECORD,
167
+ record=AirbyteRecordMessage(
168
+ stream="stream0",
169
+ data=records_stream_1_partition_1[1],
170
+ emitted_at=1577836800000,
171
+ ),
172
+ ),
173
+ AirbyteMessage(
174
+ type=MessageType.RECORD,
175
+ record=AirbyteRecordMessage(
176
+ stream="stream0",
177
+ data=records_stream_1_partition_2[0],
178
+ emitted_at=1577836800000,
179
+ ),
180
+ ),
181
+ AirbyteMessage(
182
+ type=MessageType.RECORD,
183
+ record=AirbyteRecordMessage(
184
+ stream="stream0",
185
+ data=records_stream_1_partition_2[1],
186
+ emitted_at=1577836800000,
187
+ ),
188
+ ),
189
+ AirbyteMessage(
190
+ type=MessageType.TRACE,
191
+ trace=AirbyteTraceMessage(
192
+ type=TraceType.STREAM_STATUS,
193
+ emitted_at=1577836800000.0,
194
+ error=None,
195
+ estimate=None,
196
+ stream_status=AirbyteStreamStatusTraceMessage(
197
+ stream_descriptor=StreamDescriptor(name="stream0"), status=AirbyteStreamStatus(AirbyteStreamStatus.COMPLETE)
198
+ ),
199
+ ),
200
+ ),
201
+ AirbyteMessage(
202
+ type=MessageType.TRACE,
203
+ trace=AirbyteTraceMessage(
204
+ type=TraceType.STREAM_STATUS,
205
+ emitted_at=1577836800000.0,
206
+ error=None,
207
+ estimate=None,
208
+ stream_status=AirbyteStreamStatusTraceMessage(
209
+ stream_descriptor=StreamDescriptor(name="stream1"), status=AirbyteStreamStatus(AirbyteStreamStatus.STARTED)
210
+ ),
211
+ ),
212
+ ),
213
+ AirbyteMessage(
214
+ type=MessageType.TRACE,
215
+ trace=AirbyteTraceMessage(
216
+ type=TraceType.STREAM_STATUS,
217
+ emitted_at=1577836800000.0,
218
+ error=None,
219
+ estimate=None,
220
+ stream_status=AirbyteStreamStatusTraceMessage(
221
+ stream_descriptor=StreamDescriptor(name="stream1"), status=AirbyteStreamStatus(AirbyteStreamStatus.RUNNING)
222
+ ),
223
+ ),
224
+ ),
225
+ AirbyteMessage(
226
+ type=MessageType.RECORD,
227
+ record=AirbyteRecordMessage(
228
+ stream="stream1",
229
+ data=records_stream_2_partition_1[0],
230
+ emitted_at=1577836800000,
231
+ ),
232
+ ),
233
+ AirbyteMessage(
234
+ type=MessageType.RECORD,
235
+ record=AirbyteRecordMessage(
236
+ stream="stream1",
237
+ data=records_stream_2_partition_1[1],
238
+ emitted_at=1577836800000,
239
+ ),
240
+ ),
241
+ AirbyteMessage(
242
+ type=MessageType.RECORD,
243
+ record=AirbyteRecordMessage(
244
+ stream="stream1",
245
+ data=records_stream_2_partition_2[0],
246
+ emitted_at=1577836800000,
247
+ ),
248
+ ),
249
+ AirbyteMessage(
250
+ type=MessageType.RECORD,
251
+ record=AirbyteRecordMessage(
252
+ stream="stream1",
253
+ data=records_stream_2_partition_2[1],
254
+ emitted_at=1577836800000,
255
+ ),
256
+ ),
257
+ AirbyteMessage(
258
+ type=MessageType.TRACE,
259
+ trace=AirbyteTraceMessage(
260
+ type=TraceType.STREAM_STATUS,
261
+ emitted_at=1577836800000.0,
262
+ error=None,
263
+ estimate=None,
264
+ stream_status=AirbyteStreamStatusTraceMessage(
265
+ stream_descriptor=StreamDescriptor(name="stream1"), status=AirbyteStreamStatus(AirbyteStreamStatus.COMPLETE)
266
+ ),
267
+ ),
268
+ ),
269
+ ]
270
+ _verify_messages(expected_messages, messages_from_abstract_source, messages_from_concurrent_source)
271
+
272
+
273
+ @freezegun.freeze_time("2020-01-01T00:00:00")
274
+ def test_concurrent_source_yields_the_same_messages_as_abstract_source_when_a_traced_exception_is_raised():
275
+ records = [{"id": 1, "partition": "1"}, AirbyteTracedException()]
276
+ stream_slice_to_partition = {"1": records}
277
+
278
+ logger = _init_logger()
279
+ state = None
280
+ source, concurrent_source = _init_sources([stream_slice_to_partition], state, logger)
281
+ config = {}
282
+ catalog = _create_configured_catalog(source._streams)
283
+ messages_from_abstract_source = _read_from_source(source, logger, config, catalog, state, AirbyteTracedException)
284
+ messages_from_concurrent_source = _read_from_source(concurrent_source, logger, config, catalog, state, AirbyteTracedException)
285
+
286
+ expected_messages = [
287
+ AirbyteMessage(
288
+ type=MessageType.TRACE,
289
+ trace=AirbyteTraceMessage(
290
+ type=TraceType.STREAM_STATUS,
291
+ emitted_at=1577836800000.0,
292
+ error=None,
293
+ estimate=None,
294
+ stream_status=AirbyteStreamStatusTraceMessage(
295
+ stream_descriptor=StreamDescriptor(name="stream0"), status=AirbyteStreamStatus(AirbyteStreamStatus.STARTED)
296
+ ),
297
+ ),
298
+ ),
299
+ AirbyteMessage(
300
+ type=MessageType.TRACE,
301
+ trace=AirbyteTraceMessage(
302
+ type=TraceType.STREAM_STATUS,
303
+ emitted_at=1577836800000.0,
304
+ error=None,
305
+ estimate=None,
306
+ stream_status=AirbyteStreamStatusTraceMessage(
307
+ stream_descriptor=StreamDescriptor(name="stream0"), status=AirbyteStreamStatus(AirbyteStreamStatus.RUNNING)
308
+ ),
309
+ ),
310
+ ),
311
+ AirbyteMessage(
312
+ type=MessageType.RECORD,
313
+ record=AirbyteRecordMessage(
314
+ stream="stream0",
315
+ data=records[0],
316
+ emitted_at=1577836800000,
317
+ ),
318
+ ),
319
+ AirbyteMessage(
320
+ type=MessageType.TRACE,
321
+ trace=AirbyteTraceMessage(
322
+ type=TraceType.STREAM_STATUS,
323
+ emitted_at=1577836800000.0,
324
+ error=None,
325
+ estimate=None,
326
+ stream_status=AirbyteStreamStatusTraceMessage(
327
+ stream_descriptor=StreamDescriptor(name="stream0"), status=AirbyteStreamStatus(AirbyteStreamStatus.INCOMPLETE)
328
+ ),
329
+ ),
330
+ ),
331
+ ]
332
+ _verify_messages(expected_messages, messages_from_abstract_source, messages_from_concurrent_source)
333
+
334
+
335
+ @freezegun.freeze_time("2020-01-01T00:00:00")
336
+ def test_concurrent_source_yields_the_same_messages_as_abstract_source_when_an_exception_is_raised():
337
+ records = [{"id": 1, "partition": "1"}, RuntimeError()]
338
+ stream_slice_to_partition = {"1": records}
339
+ logger = _init_logger()
340
+
341
+ state = None
342
+
343
+ source, concurrent_source = _init_sources([stream_slice_to_partition], state, logger)
344
+ config = {}
345
+ catalog = _create_configured_catalog(source._streams)
346
+ messages_from_abstract_source = _read_from_source(source, logger, config, catalog, state, RuntimeError)
347
+ messages_from_concurrent_source = _read_from_source(concurrent_source, logger, config, catalog, state, RuntimeError)
348
+
349
+ expected_messages = [
350
+ AirbyteMessage(
351
+ type=MessageType.TRACE,
352
+ trace=AirbyteTraceMessage(
353
+ type=TraceType.STREAM_STATUS,
354
+ emitted_at=1577836800000.0,
355
+ error=None,
356
+ estimate=None,
357
+ stream_status=AirbyteStreamStatusTraceMessage(
358
+ stream_descriptor=StreamDescriptor(name="stream0"), status=AirbyteStreamStatus(AirbyteStreamStatus.STARTED)
359
+ ),
360
+ ),
361
+ ),
362
+ AirbyteMessage(
363
+ type=MessageType.TRACE,
364
+ trace=AirbyteTraceMessage(
365
+ type=TraceType.STREAM_STATUS,
366
+ emitted_at=1577836800000.0,
367
+ error=None,
368
+ estimate=None,
369
+ stream_status=AirbyteStreamStatusTraceMessage(
370
+ stream_descriptor=StreamDescriptor(name="stream0"), status=AirbyteStreamStatus(AirbyteStreamStatus.RUNNING)
371
+ ),
372
+ ),
373
+ ),
374
+ AirbyteMessage(
375
+ type=MessageType.RECORD,
376
+ record=AirbyteRecordMessage(
377
+ stream="stream0",
378
+ data=records[0],
379
+ emitted_at=1577836800000,
380
+ ),
381
+ ),
382
+ AirbyteMessage(
383
+ type=MessageType.TRACE,
384
+ trace=AirbyteTraceMessage(
385
+ type=TraceType.STREAM_STATUS,
386
+ emitted_at=1577836800000.0,
387
+ error=None,
388
+ estimate=None,
389
+ stream_status=AirbyteStreamStatusTraceMessage(
390
+ stream_descriptor=StreamDescriptor(name="stream0"), status=AirbyteStreamStatus(AirbyteStreamStatus.INCOMPLETE)
391
+ ),
392
+ ),
393
+ ),
394
+ ]
395
+ _verify_messages(expected_messages, messages_from_abstract_source, messages_from_concurrent_source)
396
+
397
+
398
+ def _init_logger():
399
+ logger = Mock()
400
+ logger.level = logging.INFO
401
+ logger.isEnabledFor.return_value = False
402
+ return logger
403
+
404
+
405
+ def _init_sources(stream_slice_to_partitions, state, logger):
406
+ source = _init_source(stream_slice_to_partitions, state, logger, _MockSource())
407
+ concurrent_source = _init_source(stream_slice_to_partitions, state, logger, _MockConcurrentSource(logger))
408
+ return source, concurrent_source
409
+
410
+
411
+ def _init_source(stream_slice_to_partitions, state, logger, source):
412
+ cursor = NoopCursor()
413
+ streams = [
414
+ StreamFacade.create_from_stream(_MockStream(stream_slices, f"stream{i}"), source, logger, state, cursor)
415
+ for i, stream_slices in enumerate(stream_slice_to_partitions)
416
+ ]
417
+ source.set_streams(streams)
418
+ return source
419
+
420
+
421
+ def _create_configured_catalog(streams):
422
+ return ConfiguredAirbyteCatalog(
423
+ streams=[
424
+ ConfiguredAirbyteStream(
425
+ stream=AirbyteStream(name=s.name, json_schema={}, supported_sync_modes=[SyncMode.full_refresh]),
426
+ sync_mode=SyncMode.full_refresh,
427
+ cursor_field=None,
428
+ destination_sync_mode=DestinationSyncMode.overwrite,
429
+ )
430
+ for s in streams
431
+ ]
432
+ )
433
+
434
+
435
+ def _read_from_source(source, logger, config, catalog, state, expected_exception):
436
+ messages = []
437
+ try:
438
+ for m in source.read(logger, config, catalog, state):
439
+ messages.append(m)
440
+ except Exception as e:
441
+ if expected_exception:
442
+ assert isinstance(e, expected_exception)
443
+ return messages
444
+
445
+
446
+ def _verify_messages(expected_messages, messages_from_abstract_source, messages_from_concurrent_source):
447
+ assert _compare(expected_messages, messages_from_concurrent_source)
448
+
449
+
450
+ def _compare(s, t):
451
+ # Use a compare method that does not require ordering or hashing the elements
452
+ # We can't rely on the ordering because of the multithreading
453
+ # AirbyteMessage does not implement __eq__ and __hash__
454
+ t = list(t)
455
+ try:
456
+ for elem in s:
457
+ t.remove(elem)
458
+ except ValueError:
459
+ print(f"ValueError: {elem}")
460
+ return False
461
+ return not t
@@ -1,221 +0,0 @@
1
- #
2
- # Copyright (c) 2023 Airbyte, Inc., all rights reserved.
3
- #
4
-
5
- import concurrent
6
- import time
7
- from concurrent.futures import Future
8
- from functools import lru_cache
9
- from logging import Logger
10
- from queue import Queue
11
- from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional
12
-
13
- from airbyte_cdk.models import AirbyteStream, SyncMode
14
- from airbyte_cdk.sources.message import MessageRepository
15
- from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream
16
- from airbyte_cdk.sources.streams.concurrent.availability_strategy import AbstractAvailabilityStrategy, StreamAvailability
17
- from airbyte_cdk.sources.streams.concurrent.cursor import Cursor, NoopCursor
18
- from airbyte_cdk.sources.streams.concurrent.partition_enqueuer import PartitionEnqueuer
19
- from airbyte_cdk.sources.streams.concurrent.partition_reader import PartitionReader
20
- from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition
21
- from airbyte_cdk.sources.streams.concurrent.partitions.partition_generator import PartitionGenerator
22
- from airbyte_cdk.sources.streams.concurrent.partitions.record import Record
23
- from airbyte_cdk.sources.streams.concurrent.partitions.types import PARTITIONS_GENERATED_SENTINEL, PartitionCompleteSentinel, QueueItem
24
- from airbyte_cdk.sources.utils.slice_logger import SliceLogger
25
-
26
-
27
- class ThreadBasedConcurrentStream(AbstractStream):
28
-
29
- DEFAULT_TIMEOUT_SECONDS = 900
30
- DEFAULT_MAX_QUEUE_SIZE = 10_000
31
- DEFAULT_SLEEP_TIME = 0.1
32
-
33
- def __init__(
34
- self,
35
- partition_generator: PartitionGenerator,
36
- max_workers: int,
37
- name: str,
38
- json_schema: Mapping[str, Any],
39
- availability_strategy: AbstractAvailabilityStrategy,
40
- primary_key: List[str],
41
- cursor_field: Optional[str],
42
- slice_logger: SliceLogger,
43
- logger: Logger,
44
- message_repository: MessageRepository,
45
- timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
46
- max_concurrent_tasks: int = DEFAULT_MAX_QUEUE_SIZE,
47
- sleep_time: float = DEFAULT_SLEEP_TIME,
48
- cursor: Cursor = NoopCursor(),
49
- namespace: Optional[str] = None,
50
- ):
51
- self._stream_partition_generator = partition_generator
52
- self._max_workers = max_workers
53
- self._threadpool = concurrent.futures.ThreadPoolExecutor(max_workers=self._max_workers, thread_name_prefix="workerpool")
54
- self._name = name
55
- self._json_schema = json_schema
56
- self._availability_strategy = availability_strategy
57
- self._primary_key = primary_key
58
- self._cursor_field = cursor_field
59
- self._slice_logger = slice_logger
60
- self._logger = logger
61
- self._message_repository = message_repository
62
- self._timeout_seconds = timeout_seconds
63
- self._max_concurrent_tasks = max_concurrent_tasks
64
- self._sleep_time = sleep_time
65
- self._cursor = cursor
66
- self._namespace = namespace
67
-
68
- def read(self) -> Iterable[Record]:
69
- """
70
- Read all data from the stream (only full-refresh is supported at the moment)
71
-
72
- Algorithm:
73
- 1. Submit a future to generate the stream's partition to process.
74
- - This has to be done asynchronously because we sometimes need to submit requests to the API to generate all partitions (eg for substreams).
75
- - The future will add the partitions to process on a work queue.
76
- 2. Continuously poll work from the work queue until all partitions are generated and processed
77
- - If the next work item is an Exception, stop the threadpool and raise it.
78
- - If the next work item is a partition, submit a future to process it.
79
- - The future will add the records to emit on the work queue.
80
- - Add the partitions to the partitions_to_done dict so we know it needs to complete for the sync to succeed.
81
- - If the next work item is a record, yield the record.
82
- - If the next work item is PARTITIONS_GENERATED_SENTINEL, all the partitions were generated.
83
- - If the next work item is a PartitionCompleteSentinel, a partition is done processing.
84
- - Update the value in partitions_to_done to True so we know the partition is completed.
85
- """
86
- self._logger.debug(f"Processing stream slices for {self.name}")
87
- futures: List[Future[Any]] = []
88
- queue: Queue[QueueItem] = Queue()
89
- partition_generator = PartitionEnqueuer(queue, PARTITIONS_GENERATED_SENTINEL)
90
- partition_reader = PartitionReader(queue)
91
-
92
- self._submit_task(futures, partition_generator.generate_partitions, self._stream_partition_generator)
93
-
94
- # True -> partition is done
95
- # False -> partition is not done
96
- partitions_to_done: Dict[Partition, bool] = {}
97
-
98
- finished_partitions = False
99
- while record_or_partition_or_exception := queue.get(block=True, timeout=self._timeout_seconds):
100
- if isinstance(record_or_partition_or_exception, Exception):
101
- # An exception was raised while processing the stream
102
- # Stop the threadpool and raise it
103
- self._stop_and_raise_exception(record_or_partition_or_exception)
104
- elif record_or_partition_or_exception == PARTITIONS_GENERATED_SENTINEL:
105
- # All partitions were generated
106
- finished_partitions = True
107
- elif isinstance(record_or_partition_or_exception, PartitionCompleteSentinel):
108
- # All records for a partition were generated
109
- if record_or_partition_or_exception.partition not in partitions_to_done:
110
- raise RuntimeError(
111
- f"Received sentinel for partition {record_or_partition_or_exception.partition} that was not in partitions. This is indicative of a bug in the CDK. Please contact support.partitions:\n{partitions_to_done}"
112
- )
113
- partitions_to_done[record_or_partition_or_exception.partition] = True
114
- self._cursor.close_partition(record_or_partition_or_exception.partition)
115
- elif isinstance(record_or_partition_or_exception, Record):
116
- # Emit records
117
- yield record_or_partition_or_exception
118
- self._cursor.observe(record_or_partition_or_exception)
119
- elif isinstance(record_or_partition_or_exception, Partition):
120
- # A new partition was generated and must be processed
121
- partitions_to_done[record_or_partition_or_exception] = False
122
- if self._slice_logger.should_log_slice_message(self._logger):
123
- self._message_repository.emit_message(
124
- self._slice_logger.create_slice_log_message(record_or_partition_or_exception.to_slice())
125
- )
126
- self._submit_task(futures, partition_reader.process_partition, record_or_partition_or_exception)
127
- if finished_partitions and all(partitions_to_done.values()):
128
- # All partitions were generated and process. We're done here
129
- break
130
-
131
- self._check_for_errors(futures)
132
-
133
- def _submit_task(self, futures: List[Future[Any]], function: Callable[..., Any], *args: Any) -> None:
134
- # Submit a task to the threadpool, waiting if there are too many pending tasks
135
- self._wait_while_too_many_pending_futures(futures)
136
- futures.append(self._threadpool.submit(function, *args))
137
-
138
- def _wait_while_too_many_pending_futures(self, futures: List[Future[Any]]) -> None:
139
- # Wait until the number of pending tasks is < self._max_concurrent_tasks
140
- while True:
141
- self._prune_futures(futures)
142
- if len(futures) < self._max_concurrent_tasks:
143
- break
144
- self._logger.info("Main thread is sleeping because the task queue is full...")
145
- time.sleep(self._sleep_time)
146
-
147
- def _prune_futures(self, futures: List[Future[Any]]) -> None:
148
- """
149
- Take a list in input and remove the futures that are completed. If a future has an exception, it'll raise and kill the stream
150
- operation.
151
-
152
- Pruning this list safely relies on the assumptions that only the main thread can modify the list of futures.
153
- """
154
- if len(futures) < self._max_concurrent_tasks:
155
- return
156
-
157
- for index in reversed(range(len(futures))):
158
- future = futures[index]
159
- optional_exception = future.exception()
160
- if optional_exception:
161
- exception = RuntimeError(f"Failed reading from stream {self.name} with error: {optional_exception}")
162
- self._stop_and_raise_exception(exception)
163
-
164
- if future.done():
165
- futures.pop(index)
166
-
167
- def _check_for_errors(self, futures: List[Future[Any]]) -> None:
168
- exceptions_from_futures = [f for f in [future.exception() for future in futures] if f is not None]
169
- if exceptions_from_futures:
170
- exception = RuntimeError(f"Failed reading from stream {self.name} with errors: {exceptions_from_futures}")
171
- self._stop_and_raise_exception(exception)
172
- else:
173
- futures_not_done = [f for f in futures if not f.done()]
174
- if futures_not_done:
175
- exception = RuntimeError(f"Failed reading from stream {self.name} with futures not done: {futures_not_done}")
176
- self._stop_and_raise_exception(exception)
177
-
178
- def _stop_and_raise_exception(self, exception: BaseException) -> None:
179
- self._threadpool.shutdown(wait=False, cancel_futures=True)
180
- raise exception
181
-
182
- @property
183
- def name(self) -> str:
184
- return self._name
185
-
186
- def check_availability(self) -> StreamAvailability:
187
- return self._availability_strategy.check_availability(self._logger)
188
-
189
- @property
190
- def cursor_field(self) -> Optional[str]:
191
- return self._cursor_field
192
-
193
- @lru_cache(maxsize=None)
194
- def get_json_schema(self) -> Mapping[str, Any]:
195
- return self._json_schema
196
-
197
- def as_airbyte_stream(self) -> AirbyteStream:
198
- stream = AirbyteStream(name=self.name, json_schema=dict(self._json_schema), supported_sync_modes=[SyncMode.full_refresh])
199
-
200
- if self._namespace:
201
- stream.namespace = self._namespace
202
-
203
- if self._cursor_field:
204
- stream.source_defined_cursor = True
205
- stream.supported_sync_modes.append(SyncMode.incremental)
206
- stream.default_cursor_field = [self._cursor_field]
207
-
208
- keys = self._primary_key
209
- if keys and len(keys) > 0:
210
- stream.source_defined_primary_key = [keys]
211
-
212
- return stream
213
-
214
- def log_stream_sync_configuration(self) -> None:
215
- self._logger.debug(
216
- f"Syncing stream instance: {self.name}",
217
- extra={
218
- "primary_key": self._primary_key,
219
- "cursor_field": self.cursor_field,
220
- },
221
- )