airbyte-cdk 0.52.6__py3-none-any.whl → 0.52.8__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (32) hide show
  1. airbyte_cdk/destinations/vector_db_based/config.py +1 -0
  2. airbyte_cdk/sources/abstract_source.py +12 -61
  3. airbyte_cdk/sources/file_based/config/unstructured_format.py +1 -1
  4. airbyte_cdk/sources/file_based/file_types/unstructured_parser.py +1 -2
  5. airbyte_cdk/sources/message/repository.py +0 -6
  6. airbyte_cdk/sources/source.py +14 -13
  7. airbyte_cdk/sources/streams/concurrent/adapters.py +94 -21
  8. airbyte_cdk/sources/streams/concurrent/cursor.py +148 -0
  9. airbyte_cdk/sources/streams/concurrent/partition_enqueuer.py +2 -3
  10. airbyte_cdk/sources/streams/concurrent/partitions/partition.py +3 -0
  11. airbyte_cdk/sources/streams/concurrent/partitions/partition_generator.py +1 -3
  12. airbyte_cdk/sources/streams/concurrent/thread_based_concurrent_stream.py +7 -3
  13. airbyte_cdk/sources/streams/core.py +71 -1
  14. {airbyte_cdk-0.52.6.dist-info → airbyte_cdk-0.52.8.dist-info}/METADATA +3 -3
  15. {airbyte_cdk-0.52.6.dist-info → airbyte_cdk-0.52.8.dist-info}/RECORD +32 -30
  16. {airbyte_cdk-0.52.6.dist-info → airbyte_cdk-0.52.8.dist-info}/WHEEL +1 -1
  17. unit_tests/sources/file_based/file_types/test_unstructured_parser.py +5 -0
  18. unit_tests/sources/file_based/scenarios/csv_scenarios.py +1 -1
  19. unit_tests/sources/file_based/scenarios/unstructured_scenarios.py +16 -0
  20. unit_tests/sources/message/test_repository.py +7 -20
  21. unit_tests/sources/streams/concurrent/scenarios/stream_facade_builder.py +46 -5
  22. unit_tests/sources/streams/concurrent/scenarios/stream_facade_scenarios.py +154 -37
  23. unit_tests/sources/streams/concurrent/scenarios/test_concurrent_scenarios.py +6 -0
  24. unit_tests/sources/streams/concurrent/scenarios/thread_based_concurrent_stream_source_builder.py +19 -3
  25. unit_tests/sources/streams/concurrent/test_adapters.py +48 -22
  26. unit_tests/sources/streams/concurrent/test_concurrent_partition_generator.py +5 -4
  27. unit_tests/sources/streams/concurrent/test_cursor.py +130 -0
  28. unit_tests/sources/streams/concurrent/test_thread_based_concurrent_stream.py +14 -10
  29. unit_tests/sources/streams/test_stream_read.py +3 -1
  30. unit_tests/sources/test_abstract_source.py +12 -9
  31. {airbyte_cdk-0.52.6.dist-info → airbyte_cdk-0.52.8.dist-info}/LICENSE.txt +0 -0
  32. {airbyte_cdk-0.52.6.dist-info → airbyte_cdk-0.52.8.dist-info}/top_level.txt +0 -0
@@ -85,6 +85,7 @@ class ProcessingConfigModel(BaseModel):
85
85
  ...,
86
86
  title="Chunk size",
87
87
  maximum=8191,
88
+ minimum=1,
88
89
  description="Size of chunks in tokens to store in vector store (make sure it is not too big for the context if your LLM)",
89
90
  )
