airbyte-cdk 0.53.9__py3-none-any.whl → 0.55.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (41) hide show
  1. airbyte_cdk/sources/concurrent_source/__init__.py +3 -0
  2. airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py +190 -0
  3. airbyte_cdk/sources/concurrent_source/concurrent_source.py +161 -0
  4. airbyte_cdk/sources/concurrent_source/concurrent_source_adapter.py +63 -0
  5. airbyte_cdk/sources/concurrent_source/partition_generation_completed_sentinel.py +17 -0
  6. airbyte_cdk/sources/concurrent_source/thread_pool_manager.py +97 -0
  7. airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py +16 -4
  8. airbyte_cdk/sources/declarative/requesters/paginators/default_paginator.py +14 -14
  9. airbyte_cdk/sources/declarative/requesters/request_options/request_options_provider.py +2 -2
  10. airbyte_cdk/sources/streams/concurrent/abstract_stream.py +4 -4
  11. airbyte_cdk/sources/streams/concurrent/adapters.py +34 -12
  12. airbyte_cdk/sources/streams/concurrent/default_stream.py +79 -0
  13. airbyte_cdk/sources/streams/concurrent/partition_enqueuer.py +7 -7
  14. airbyte_cdk/sources/streams/concurrent/partitions/partition.py +23 -0
  15. airbyte_cdk/sources/streams/concurrent/partitions/record.py +4 -3
  16. airbyte_cdk/sources/streams/concurrent/partitions/types.py +2 -3
  17. airbyte_cdk/sources/utils/slice_logger.py +5 -0
  18. {airbyte_cdk-0.53.9.dist-info → airbyte_cdk-0.55.0.dist-info}/METADATA +1 -1
  19. {airbyte_cdk-0.53.9.dist-info → airbyte_cdk-0.55.0.dist-info}/RECORD +40 -28
  20. {airbyte_cdk-0.53.9.dist-info → airbyte_cdk-0.55.0.dist-info}/WHEEL +1 -1
  21. unit_tests/sources/concurrent_source/__init__.py +3 -0
  22. unit_tests/sources/concurrent_source/test_concurrent_source_adapter.py +105 -0
  23. unit_tests/sources/declarative/parsers/test_model_to_component_factory.py +33 -0
  24. unit_tests/sources/declarative/requesters/paginators/test_default_paginator.py +9 -2
  25. unit_tests/sources/streams/concurrent/scenarios/stream_facade_builder.py +14 -7
  26. unit_tests/sources/streams/concurrent/scenarios/stream_facade_scenarios.py +2 -3
  27. unit_tests/sources/streams/concurrent/scenarios/thread_based_concurrent_stream_scenarios.py +44 -55
  28. unit_tests/sources/streams/concurrent/scenarios/thread_based_concurrent_stream_source_builder.py +24 -15
  29. unit_tests/sources/streams/concurrent/test_adapters.py +52 -32
  30. unit_tests/sources/streams/concurrent/test_concurrent_partition_generator.py +6 -5
  31. unit_tests/sources/streams/concurrent/test_concurrent_read_processor.py +604 -0
  32. unit_tests/sources/streams/concurrent/test_cursor.py +1 -1
  33. unit_tests/sources/streams/concurrent/{test_thread_based_concurrent_stream.py → test_default_stream.py} +7 -144
  34. unit_tests/sources/streams/concurrent/test_partition_reader.py +2 -2
  35. unit_tests/sources/streams/concurrent/test_thread_pool_manager.py +98 -0
  36. unit_tests/sources/streams/test_stream_read.py +1 -2
  37. unit_tests/sources/test_concurrent_source.py +105 -0
  38. unit_tests/sources/test_source_read.py +461 -0
  39. airbyte_cdk/sources/streams/concurrent/thread_based_concurrent_stream.py +0 -221
  40. {airbyte_cdk-0.53.9.dist-info → airbyte_cdk-0.55.0.dist-info}/LICENSE.txt +0 -0
  41. {airbyte_cdk-0.53.9.dist-info → airbyte_cdk-0.55.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,461 @@
1
+ #
2
+ # Copyright (c) 2023 Airbyte, Inc., all rights reserved.
3
+ #
4
+ import logging
5
+ from typing import Any, Iterable, List, Mapping, Optional, Tuple, Union
6
+ from unittest.mock import Mock
7
+
8
+ import freezegun
9
+ from airbyte_cdk.models import (
10
+ AirbyteMessage,
11
+ AirbyteRecordMessage,
12
+ AirbyteStream,
13
+ AirbyteStreamStatus,
14
+ AirbyteStreamStatusTraceMessage,
15
+ AirbyteTraceMessage,
16
+ ConfiguredAirbyteCatalog,
17
+ ConfiguredAirbyteStream,
18
+ DestinationSyncMode,
19
+ StreamDescriptor,
20
+ SyncMode,
21
+ TraceType,
22
+ )
23
+ from airbyte_cdk.models import Type as MessageType
24
+ from airbyte_cdk.sources import AbstractSource
25
+ from airbyte_cdk.sources.concurrent_source.concurrent_source import ConcurrentSource
26
+ from airbyte_cdk.sources.concurrent_source.concurrent_source_adapter import ConcurrentSourceAdapter
27
+ from airbyte_cdk.sources.message import InMemoryMessageRepository
28
+ from airbyte_cdk.sources.streams import Stream
29
+ from airbyte_cdk.sources.streams.concurrent.adapters import StreamFacade
30
+ from airbyte_cdk.sources.streams.concurrent.cursor import NoopCursor
31
+ from airbyte_cdk.sources.streams.core import StreamData
32
+ from airbyte_cdk.utils import AirbyteTracedException
33
+ from unit_tests.sources.streams.concurrent.scenarios.thread_based_concurrent_stream_source_builder import NeverLogSliceLogger
34
+
35
+
36
+ class _MockStream(Stream):
37
+ def __init__(self, slice_to_records: Mapping[str, List[Mapping[str, Any]]], name: str):
38
+ self._slice_to_records = slice_to_records
39
+ self._name = name
40
+
41
+ @property
42
+ def name(self) -> str:
43
+ return self._name
44
+
45
+ @property
46
+ def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]:
47
+ return None
48
+
49
+ def stream_slices(
50
+ self, *, sync_mode: SyncMode, cursor_field: Optional[List[str]] = None, stream_state: Optional[Mapping[str, Any]] = None
51
+ ) -> Iterable[Optional[Mapping[str, Any]]]:
52
+ for partition in self._slice_to_records.keys():
53
+ yield {"partition": partition}
54
+
55
+ def read_records(
56
+ self,
57
+ sync_mode: SyncMode,
58
+ cursor_field: Optional[List[str]] = None,
59
+ stream_slice: Optional[Mapping[str, Any]] = None,
60
+ stream_state: Optional[Mapping[str, Any]] = None,
61
+ ) -> Iterable[StreamData]:
62
+ for record_or_exception in self._slice_to_records[stream_slice["partition"]]:
63
+ if isinstance(record_or_exception, Exception):
64
+ raise record_or_exception
65
+ else:
66
+ yield record_or_exception
67
+
68
+ def get_json_schema(self) -> Mapping[str, Any]:
69
+ return {}
70
+
71
+
72
+ class _MockSource(AbstractSource):
73
+ message_repository = InMemoryMessageRepository()
74
+
75
+ def check_connection(self, logger: logging.Logger, config: Mapping[str, Any]) -> Tuple[bool, Optional[Any]]:
76
+ pass
77
+
78
+ def set_streams(self, streams):
79
+ self._streams = streams
80
+
81
+ def streams(self, config: Mapping[str, Any]) -> List[Stream]:
82
+ return self._streams
83
+
84
+
85
+ class _MockConcurrentSource(ConcurrentSourceAdapter):
86
+ message_repository = InMemoryMessageRepository()
87
+
88
+ def __init__(self, logger):
89
+ concurrent_source = ConcurrentSource.create(1, 1, logger, NeverLogSliceLogger(), self.message_repository)
90
+ super().__init__(concurrent_source)
91
+
92
+ def check_connection(self, logger: logging.Logger, config: Mapping[str, Any]) -> Tuple[bool, Optional[Any]]:
93
+ pass
94
+
95
+ def set_streams(self, streams):
96
+ self._streams = streams
97
+
98
+ def streams(self, config: Mapping[str, Any]) -> List[Stream]:
99
+ return self._streams
100
+
101
+
102
+ @freezegun.freeze_time("2020-01-01T00:00:00")
103
+ def test_concurrent_source_yields_the_same_messages_as_abstract_source_when_no_exceptions_are_raised():
104
+ records_stream_1_partition_1 = [
105
+ {"id": 1, "partition": "1"},
106
+ {"id": 2, "partition": "1"},
107
+ ]
108
+ records_stream_1_partition_2 = [
109
+ {"id": 3, "partition": "2"},
110
+ {"id": 4, "partition": "2"},
111
+ ]
112
+ records_stream_2_partition_1 = [
113
+ {"id": 100, "partition": "A"},
114
+ {"id": 200, "partition": "A"},
115
+ ]
116
+ records_stream_2_partition_2 = [
117
+ {"id": 300, "partition": "B"},
118
+ {"id": 400, "partition": "B"},
119
+ ]
120
+ stream_1_slice_to_partition = {"1": records_stream_1_partition_1, "2": records_stream_1_partition_2}
121
+ stream_2_slice_to_partition = {"A": records_stream_2_partition_1, "B": records_stream_2_partition_2}
122
+ state = None
123
+ logger = _init_logger()
124
+
125
+ source, concurrent_source = _init_sources([stream_1_slice_to_partition, stream_2_slice_to_partition], state, logger)
126
+
127
+ config = {}
128
+ catalog = _create_configured_catalog(source._streams)
129
+ messages_from_abstract_source = _read_from_source(source, logger, config, catalog, state, None)
130
+ messages_from_concurrent_source = _read_from_source(concurrent_source, logger, config, catalog, state, None)
131
+
132
+ expected_messages = [
133
+ AirbyteMessage(
134
+ type=MessageType.TRACE,
135
+ trace=AirbyteTraceMessage(
136
+ type=TraceType.STREAM_STATUS,
137
+ emitted_at=1577836800000.0,
138
+ error=None,
139
+ estimate=None,
140
+ stream_status=AirbyteStreamStatusTraceMessage(
141
+ stream_descriptor=StreamDescriptor(name="stream0"), status=AirbyteStreamStatus(AirbyteStreamStatus.STARTED)
142
+ ),
143
+ ),
144
+ ),
145
+ AirbyteMessage(
146
+ type=MessageType.TRACE,
147
+ trace=AirbyteTraceMessage(
148
+ type=TraceType.STREAM_STATUS,
149
+ emitted_at=1577836800000.0,
150
+ error=None,
151
+ estimate=None,
152
+ stream_status=AirbyteStreamStatusTraceMessage(
153
+ stream_descriptor=StreamDescriptor(name="stream0"), status=AirbyteStreamStatus(AirbyteStreamStatus.RUNNING)
154
+ ),
155
+ ),
156
+ ),
157
+ AirbyteMessage(
158
+ type=MessageType.RECORD,
159
+ record=AirbyteRecordMessage(
160
+ stream="stream0",
161
+ data=records_stream_1_partition_1[0],
162
+ emitted_at=1577836800000,
163
+ ),
164
+ ),
165
+ AirbyteMessage(
166
+ type=MessageType.RECORD,
167
+ record=AirbyteRecordMessage(
168
+ stream="stream0",
169
+ data=records_stream_1_partition_1[1],
170
+ emitted_at=1577836800000,
171
+ ),
172
+ ),
173
+ AirbyteMessage(
174
+ type=MessageType.RECORD,
175
+ record=AirbyteRecordMessage(
176
+ stream="stream0",
177
+ data=records_stream_1_partition_2[0],
178
+ emitted_at=1577836800000,
179
+ ),
180
+ ),
181
+ AirbyteMessage(
182
+ type=MessageType.RECORD,
183
+ record=AirbyteRecordMessage(
184
+ stream="stream0",
185
+ data=records_stream_1_partition_2[1],
186
+ emitted_at=1577836800000,
187
+ ),
188
+ ),
189
+ AirbyteMessage(
190
+ type=MessageType.TRACE,
191
+ trace=AirbyteTraceMessage(
192
+ type=TraceType.STREAM_STATUS,
193
+ emitted_at=1577836800000.0,
194
+ error=None,
195
+ estimate=None,
196
+ stream_status=AirbyteStreamStatusTraceMessage(
197
+ stream_descriptor=StreamDescriptor(name="stream0"), status=AirbyteStreamStatus(AirbyteStreamStatus.COMPLETE)
198
+ ),
199
+ ),
200
+ ),
201
+ AirbyteMessage(
202
+ type=MessageType.TRACE,
203
+ trace=AirbyteTraceMessage(
204
+ type=TraceType.STREAM_STATUS,
205
+ emitted_at=1577836800000.0,
206
+ error=None,
207
+ estimate=None,
208
+ stream_status=AirbyteStreamStatusTraceMessage(
209
+ stream_descriptor=StreamDescriptor(name="stream1"), status=AirbyteStreamStatus(AirbyteStreamStatus.STARTED)
210
+ ),
211
+ ),
212
+ ),
213
+ AirbyteMessage(
214
+ type=MessageType.TRACE,
215
+ trace=AirbyteTraceMessage(
216
+ type=TraceType.STREAM_STATUS,
217
+ emitted_at=1577836800000.0,
218
+ error=None,
219
+ estimate=None,
220
+ stream_status=AirbyteStreamStatusTraceMessage(
221
+ stream_descriptor=StreamDescriptor(name="stream1"), status=AirbyteStreamStatus(AirbyteStreamStatus.RUNNING)
222
+ ),
223
+ ),
224
+ ),
225
+ AirbyteMessage(
226
+ type=MessageType.RECORD,
227
+ record=AirbyteRecordMessage(
228
+ stream="stream1",
229
+ data=records_stream_2_partition_1[0],
230
+ emitted_at=1577836800000,
231
+ ),
232
+ ),
233
+ AirbyteMessage(
234
+ type=MessageType.RECORD,
235
+ record=AirbyteRecordMessage(
236
+ stream="stream1",
237
+ data=records_stream_2_partition_1[1],
238
+ emitted_at=1577836800000,
239
+ ),
240
+ ),
241
+ AirbyteMessage(
242
+ type=MessageType.RECORD,
243
+ record=AirbyteRecordMessage(
244
+ stream="stream1",
245
+ data=records_stream_2_partition_2[0],
246
+ emitted_at=1577836800000,
247
+ ),
248
+ ),
249
+ AirbyteMessage(
250
+ type=MessageType.RECORD,
251
+ record=AirbyteRecordMessage(
252
+ stream="stream1",
253
+ data=records_stream_2_partition_2[1],
254
+ emitted_at=1577836800000,
255
+ ),
256
+ ),
257
+ AirbyteMessage(
258
+ type=MessageType.TRACE,
259
+ trace=AirbyteTraceMessage(
260
+ type=TraceType.STREAM_STATUS,
261
+ emitted_at=1577836800000.0,
262
+ error=None,
263
+ estimate=None,
264
+ stream_status=AirbyteStreamStatusTraceMessage(
265
+ stream_descriptor=StreamDescriptor(name="stream1"), status=AirbyteStreamStatus(AirbyteStreamStatus.COMPLETE)
266
+ ),
267
+ ),
268
+ ),
269
+ ]
270
+ _verify_messages(expected_messages, messages_from_abstract_source, messages_from_concurrent_source)
271
+
272
+
273
+ @freezegun.freeze_time("2020-01-01T00:00:00")
274
+ def test_concurrent_source_yields_the_same_messages_as_abstract_source_when_a_traced_exception_is_raised():
275
+ records = [{"id": 1, "partition": "1"}, AirbyteTracedException()]
276
+ stream_slice_to_partition = {"1": records}
277
+
278
+ logger = _init_logger()
279
+ state = None
280
+ source, concurrent_source = _init_sources([stream_slice_to_partition], state, logger)
281
+ config = {}
282
+ catalog = _create_configured_catalog(source._streams)
283
+ messages_from_abstract_source = _read_from_source(source, logger, config, catalog, state, AirbyteTracedException)
284
+ messages_from_concurrent_source = _read_from_source(concurrent_source, logger, config, catalog, state, AirbyteTracedException)
285
+
286
+ expected_messages = [
287
+ AirbyteMessage(
288
+ type=MessageType.TRACE,
289
+ trace=AirbyteTraceMessage(
290
+ type=TraceType.STREAM_STATUS,
291
+ emitted_at=1577836800000.0,
292
+ error=None,
293
+ estimate=None,
294
+ stream_status=AirbyteStreamStatusTraceMessage(
295
+ stream_descriptor=StreamDescriptor(name="stream0"), status=AirbyteStreamStatus(AirbyteStreamStatus.STARTED)
296
+ ),
297
+ ),
298
+ ),
299
+ AirbyteMessage(
300
+ type=MessageType.TRACE,
301
+ trace=AirbyteTraceMessage(
302
+ type=TraceType.STREAM_STATUS,
303
+ emitted_at=1577836800000.0,
304
+ error=None,
305
+ estimate=None,
306
+ stream_status=AirbyteStreamStatusTraceMessage(
307
+ stream_descriptor=StreamDescriptor(name="stream0"), status=AirbyteStreamStatus(AirbyteStreamStatus.RUNNING)
308
+ ),
309
+ ),
310
+ ),
311
+ AirbyteMessage(
312
+ type=MessageType.RECORD,
313
+ record=AirbyteRecordMessage(
314
+ stream="stream0",
315
+ data=records[0],
316
+ emitted_at=1577836800000,
317
+ ),
318
+ ),
319
+ AirbyteMessage(
320
+ type=MessageType.TRACE,
321
+ trace=AirbyteTraceMessage(
322
+ type=TraceType.STREAM_STATUS,
323
+ emitted_at=1577836800000.0,
324
+ error=None,
325
+ estimate=None,
326
+ stream_status=AirbyteStreamStatusTraceMessage(
327
+ stream_descriptor=StreamDescriptor(name="stream0"), status=AirbyteStreamStatus(AirbyteStreamStatus.INCOMPLETE)
328
+ ),
329
+ ),
330
+ ),
331
+ ]
332
+ _verify_messages(expected_messages, messages_from_abstract_source, messages_from_concurrent_source)
333
+
334
+
335
+ @freezegun.freeze_time("2020-01-01T00:00:00")
336
+ def test_concurrent_source_yields_the_same_messages_as_abstract_source_when_an_exception_is_raised():
337
+ records = [{"id": 1, "partition": "1"}, RuntimeError()]
338
+ stream_slice_to_partition = {"1": records}
339
+ logger = _init_logger()
340
+
341
+ state = None
342
+
343
+ source, concurrent_source = _init_sources([stream_slice_to_partition], state, logger)
344
+ config = {}
345
+ catalog = _create_configured_catalog(source._streams)
346
+ messages_from_abstract_source = _read_from_source(source, logger, config, catalog, state, RuntimeError)
347
+ messages_from_concurrent_source = _read_from_source(concurrent_source, logger, config, catalog, state, RuntimeError)
348
+
349
+ expected_messages = [
350
+ AirbyteMessage(
351
+ type=MessageType.TRACE,
352
+ trace=AirbyteTraceMessage(
353
+ type=TraceType.STREAM_STATUS,
354
+ emitted_at=1577836800000.0,
355
+ error=None,
356
+ estimate=None,
357
+ stream_status=AirbyteStreamStatusTraceMessage(
358
+ stream_descriptor=StreamDescriptor(name="stream0"), status=AirbyteStreamStatus(AirbyteStreamStatus.STARTED)
359
+ ),
360
+ ),
361
+ ),
362
+ AirbyteMessage(
363
+ type=MessageType.TRACE,
364
+ trace=AirbyteTraceMessage(
365
+ type=TraceType.STREAM_STATUS,
366
+ emitted_at=1577836800000.0,
367
+ error=None,
368
+ estimate=None,
369
+ stream_status=AirbyteStreamStatusTraceMessage(
370
+ stream_descriptor=StreamDescriptor(name="stream0"), status=AirbyteStreamStatus(AirbyteStreamStatus.RUNNING)
371
+ ),
372
+ ),
373
+ ),
374
+ AirbyteMessage(
375
+ type=MessageType.RECORD,
376
+ record=AirbyteRecordMessage(
377
+ stream="stream0",
378
+ data=records[0],
379
+ emitted_at=1577836800000,
380
+ ),
381
+ ),
382
+ AirbyteMessage(
383
+ type=MessageType.TRACE,
384
+ trace=AirbyteTraceMessage(
385
+ type=TraceType.STREAM_STATUS,
386
+ emitted_at=1577836800000.0,
387
+ error=None,
388
+ estimate=None,
389
+ stream_status=AirbyteStreamStatusTraceMessage(
390
+ stream_descriptor=StreamDescriptor(name="stream0"), status=AirbyteStreamStatus(AirbyteStreamStatus.INCOMPLETE)
391
+ ),
392
+ ),
393
+ ),
394
+ ]
395
+ _verify_messages(expected_messages, messages_from_abstract_source, messages_from_concurrent_source)
396
+
397
+
398
+ def _init_logger():
399
+ logger = Mock()
400
+ logger.level = logging.INFO
401
+ logger.isEnabledFor.return_value = False
402
+ return logger
403
+
404
+
405
+ def _init_sources(stream_slice_to_partitions, state, logger):
406
+ source = _init_source(stream_slice_to_partitions, state, logger, _MockSource())
407
+ concurrent_source = _init_source(stream_slice_to_partitions, state, logger, _MockConcurrentSource(logger))
408
+ return source, concurrent_source
409
+
410
+
411
+ def _init_source(stream_slice_to_partitions, state, logger, source):
412
+ cursor = NoopCursor()
413
+ streams = [
414
+ StreamFacade.create_from_stream(_MockStream(stream_slices, f"stream{i}"), source, logger, state, cursor)
415
+ for i, stream_slices in enumerate(stream_slice_to_partitions)
416
+ ]
417
+ source.set_streams(streams)
418
+ return source
419
+
420
+
421
+ def _create_configured_catalog(streams):
422
+ return ConfiguredAirbyteCatalog(
423
+ streams=[
424
+ ConfiguredAirbyteStream(
425
+ stream=AirbyteStream(name=s.name, json_schema={}, supported_sync_modes=[SyncMode.full_refresh]),
426
+ sync_mode=SyncMode.full_refresh,
427
+ cursor_field=None,
428
+ destination_sync_mode=DestinationSyncMode.overwrite,
429
+ )
430
+ for s in streams
431
+ ]
432
+ )
433
+
434
+
435
+ def _read_from_source(source, logger, config, catalog, state, expected_exception):
436
+ messages = []
437
+ try:
438
+ for m in source.read(logger, config, catalog, state):
439
+ messages.append(m)
440
+ except Exception as e:
441
+ if expected_exception:
442
+ assert isinstance(e, expected_exception)
443
+ return messages
444
+
445
+
446
+ def _verify_messages(expected_messages, messages_from_abstract_source, messages_from_concurrent_source):
447
+ assert _compare(expected_messages, messages_from_concurrent_source)
448
+
449
+
450
+ def _compare(s, t):
451
+ # Use a compare method that does not require ordering or hashing the elements
452
+ # We can't rely on the ordering because of the multithreading
453
+ # AirbyteMessage does not implement __eq__ and __hash__
454
+ t = list(t)
455
+ try:
456
+ for elem in s:
457
+ t.remove(elem)
458
+ except ValueError:
459
+ print(f"ValueError: {elem}")
460
+ return False
461
+ return not t
@@ -1,221 +0,0 @@
1
- #
2
- # Copyright (c) 2023 Airbyte, Inc., all rights reserved.
3
- #
4
-
5
- import concurrent
6
- import time
7
- from concurrent.futures import Future
8
- from functools import lru_cache
9
- from logging import Logger
10
- from queue import Queue
11
- from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional
12
-
13
- from airbyte_cdk.models import AirbyteStream, SyncMode
14
- from airbyte_cdk.sources.message import MessageRepository
15
- from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream
16
- from airbyte_cdk.sources.streams.concurrent.availability_strategy import AbstractAvailabilityStrategy, StreamAvailability
17
- from airbyte_cdk.sources.streams.concurrent.cursor import Cursor, NoopCursor
18
- from airbyte_cdk.sources.streams.concurrent.partition_enqueuer import PartitionEnqueuer
19
- from airbyte_cdk.sources.streams.concurrent.partition_reader import PartitionReader
20
- from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition
21
- from airbyte_cdk.sources.streams.concurrent.partitions.partition_generator import PartitionGenerator
22
- from airbyte_cdk.sources.streams.concurrent.partitions.record import Record
23
- from airbyte_cdk.sources.streams.concurrent.partitions.types import PARTITIONS_GENERATED_SENTINEL, PartitionCompleteSentinel, QueueItem
24
- from airbyte_cdk.sources.utils.slice_logger import SliceLogger
25
-
26
-
27
- class ThreadBasedConcurrentStream(AbstractStream):
28
-
29
- DEFAULT_TIMEOUT_SECONDS = 900
30
- DEFAULT_MAX_QUEUE_SIZE = 10_000
31
- DEFAULT_SLEEP_TIME = 0.1
32
-
33
- def __init__(
34
- self,
35
- partition_generator: PartitionGenerator,
36
- max_workers: int,
37
- name: str,
38
- json_schema: Mapping[str, Any],
39
- availability_strategy: AbstractAvailabilityStrategy,
40
- primary_key: List[str],
41
- cursor_field: Optional[str],
42
- slice_logger: SliceLogger,
43
- logger: Logger,
44
- message_repository: MessageRepository,
45
- timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
46
- max_concurrent_tasks: int = DEFAULT_MAX_QUEUE_SIZE,
47
- sleep_time: float = DEFAULT_SLEEP_TIME,
48
- cursor: Cursor = NoopCursor(),
49
- namespace: Optional[str] = None,
50
- ):
51
- self._stream_partition_generator = partition_generator
52
- self._max_workers = max_workers
53
- self._threadpool = concurrent.futures.ThreadPoolExecutor(max_workers=self._max_workers, thread_name_prefix="workerpool")
54
- self._name = name
55
- self._json_schema = json_schema
56
- self._availability_strategy = availability_strategy
57
- self._primary_key = primary_key
58
- self._cursor_field = cursor_field
59
- self._slice_logger = slice_logger
60
- self._logger = logger
61
- self._message_repository = message_repository
62
- self._timeout_seconds = timeout_seconds
63
- self._max_concurrent_tasks = max_concurrent_tasks
64
- self._sleep_time = sleep_time
65
- self._cursor = cursor
66
- self._namespace = namespace
67
-
68
- def read(self) -> Iterable[Record]:
69
- """
70
- Read all data from the stream (only full-refresh is supported at the moment)
71
-
72
- Algorithm:
73
- 1. Submit a future to generate the stream's partition to process.
74
- - This has to be done asynchronously because we sometimes need to submit requests to the API to generate all partitions (eg for substreams).
75
- - The future will add the partitions to process on a work queue.
76
- 2. Continuously poll work from the work queue until all partitions are generated and processed
77
- - If the next work item is an Exception, stop the threadpool and raise it.
78
- - If the next work item is a partition, submit a future to process it.
79
- - The future will add the records to emit on the work queue.
80
- - Add the partitions to the partitions_to_done dict so we know it needs to complete for the sync to succeed.
81
- - If the next work item is a record, yield the record.
82
- - If the next work item is PARTITIONS_GENERATED_SENTINEL, all the partitions were generated.
83
- - If the next work item is a PartitionCompleteSentinel, a partition is done processing.
84
- - Update the value in partitions_to_done to True so we know the partition is completed.
85
- """
86
- self._logger.debug(f"Processing stream slices for {self.name}")
87
- futures: List[Future[Any]] = []
88
- queue: Queue[QueueItem] = Queue()
89
- partition_generator = PartitionEnqueuer(queue, PARTITIONS_GENERATED_SENTINEL)
90
- partition_reader = PartitionReader(queue)
91
-
92
- self._submit_task(futures, partition_generator.generate_partitions, self._stream_partition_generator)
93
-
94
- # True -> partition is done
95
- # False -> partition is not done
96
- partitions_to_done: Dict[Partition, bool] = {}
97
-
98
- finished_partitions = False
99
- while record_or_partition_or_exception := queue.get(block=True, timeout=self._timeout_seconds):
100
- if isinstance(record_or_partition_or_exception, Exception):
101
- # An exception was raised while processing the stream
102
- # Stop the threadpool and raise it
103
- self._stop_and_raise_exception(record_or_partition_or_exception)
104
- elif record_or_partition_or_exception == PARTITIONS_GENERATED_SENTINEL:
105
- # All partitions were generated
106
- finished_partitions = True
107
- elif isinstance(record_or_partition_or_exception, PartitionCompleteSentinel):
108
- # All records for a partition were generated
109
- if record_or_partition_or_exception.partition not in partitions_to_done:
110
- raise RuntimeError(
111
- f"Received sentinel for partition {record_or_partition_or_exception.partition} that was not in partitions. This is indicative of a bug in the CDK. Please contact support.partitions:\n{partitions_to_done}"
112
- )
113
- partitions_to_done[record_or_partition_or_exception.partition] = True
114
- self._cursor.close_partition(record_or_partition_or_exception.partition)
115
- elif isinstance(record_or_partition_or_exception, Record):
116
- # Emit records
117
- yield record_or_partition_or_exception
118
- self._cursor.observe(record_or_partition_or_exception)
119
- elif isinstance(record_or_partition_or_exception, Partition):
120
- # A new partition was generated and must be processed
121
- partitions_to_done[record_or_partition_or_exception] = False
122
- if self._slice_logger.should_log_slice_message(self._logger):
123
- self._message_repository.emit_message(
124
- self._slice_logger.create_slice_log_message(record_or_partition_or_exception.to_slice())
125
- )
126
- self._submit_task(futures, partition_reader.process_partition, record_or_partition_or_exception)
127
- if finished_partitions and all(partitions_to_done.values()):
128
- # All partitions were generated and process. We're done here
129
- break
130
-
131
- self._check_for_errors(futures)
132
-
133
- def _submit_task(self, futures: List[Future[Any]], function: Callable[..., Any], *args: Any) -> None:
134
- # Submit a task to the threadpool, waiting if there are too many pending tasks
135
- self._wait_while_too_many_pending_futures(futures)
136
- futures.append(self._threadpool.submit(function, *args))
137
-
138
- def _wait_while_too_many_pending_futures(self, futures: List[Future[Any]]) -> None:
139
- # Wait until the number of pending tasks is < self._max_concurrent_tasks
140
- while True:
141
- self._prune_futures(futures)
142
- if len(futures) < self._max_concurrent_tasks:
143
- break
144
- self._logger.info("Main thread is sleeping because the task queue is full...")
145
- time.sleep(self._sleep_time)
146
-
147
- def _prune_futures(self, futures: List[Future[Any]]) -> None:
148
- """
149
- Take a list in input and remove the futures that are completed. If a future has an exception, it'll raise and kill the stream
150
- operation.
151
-
152
- Pruning this list safely relies on the assumptions that only the main thread can modify the list of futures.
153
- """
154
- if len(futures) < self._max_concurrent_tasks:
155
- return
156
-
157
- for index in reversed(range(len(futures))):
158
- future = futures[index]
159
- optional_exception = future.exception()
160
- if optional_exception:
161
- exception = RuntimeError(f"Failed reading from stream {self.name} with error: {optional_exception}")
162
- self._stop_and_raise_exception(exception)
163
-
164
- if future.done():
165
- futures.pop(index)
166
-
167
- def _check_for_errors(self, futures: List[Future[Any]]) -> None:
168
- exceptions_from_futures = [f for f in [future.exception() for future in futures] if f is not None]
169
- if exceptions_from_futures:
170
- exception = RuntimeError(f"Failed reading from stream {self.name} with errors: {exceptions_from_futures}")
171
- self._stop_and_raise_exception(exception)
172
- else:
173
- futures_not_done = [f for f in futures if not f.done()]
174
- if futures_not_done:
175
- exception = RuntimeError(f"Failed reading from stream {self.name} with futures not done: {futures_not_done}")
176
- self._stop_and_raise_exception(exception)
177
-
178
- def _stop_and_raise_exception(self, exception: BaseException) -> None:
179
- self._threadpool.shutdown(wait=False, cancel_futures=True)
180
- raise exception
181
-
182
- @property
183
- def name(self) -> str:
184
- return self._name
185
-
186
- def check_availability(self) -> StreamAvailability:
187
- return self._availability_strategy.check_availability(self._logger)
188
-
189
- @property
190
- def cursor_field(self) -> Optional[str]:
191
- return self._cursor_field
192
-
193
- @lru_cache(maxsize=None)
194
- def get_json_schema(self) -> Mapping[str, Any]:
195
- return self._json_schema
196
-
197
- def as_airbyte_stream(self) -> AirbyteStream:
198
- stream = AirbyteStream(name=self.name, json_schema=dict(self._json_schema), supported_sync_modes=[SyncMode.full_refresh])
199
-
200
- if self._namespace:
201
- stream.namespace = self._namespace
202
-
203
- if self._cursor_field:
204
- stream.source_defined_cursor = True
205
- stream.supported_sync_modes.append(SyncMode.incremental)
206
- stream.default_cursor_field = [self._cursor_field]
207
-
208
- keys = self._primary_key
209
- if keys and len(keys) > 0:
210
- stream.source_defined_primary_key = [keys]
211
-
212
- return stream
213
-
214
- def log_stream_sync_configuration(self) -> None:
215
- self._logger.debug(
216
- f"Syncing stream instance: {self.name}",
217
- extra={
218
- "primary_key": self._primary_key,
219
- "cursor_field": self.cursor_field,
220
- },
221
- )