90
91
  chunk_overlap: int = Field(
@@ -225,56 +225,18 @@ class AbstractSource(Source, ABC):
225
225
 
226
226
  if stream_state and "state" in dir(stream_instance):
227
227
  stream_instance.state = stream_state # type: ignore # we check that state in the dir(stream_instance)
228
- logger.info(f"Setting state of {stream_name} stream to {stream_state}")
229
-
230
- slices = stream_instance.stream_slices(
231
- cursor_field=configured_stream.cursor_field,
232
- sync_mode=SyncMode.incremental,
233
- stream_state=stream_state,
234
- )
235
- logger.debug(f"Processing stream slices for {stream_name} (sync_mode: incremental)", extra={"stream_slices": slices})
236
-
237
- total_records_counter = 0
238
- has_slices = False
239
- for _slice in slices:
240
- has_slices = True
241
- if self._slice_logger.should_log_slice_message(logger):
242
- yield self._slice_logger.create_slice_log_message(_slice)
243
- records = stream_instance.read_records(
244
- sync_mode=SyncMode.incremental,
245
- stream_slice=_slice,
246
- stream_state=stream_state,
247
- cursor_field=configured_stream.cursor_field or None,
248
- )
249
- record_counter = 0
250
- for message_counter, record_data_or_message in enumerate(records, start=1):
251
- message = self._get_message(record_data_or_message, stream_instance)
252
- yield from self._emit_queued_messages()
253
- yield message
254
- if message.type == MessageType.RECORD:
255
- record = message.record
256
- stream_state = stream_instance.get_updated_state(stream_state, record.data)
257
- checkpoint_interval = stream_instance.state_checkpoint_interval
258
- record_counter += 1
259
- if checkpoint_interval and record_counter % checkpoint_interval == 0:
260
- yield self._checkpoint_state(stream_instance, stream_state, state_manager)
261
-
262
- total_records_counter += 1
263
- # This functionality should ideally live outside of this method
264
- # but since state is managed inside this method, we keep track
265
- # of it here.
266
- if internal_config.is_limit_reached(total_records_counter):
267
- # Break from slice loop to save state and exit from _read_incremental function.
268
- break
269
-
270
- yield self._checkpoint_state(stream_instance, stream_state, state_manager)
271
- if internal_config.is_limit_reached(total_records_counter):
272
- return
273
-
274
- if not has_slices:
275
- # Safety net to ensure we always emit at least one state message even if there are no slices
276
- checkpoint = self._checkpoint_state(stream_instance, stream_state, state_manager)
277
- yield checkpoint
228
+ logger.info(f"Setting state of {self.name} stream to {stream_state}")
229
+
230
+ for record_data_or_message in stream_instance.read_incremental(
231
+ configured_stream.cursor_field,
232
+ logger,
233
+ self._slice_logger,
234
+ stream_state,
235
+ state_manager,
236
+ self.per_stream_state_enabled,
237
+ internal_config,
238
+ ):
239
+ yield self._get_message(record_data_or_message, stream_instance)
278
240
 
279
241
  def _emit_queued_messages(self) -> Iterable[AirbyteMessage]:
280
242
  if self.message_repository:
@@ -297,17 +259,6 @@ class AbstractSource(Source, ABC):
297
259
  if internal_config.is_limit_reached(total_records_counter):
298
260
  return
299
261
 
300
- def _checkpoint_state(self, stream: Stream, stream_state: Mapping[str, Any], state_manager: ConnectorStateManager) -> AirbyteMessage:
301
- # First attempt to retrieve the current state using the stream's state property. We receive an AttributeError if the state
302
- # property is not implemented by the stream instance and as a fallback, use the stream_state retrieved from the stream
303
- # instance's deprecated get_updated_state() method.
304
- try:
305
- state_manager.update_state_for_stream(stream.name, stream.namespace, stream.state) # type: ignore # we know the field might not exist...
306
-
307
- except AttributeError:
308
- state_manager.update_state_for_stream(stream.name, stream.namespace, stream_state)
309
- return state_manager.create_state_message(stream.name, stream.namespace, send_per_stream_state=self.per_stream_state_enabled)
310
-
311
262
  @staticmethod
312
263
  def _apply_log_level_to_stream_logger(logger: logging.Logger, stream_instance: Stream) -> None:
313
264
  """
@@ -8,7 +8,7 @@ from pydantic import BaseModel, Field
8
8
  class UnstructuredFormat(BaseModel):
9
9
  class Config:
10
10
  title = "Document File Type Format (Experimental)"
11
- schema_extra = {"description": "Extract text from document formats (.pdf, .docx, .md) and emit as one record per file."}
11
+ schema_extra = {"description": "Extract text from document formats (.pdf, .docx, .md, .pptx) and emit as one record per file."}
12
12
 
13
13
  filetype: str = Field(
14
14
  "unstructured",
@@ -103,7 +103,6 @@ class UnstructuredParser(FileTypeParser):
103
103
  return self._render_markdown(elements)
104
104
 
105
105
  def _get_filetype(self, file: IOBase, file_name: str) -> Any:
106
-
107
106
  # set name to none, otherwise unstructured will try to get the modified date from the local file system
108
107
  if hasattr(file, "name"):
109
108
  file.name = None
@@ -114,7 +113,7 @@ class UnstructuredParser(FileTypeParser):
114
113
  )
115
114
 
116
115
  def _supported_file_types(self) -> List[Any]:
117
- return [FileType.MD, FileType.PDF, FileType.DOCX]
116
+ return [FileType.MD, FileType.PDF, FileType.DOCX, FileType.PPTX]
118
117
 
119
118
  def _render_markdown(self, elements: List[Any]) -> str:
120
119
  return "\n\n".join((self._convert_to_markdown(el) for el in elements))
@@ -75,12 +75,6 @@ class InMemoryMessageRepository(MessageRepository):
75
75
  self._log_level = log_level
76
76
 
77
77
  def emit_message(self, message: AirbyteMessage) -> None:
78
- """
79
- :param message: As of today, only AirbyteControlMessages are supported given that supporting other types of message will need more
80
- work and therefore this work has been postponed
81
- """
82
- if message.type not in _SUPPORTED_MESSAGE_TYPES:
83
- raise ValueError(f"As of today, only {_SUPPORTED_MESSAGE_TYPES} are supported as part of the InMemoryMessageRepository")
84
78
  self._message_queue.append(message)
85
79
 
86
80
  def log_message(self, level: Level, message_provider: Callable[[], LogMessage]) -> None:
@@ -6,7 +6,7 @@
6
6
  import logging
7
7
  from abc import ABC, abstractmethod
8
8
  from collections import defaultdict
9
- from typing import Any, Generic, Iterable, List, Mapping, MutableMapping, TypeVar, Union
9
+ from typing import Any, Dict, Generic, Iterable, List, Mapping, MutableMapping, Optional, TypeVar, Union
10
10
 
11
11
  from airbyte_cdk.connector import BaseConnector, DefaultConnectorMixin, TConfig
12
12
  from airbyte_cdk.models import AirbyteCatalog, AirbyteMessage, AirbyteStateMessage, AirbyteStateType, ConfiguredAirbyteCatalog
@@ -25,7 +25,7 @@ class BaseSource(BaseConnector[TConfig], ABC, Generic[TConfig, TState, TCatalog]
25
25
  ...
26
26
 
27
27
  @abstractmethod
28
- def read(self, logger: logging.Logger, config: TConfig, catalog: TCatalog, state: TState = None) -> Iterable[AirbyteMessage]:
28
+ def read(self, logger: logging.Logger, config: TConfig, catalog: TCatalog, state: Optional[TState] = None) -> Iterable[AirbyteMessage]:
29
29
  """
30
30
  Returns a generator of the AirbyteMessages generated by reading the source with the given configuration, catalog, and state.
31
31
  """
@@ -43,8 +43,9 @@ class Source(
43
43
  BaseSource[Mapping[str, Any], Union[List[AirbyteStateMessage], MutableMapping[str, Any]], ConfiguredAirbyteCatalog],
44
44
  ABC,
45
45
  ):
46
- # can be overridden to change an input state
47
- def read_state(self, state_path: str) -> Union[List[AirbyteStateMessage], MutableMapping[str, Any]]:
46
+ # can be overridden to change an input state.
47
+ @classmethod
48
+ def read_state(cls, state_path: str) -> Union[List[AirbyteStateMessage], MutableMapping[str, Any]]:
48
49
  """
49
50
  Retrieves the input state of a sync by reading from the specified JSON file. Incoming state can be deserialized into either
50
51
  a JSON object for legacy state input or as a list of AirbyteStateMessages for the per-stream state format. Regardless of the
@@ -53,30 +54,30 @@ class Source(
53
54
  :return: The complete stream state based on the connector's previous sync
54
55
  """
55
56
  if state_path:
56
- state_obj = self._read_json_file(state_path)
57
+ state_obj = BaseConnector._read_json_file(state_path)
57
58
  if not state_obj:
58
- return self._emit_legacy_state_format({})
59
- is_per_stream_state = isinstance(state_obj, List)
60
- if is_per_stream_state:
59
+ return cls._emit_legacy_state_format({})
60
+ if isinstance(state_obj, List):
61
61
  parsed_state_messages = []
62
- for state in state_obj:
62
+ for state in state_obj: # type: ignore # `isinstance(state_obj, List)` ensures that this is a list
63
63
  parsed_message = AirbyteStateMessage.parse_obj(state)
64
64
  if not parsed_message.stream and not parsed_message.data and not parsed_message.global_:
65
65
  raise ValueError("AirbyteStateMessage should contain either a stream, global, or state field")
66
66
  parsed_state_messages.append(parsed_message)
67
67
  return parsed_state_messages
68
68
  else:
69
- return self._emit_legacy_state_format(state_obj)
70
- return self._emit_legacy_state_format({})
69
+ return cls._emit_legacy_state_format(state_obj) # type: ignore # assuming it is a dict
70
+ return cls._emit_legacy_state_format({})
71
71
 
72
- def _emit_legacy_state_format(self, state_obj) -> Union[List[AirbyteStateMessage], MutableMapping[str, Any]]:
72
+ @classmethod
73
+ def _emit_legacy_state_format(cls, state_obj: Dict[str, Any]) -> Union[List[AirbyteStateMessage], MutableMapping[str, Any]]:
73
74
  """
74
75
  Existing connectors that override read() might not be able to interpret the new state format. We temporarily
75
76
  send state in the old format for these connectors, but once all have been upgraded, this method can be removed,
76
77
  and we can then emit state in the list format.
77
78
  """
78
79
  # vars(self.__class__) checks if the current class directly overrides the read() function
79
- if "read" in vars(self.__class__):
80
+ if "read" in vars(cls):
80
81
  return defaultdict(dict, state_obj)
81
82
  else:
82
83
  if state_obj:
@@ -6,10 +6,11 @@ import copy
6
6
  import json
7
7
  import logging
8
8
  from functools import lru_cache
9
- from typing import Any, Iterable, List, Mapping, Optional, Tuple, Union
9
+ from typing import Any, Iterable, List, Mapping, MutableMapping, Optional, Tuple, Union
10
10
 
11
11
  from airbyte_cdk.models import AirbyteStream, SyncMode
12
12
  from airbyte_cdk.sources import AbstractSource, Source
13
+ from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager
13
14
  from airbyte_cdk.sources.message import MessageRepository
14
15
  from airbyte_cdk.sources.streams import Stream
15
16
  from airbyte_cdk.sources.streams.availability_strategy import AvailabilityStrategy
@@ -20,12 +21,14 @@ from airbyte_cdk.sources.streams.concurrent.availability_strategy import (
20
21
  StreamAvailable,
21
22
  StreamUnavailable,
22
23
  )
24
+ from airbyte_cdk.sources.streams.concurrent.cursor import Cursor, NoopCursor
23
25
  from airbyte_cdk.sources.streams.concurrent.exceptions import ExceptionWithDisplayMessage
24
26
  from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition
25
27
  from airbyte_cdk.sources.streams.concurrent.partitions.partition_generator import PartitionGenerator
26
28
  from airbyte_cdk.sources.streams.concurrent.partitions.record import Record
27
29
  from airbyte_cdk.sources.streams.concurrent.thread_based_concurrent_stream import ThreadBasedConcurrentStream
28
30
  from airbyte_cdk.sources.streams.core import StreamData
31
+ from airbyte_cdk.sources.utils.schema_helpers import InternalConfig
29
32
  from airbyte_cdk.sources.utils.slice_logger import SliceLogger
30
33
  from deprecated.classic import deprecated
31
34
 
@@ -44,7 +47,15 @@ class StreamFacade(Stream):
44
47
  """
45
48
 
46
49
  @classmethod
47
- def create_from_stream(cls, stream: Stream, source: AbstractSource, logger: logging.Logger, max_workers: int) -> Stream:
50
+ def create_from_stream(
51
+ cls,
52
+ stream: Stream,
53
+ source: AbstractSource,
54
+ logger: logging.Logger,
55
+ max_workers: int,
56
+ state: Optional[MutableMapping[str, Any]],
57
+ cursor: Cursor,
58
+ ) -> Stream:
48
59
  """
49
60
  Create a ConcurrentStream from a Stream object.
50
61
  :param source: The source
@@ -63,9 +74,16 @@ class StreamFacade(Stream):
63
74
  message_repository = source.message_repository
64
75
  return StreamFacade(
65
76
  ThreadBasedConcurrentStream(
66
- partition_generator=StreamPartitionGenerator(stream, message_repository),
77
+ partition_generator=StreamPartitionGenerator(
78
+ stream,
79
+ message_repository,
80
+ SyncMode.full_refresh if isinstance(cursor, NoopCursor) else SyncMode.incremental,
81
+ [cursor_field] if cursor_field is not None else None,
82
+ state,
83
+ ),
67
84
  max_workers=max_workers,
68
85
  name=stream.name,
86
+ namespace=stream.namespace,
69
87
  json_schema=stream.get_json_schema(),
70
88
  availability_strategy=StreamAvailabilityStrategy(stream, source),
71
89
  primary_key=pk,
@@ -73,9 +91,21 @@ class StreamFacade(Stream):
73
91
  slice_logger=source._slice_logger,
74
92
  message_repository=message_repository,
75
93
  logger=logger,
76
- )
94
+ cursor=cursor,
95
+ ),
96
+ stream,
97
+ cursor,
77
98
  )
78
99
 
100
+ @property
101
+ def state(self) -> MutableMapping[str, Any]:
102
+ raise NotImplementedError("This should not be called as part of the Concurrent CDK code. Please report the problem to Airbyte")
103
+
104
+ @state.setter
105
+ def state(self, value: Mapping[str, Any]) -> None:
106
+ if "state" in dir(self._legacy_stream):
107
+ self._legacy_stream.state = value # type: ignore # validating `state` is attribute of stream using `if` above
108
+
79
109
  @classmethod
80
110
  def _get_primary_key_from_stream(cls, stream_primary_key: Optional[Union[str, List[str], List[List[str]]]]) -> List[str]:
81
111
  if stream_primary_key is None:
@@ -102,11 +132,13 @@ class StreamFacade(Stream):
102
132
  else:
103
133
  return stream.cursor_field
104
134
 
105
- def __init__(self, stream: AbstractStream):
135
+ def __init__(self, stream: AbstractStream, legacy_stream: Stream, cursor: Cursor):
106
136
  """
107
137
  :param stream: The underlying AbstractStream
108
138
  """
109
139
  self._abstract_stream = stream
140
+ self._legacy_stream = legacy_stream
141
+ self._cursor = cursor
110
142
 
111
143
  def read_full_refresh(
112
144
  self,
@@ -121,8 +153,19 @@ class StreamFacade(Stream):
121
153
  :param slice_logger: (ignored)
122
154
  :return: Iterable of StreamData
123
155
  """
124
- for record in self._abstract_stream.read():
125
- yield record.data
156
+ yield from self._read_records()
157
+
158
+ def read_incremental(
159
+ self,
160
+ cursor_field: Optional[List[str]],
161
+ logger: logging.Logger,
162
+ slice_logger: SliceLogger,
163
+ stream_state: MutableMapping[str, Any],
164
+ state_manager: ConnectorStateManager,
165
+ per_stream_state_enabled: bool,
166
+ internal_config: InternalConfig,
167
+ ) -> Iterable[StreamData]:
168
+ yield from self._read_records()
126
169
 
127
170
  def read_records(
128
171
  self,
@@ -131,12 +174,11 @@ class StreamFacade(Stream):
131
174
  stream_slice: Optional[Mapping[str, Any]] = None,
132
175
  stream_state: Optional[Mapping[str, Any]] = None,
133
176
  ) -> Iterable[StreamData]:
134
- if sync_mode == SyncMode.full_refresh:
135
- for record in self._abstract_stream.read():
136
- yield record.data
137
- else:
138
- # Incremental reads are not supported
139
- raise NotImplementedError
177
+ yield from self._read_records()
178
+
179
+ def _read_records(self) -> Iterable[StreamData]:
180
+ for record in self._abstract_stream.read():
181
+ yield record.data
140
182
 
141
183
  @property
142
184
  def name(self) -> str:
@@ -165,8 +207,7 @@ class StreamFacade(Stream):
165
207
 
166
208
  @property
167
209
  def supports_incremental(self) -> bool:
168
- # Only full refresh is supported
169
- return False
210
+ return self._legacy_stream.supports_incremental
170
211
 
171
212
  def check_availability(self, logger: logging.Logger, source: Optional["Source"] = None) -> Tuple[bool, Optional[str]]:
172
213
  """
@@ -210,7 +251,15 @@ class StreamPartition(Partition):
210
251
  In the long-run, it would be preferable to update the connectors, but we don't have the tooling or need to justify the effort at this time.
211
252
  """
212
253
 
213
- def __init__(self, stream: Stream, _slice: Optional[Mapping[str, Any]], message_repository: MessageRepository):
254
+ def __init__(
255
+ self,
256
+ stream: Stream,
257
+ _slice: Optional[Mapping[str, Any]],
258
+ message_repository: MessageRepository,
259
+ sync_mode: SyncMode,
260
+ cursor_field: Optional[List[str]],
261
+ state: Optional[MutableMapping[str, Any]],
262
+ ):
214
263
  """
215
264
  :param stream: The stream to delegate to
216
265
  :param _slice: The partition's stream_slice
@@ -219,6 +268,9 @@ class StreamPartition(Partition):
219
268
  self._stream = stream
220
269
  self._slice = _slice
221
270
  self._message_repository = message_repository
271
+ self._sync_mode = sync_mode
272
+ self._cursor_field = cursor_field
273
+ self._state = state
222
274
 
223
275
  def read(self) -> Iterable[Record]:
224
276
  """
@@ -227,7 +279,18 @@ class StreamPartition(Partition):
227
279
  Otherwise, the message will be emitted on the message repository.
228
280
  """
229
281
  try:
230
- for record_data in self._stream.read_records(sync_mode=SyncMode.full_refresh, stream_slice=copy.deepcopy(self._slice)):
282
+ # using `stream_state=self._state` have a very different behavior than the current one as today the state is updated slice
283
+ # by slice incrementally. We don't have this guarantee with Concurrent CDK. For HttpStream, `stream_state` is passed to:
284
+ # * fetch_next_page
285
+ # * parse_response
286
+ # Both are not used for Stripe so we should be good for the first iteration of Concurrent CDK. However, Stripe still do
287
+ # `if not stream_state` to know if it calls the Event stream or not
288
+ for record_data in self._stream.read_records(
289
+ cursor_field=self._cursor_field,
290
+ sync_mode=SyncMode.full_refresh,
291
+ stream_slice=copy.deepcopy(self._slice),
292
+ stream_state=self._state,
293
+ ):
231
294
  if isinstance(record_data, Mapping):
232
295
  data_to_return = dict(record_data)
233
296
  self._stream.transformer.transform(data_to_return, self._stream.get_json_schema())
@@ -264,17 +327,27 @@ class StreamPartitionGenerator(PartitionGenerator):
264
327
  In the long-run, it would be preferable to update the connectors, but we don't have the tooling or need to justify the effort at this time.
265
328
  """
266
329
 
267
- def __init__(self, stream: Stream, message_repository: MessageRepository):
330
+ def __init__(
331
+ self,
332
+ stream: Stream,
333
+ message_repository: MessageRepository,
334
+ sync_mode: SyncMode,
335
+ cursor_field: Optional[List[str]],
336
+ state: Optional[MutableMapping[str, Any]],
337
+ ):
268
338
  """
269
339
  :param stream: The stream to delegate to
270
340
  :param message_repository: The message repository to use to emit non-record messages
271
341
  """
272
342
  self.message_repository = message_repository
273
343
  self._stream = stream
344
+ self._sync_mode = sync_mode
345
+ self._cursor_field = cursor_field
346
+ self._state = state
274
347
 
275
- def generate(self, sync_mode: SyncMode) -> Iterable[Partition]:
276
- for s in self._stream.stream_slices(sync_mode=sync_mode):
277
- yield StreamPartition(self._stream, copy.deepcopy(s), self.message_repository)
348
+ def generate(self) -> Iterable[Partition]:
349
+ for s in self._stream.stream_slices(sync_mode=self._sync_mode, cursor_field=self._cursor_field, stream_state=self._state):
350
+ yield StreamPartition(self._stream, copy.deepcopy(s), self.message_repository, self._sync_mode, self._cursor_field, self._state)
278
351
 
279
352
 
280
353
  @deprecated("This class is experimental. Use at your own risk.")
@@ -0,0 +1,148 @@
1
+ import functools
2
+ from abc import ABC, abstractmethod
3
+ from typing import Any, List, Mapping, Optional, Protocol, Tuple
4
+
5
+ from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager
6
+ from airbyte_cdk.sources.message import MessageRepository
7
+ from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition
8
+ from airbyte_cdk.sources.streams.concurrent.partitions.record import Record
9
+
10
+
11
+ def _extract_value(mapping: Mapping[str, Any], path: List[str]) -> Any:
12
+ return functools.reduce(lambda a, b: a[b], path, mapping)
13
+
14
+
15
+ class Comparable(Protocol):
16
+ """Protocol for annotating comparable types."""
17
+
18
+ @abstractmethod
19
+ def __lt__(self: "Comparable", other: "Comparable") -> bool:
20
+ pass
21
+
22
+
23
+ class CursorField:
24
+ def __init__(self, cursor_field_key: str) -> None:
25
+ self._cursor_field_key = cursor_field_key
26
+
27
+ def extract_value(self, record: Record) -> Comparable:
28
+ cursor_value = record.data.get(self._cursor_field_key)
29
+ if cursor_value is None:
30
+ raise ValueError(f"Could not find cursor field {self._cursor_field_key} in record")
31
+ return cursor_value # type: ignore # we assume that the value the path points at is a comparable
32
+
33
+
34
+ class Cursor(ABC):
35
+ @abstractmethod
36
+ def observe(self, record: Record) -> None:
37
+ """
38
+ Indicate to the cursor that the record has been emitted
39
+ """
40
+ raise NotImplementedError()
41
+
42
+ @abstractmethod
43
+ def close_partition(self, partition: Partition) -> None:
44
+ """
45
+ Indicate to the cursor that the partition has been successfully processed
46
+ """
47
+ raise NotImplementedError()
48
+
49
+
50
+ class NoopCursor(Cursor):
51
+ def observe(self, record: Record) -> None:
52
+ pass
53
+
54
+ def close_partition(self, partition: Partition) -> None:
55
+ pass
56
+
57
+
58
+ class ConcurrentCursor(Cursor):
59
+ _START_BOUNDARY = 0
60
+ _END_BOUNDARY = 1
61
+
62
+ def __init__(
63
+ self,
64
+ stream_name: str,
65
+ stream_namespace: Optional[str],
66
+ stream_state: Any,
67
+ message_repository: MessageRepository,
68
+ connector_state_manager: ConnectorStateManager,
69
+ cursor_field: CursorField,
70
+ slice_boundary_fields: Optional[Tuple[str, str]],
71
+ ) -> None:
72
+ self._stream_name = stream_name
73
+ self._stream_namespace = stream_namespace
74
+ self._message_repository = message_repository
75
+ self._connector_state_manager = connector_state_manager
76
+ self._cursor_field = cursor_field
77
+ # To see some example where the slice boundaries might not be defined, check https://github.com/airbytehq/airbyte/blob/1ce84d6396e446e1ac2377362446e3fb94509461/airbyte-integrations/connectors/source-stripe/source_stripe/streams.py#L363-L379
78
+ self._slice_boundary_fields = slice_boundary_fields if slice_boundary_fields else tuple()
79
+ self._most_recent_record: Optional[Record] = None
80
+ self._has_closed_at_least_one_slice = False
81
+
82
+ # TODO to migrate state. The migration should probably be outside of this class. Impact of not having this:
83
+ # * Given a sync that emits no records, the emitted state message will be empty
84
+ self._state = { # type: ignore # waiting for when we implement the state migration
85
+ "slices": [] # empty for now but should look like `{start: 1, end: 10, parent_id: "id1"}`
86
+ }
87
+
88
+ def observe(self, record: Record) -> None:
89
+ if self._slice_boundary_fields:
90
+ # Given that slicing is done using the cursor field, we don't need to observe the record as we assume slices will describe what
91
+ # has been emitted. Assuming there is a chance that records might not be yet populated for the most recent slice, use a lookback
92
+ # window
93
+ return
94
+
95
+ if not self._most_recent_record or self._extract_cursor_value(self._most_recent_record) < self._extract_cursor_value(record):
96
+ self._most_recent_record = record
97
+
98
+ def _extract_cursor_value(self, record: Record) -> Comparable:
99
+ return self._cursor_field.extract_value(record)
100
+
101
+ def close_partition(self, partition: Partition) -> None:
102
+ slice_count_before = len(self._state["slices"])
103
+ self._add_slice_to_state(partition)
104
+ if slice_count_before < len(self._state["slices"]):
105
+ self._merge_partitions()
106
+ self._emit_state_message()
107
+ self._has_closed_at_least_one_slice = True
108
+
109
+ def _add_slice_to_state(self, partition: Partition) -> None:
110
+ if self._slice_boundary_fields:
111
+ self._state["slices"].append(
112
+ {
113
+ "start": self._extract_from_slice(partition, self._slice_boundary_fields[self._START_BOUNDARY]),
114
+ "end": self._extract_from_slice(partition, self._slice_boundary_fields[self._END_BOUNDARY]),
115
+ }
116
+ )
117
+ elif self._most_recent_record:
118
+ if self._has_closed_at_least_one_slice:
119
+ raise ValueError(
120
+ "Given that slice_boundary_fields is not defined and that per-partition state is not supported, only one slice is "
121
+ "expected."
122
+ )
123
+
124
+ self._state["slices"].append(
125
+ {
126
+ "start": 0, # FIXME this only works with int datetime
127
+ "end": self._extract_cursor_value(self._most_recent_record),
128
+ }
129
+ )
130
+
131
+ def _emit_state_message(self) -> None:
132
+ self._connector_state_manager.update_state_for_stream(self._stream_name, self._stream_namespace, self._state)
133
+ state_message = self._connector_state_manager.create_state_message(
134
+ self._stream_name, self._stream_namespace, send_per_stream_state=True
135
+ )
136
+ self._message_repository.emit_message(state_message)
137
+
138
+ def _merge_partitions(self) -> None:
139
+ pass # TODO eventually
140
+
141
+ def _extract_from_slice(self, partition: Partition, key: str) -> Comparable:
142
+ try:
143
+ _slice = partition.to_slice()
144
+ if not _slice:
145
+ raise KeyError(f"Could not find key `{key}` in empty slice")
146
+ return _slice[key] # type: ignore # we expect the devs to specify a key that would return a Comparable
147
+ except KeyError as exception:
148
+ raise KeyError(f"Partition is expected to have key `{key}` but could not be found") from exception
@@ -4,7 +4,6 @@
4
4
 
5
5
  from queue import Queue
6
6
 
7
- from airbyte_cdk.models import SyncMode
8
7
  from airbyte_cdk.sources.streams.concurrent.partitions.partition_generator import PartitionGenerator
9
8
  from airbyte_cdk.sources.streams.concurrent.partitions.types import PARTITIONS_GENERATED_SENTINEL, QueueItem
10
9
 
@@ -22,7 +21,7 @@ class PartitionEnqueuer:
22
21
  self._queue = queue
23
22
  self._sentinel = sentinel
24
23
 
25
- def generate_partitions(self, partition_generator: PartitionGenerator, sync_mode: SyncMode) -> None:
24
+ def generate_partitions(self, partition_generator: PartitionGenerator) -> None:
26
25
  """
27
26
  Generate partitions from a partition generator and put them in a queue.
28
27
  When all the partitions are added to the queue, a sentinel is added to the queue to indicate that all the partitions have been generated.
@@ -35,7 +34,7 @@ class PartitionEnqueuer:
35
34
  :return:
36
35
  """
37
36
  try:
38
- for partition in partition_generator.generate(sync_mode=sync_mode):
37
+ for partition in partition_generator.generate():
39
38
  self._queue.put(partition)
40
39
  self._queue.put(self._sentinel)
41
40
  except Exception as e:
@@ -25,6 +25,9 @@ class Partition(ABC):
25
25
  def to_slice(self) -> Optional[Mapping[str, Any]]:
26
26
  """
27
27
  Converts the partition to a slice that can be serialized and deserialized.
28
+
29
+ Note: it would have been interesting to have a type of `Mapping[str, Comparable]` to simplify typing but some slices can have nested
30
+ values ([example](https://github.com/airbytehq/airbyte/blob/1ce84d6396e446e1ac2377362446e3fb94509461/airbyte-integrations/connectors/source-stripe/source_stripe/streams.py#L584-L596))
28
31
  :return: A mapping representing a slice
29
32
  """
30
33
  pass
@@ -5,16 +5,14 @@
5
5
  from abc import ABC, abstractmethod
6
6
  from typing import Iterable
7
7
 
8
- from airbyte_cdk.models import SyncMode
9
8
  from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition
10
9
 
11
10
 
12
11
  class PartitionGenerator(ABC):
13
12
  @abstractmethod
14
- def generate(self, sync_mode: SyncMode) -> Iterable[Partition]:
13
+ def generate(self) -> Iterable[Partition]:
15
14
  """
16
15
  Generates partitions for a given sync mode.
17
- :param sync_mode: SyncMode
18
16
  :return: An iterable of partitions
19
17
  """
20
18
  pass