airbyte-cdk 6.21.1.dev0__py3-none-any.whl → 6.26.0.dev4103__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. airbyte_cdk/cli/source_declarative_manifest/_run.py +6 -0
  2. airbyte_cdk/connector_builder/connector_builder_handler.py +1 -0
  3. airbyte_cdk/sources/declarative/auth/oauth.py +68 -11
  4. airbyte_cdk/sources/declarative/concurrent_declarative_source.py +81 -16
  5. airbyte_cdk/sources/declarative/declarative_component_schema.yaml +58 -2
  6. airbyte_cdk/sources/declarative/decoders/__init__.py +9 -1
  7. airbyte_cdk/sources/declarative/decoders/zipfile_decoder.py +59 -0
  8. airbyte_cdk/sources/declarative/extractors/record_filter.py +3 -5
  9. airbyte_cdk/sources/declarative/incremental/__init__.py +6 -0
  10. airbyte_cdk/sources/declarative/incremental/concurrent_partition_cursor.py +334 -0
  11. airbyte_cdk/sources/declarative/incremental/global_substream_cursor.py +3 -0
  12. airbyte_cdk/sources/declarative/incremental/per_partition_cursor.py +35 -3
  13. airbyte_cdk/sources/declarative/manifest_declarative_source.py +15 -4
  14. airbyte_cdk/sources/declarative/models/declarative_component_schema.py +50 -14
  15. airbyte_cdk/sources/declarative/parsers/custom_code_compiler.py +143 -0
  16. airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py +220 -22
  17. airbyte_cdk/sources/declarative/partition_routers/substream_partition_router.py +6 -2
  18. airbyte_cdk/sources/declarative/requesters/error_handlers/composite_error_handler.py +22 -0
  19. airbyte_cdk/sources/declarative/retrievers/simple_retriever.py +1 -1
  20. airbyte_cdk/sources/file_based/config/abstract_file_based_spec.py +15 -0
  21. airbyte_cdk/sources/file_based/config/identities_based_stream_config.py +8 -0
  22. airbyte_cdk/sources/file_based/config/permissions.py +34 -0
  23. airbyte_cdk/sources/file_based/file_based_source.py +65 -1
  24. airbyte_cdk/sources/file_based/file_based_stream_reader.py +33 -0
  25. airbyte_cdk/sources/file_based/schema_helpers.py +25 -0
  26. airbyte_cdk/sources/file_based/stream/__init__.py +2 -1
  27. airbyte_cdk/sources/file_based/stream/default_file_based_stream.py +29 -0
  28. airbyte_cdk/sources/file_based/stream/identities_stream.py +99 -0
  29. airbyte_cdk/sources/http_logger.py +1 -1
  30. airbyte_cdk/sources/streams/concurrent/clamping.py +99 -0
  31. airbyte_cdk/sources/streams/concurrent/cursor.py +51 -57
  32. airbyte_cdk/sources/streams/concurrent/cursor_types.py +32 -0
  33. airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py +20 -20
  34. airbyte_cdk/test/utils/manifest_only_fixtures.py +1 -2
  35. {airbyte_cdk-6.21.1.dev0.dist-info → airbyte_cdk-6.26.0.dev4103.dist-info}/METADATA +3 -3
  36. {airbyte_cdk-6.21.1.dev0.dist-info → airbyte_cdk-6.26.0.dev4103.dist-info}/RECORD +39 -31
  37. {airbyte_cdk-6.21.1.dev0.dist-info → airbyte_cdk-6.26.0.dev4103.dist-info}/LICENSE.txt +0 -0
  38. {airbyte_cdk-6.21.1.dev0.dist-info → airbyte_cdk-6.26.0.dev4103.dist-info}/WHEEL +0 -0
  39. {airbyte_cdk-6.21.1.dev0.dist-info → airbyte_cdk-6.26.0.dev4103.dist-info}/entry_points.txt +0 -0
@@ -33,6 +33,9 @@ from airbyte_cdk.sources.file_based.config.file_based_stream_config import (
33
33
  FileBasedStreamConfig,
34
34
  ValidationPolicy,
35
35
  )
36
+ from airbyte_cdk.sources.file_based.config.identities_based_stream_config import (
37
+ IdentitiesStreamConfig,
38
+ )
36
39
  from airbyte_cdk.sources.file_based.discovery_policy import (
37
40
  AbstractDiscoveryPolicy,
38
41
  DefaultDiscoveryPolicy,
@@ -49,7 +52,11 @@ from airbyte_cdk.sources.file_based.schema_validation_policies import (
49
52
  DEFAULT_SCHEMA_VALIDATION_POLICIES,
50
53
  AbstractSchemaValidationPolicy,
51
54
  )
52
- from airbyte_cdk.sources.file_based.stream import AbstractFileBasedStream, DefaultFileBasedStream
55
+ from airbyte_cdk.sources.file_based.stream import (
56
+ AbstractFileBasedStream,
57
+ DefaultFileBasedStream,
58
+ IdentitiesStream,
59
+ )
53
60
  from airbyte_cdk.sources.file_based.stream.concurrent.adapters import FileBasedStreamFacade
54
61
  from airbyte_cdk.sources.file_based.stream.concurrent.cursor import (
55
62
  AbstractConcurrentFileBasedCursor,
@@ -157,6 +164,9 @@ class FileBasedSource(ConcurrentSourceAdapter, ABC):
157
164
  errors = []
158
165
  tracebacks = []
159
166
  for stream in streams:
167
+ if isinstance(stream, IdentitiesStream):
168
+ # Probably need to check identities endpoint/api access but will skip for now.
169
+ continue
160
170
  if not isinstance(stream, AbstractFileBasedStream):
161
171
  raise ValueError(f"Stream {stream} is not a file-based stream.")
162
172
  try:
@@ -164,6 +174,7 @@ class FileBasedSource(ConcurrentSourceAdapter, ABC):
164
174
  availability_method = (
165
175
  stream.availability_strategy.check_availability
166
176
  if self._use_file_transfer(parsed_config)
177
+ or self._sync_acl_permissions(parsed_config)
167
178
  else stream.availability_strategy.check_availability_and_parsability
168
179
  )
169
180
  (
@@ -289,6 +300,13 @@ class FileBasedSource(ConcurrentSourceAdapter, ABC):
289
300
  )
290
301
 
291
302
  streams.append(stream)
303
+
304
+ identities_stream_config = self._get_identities_stream_config(parsed_config)
305
+ if identities_stream_config:
306
+ identities_stream = self._make_identities_stream(
307
+ stream_config=identities_stream_config
308
+ )
309
+ streams.append(identities_stream)
292
310
  return streams
293
311
 
294
312
  except ValidationError as exc:
@@ -312,6 +330,19 @@ class FileBasedSource(ConcurrentSourceAdapter, ABC):
312
330
  cursor=cursor,
313
331
  use_file_transfer=self._use_file_transfer(parsed_config),
314
332
  preserve_directory_structure=self._preserve_directory_structure(parsed_config),
333
+ sync_acl_permissions=self._sync_acl_permissions(parsed_config),
334
+ )
335
+
336
+ def _make_identities_stream(
337
+ self,
338
+ stream_config: IdentitiesStreamConfig,
339
+ ) -> Stream:
340
+ return IdentitiesStream(
341
+ config=stream_config,
342
+ catalog_schema=self.stream_schemas.get(stream_config.name),
343
+ stream_reader=self.stream_reader,
344
+ discovery_policy=self.discovery_policy,
345
+ errors_collector=self.errors_collector,
315
346
  )
316
347
 
317
348
  def _get_stream_from_catalog(
@@ -387,6 +418,14 @@ class FileBasedSource(ConcurrentSourceAdapter, ABC):
387
418
  )
388
419
  return use_file_transfer
389
420
 
421
+ @staticmethod
422
+ def _use_records_transfer(parsed_config: AbstractFileBasedSpec) -> bool:
423
+ use_records_transfer = (
424
+ hasattr(parsed_config.delivery_method, "delivery_type")
425
+ and parsed_config.delivery_method.delivery_type == "use_records_transfer"
426
+ )
427
+ return use_records_transfer
428
+
390
429
  @staticmethod
391
430
  def _preserve_directory_structure(parsed_config: AbstractFileBasedSpec) -> bool:
392
431
  """
@@ -408,3 +447,28 @@ class FileBasedSource(ConcurrentSourceAdapter, ABC):
408
447
  ):
409
448
  return parsed_config.delivery_method.preserve_directory_structure
410
449
  return True
450
+
451
+ @staticmethod
452
+ def _sync_acl_permissions(parsed_config: AbstractFileBasedSpec) -> bool:
453
+ if (
454
+ FileBasedSource._use_records_transfer(parsed_config)
455
+ and hasattr(parsed_config.delivery_method, "sync_acl_permissions")
456
+ and parsed_config.delivery_method.sync_acl_permissions is not None
457
+ ):
458
+ return parsed_config.delivery_method.sync_acl_permissions
459
+ return False
460
+
461
+ @staticmethod
462
+ def _get_identities_stream_config(
463
+ parsed_config: AbstractFileBasedSpec,
464
+ ) -> Optional[IdentitiesStreamConfig]:
465
+ identities_stream_config = None
466
+ if (
467
+ FileBasedSource._sync_acl_permissions(parsed_config)
468
+ and hasattr(parsed_config.delivery_method, "identities")
469
+ and parsed_config.delivery_method.identities is not None
470
+ and isinstance(parsed_config.delivery_method.identities, IdentitiesStreamConfig)
471
+ and parsed_config.delivery_method.identities.domain
472
+ ):
473
+ identities_stream_config = parsed_config.delivery_method.identities
474
+ return identities_stream_config
@@ -135,6 +135,15 @@ class AbstractFileBasedStreamReader(ABC):
135
135
  return use_file_transfer
136
136
  return False
137
137
 
138
+ def use_records_transfer(self) -> bool:
139
+ if self.config:
140
+ use_records_transfer = (
141
+ hasattr(self.config.delivery_method, "delivery_type")
142
+ and self.config.delivery_method.delivery_type == "use_records_transfer"
143
+ )
144
+ return use_records_transfer
145
+ return False
146
+
138
147
  def preserve_directory_structure(self) -> bool:
139
148
  # fall back to preserve subdirectories if config is not present or incomplete
140
149
  if (
@@ -146,6 +155,16 @@ class AbstractFileBasedStreamReader(ABC):
146
155
  return self.config.delivery_method.preserve_directory_structure
147
156
  return True
148
157
 
158
+ def sync_acl_permissions(self) -> bool:
159
+ if (
160
+ self.config
161
+ and self.use_records_transfer()
162
+ and hasattr(self.config.delivery_method, "sync_acl_permissions")
163
+ and self.config.delivery_method.sync_acl_permissions is not None
164
+ ):
165
+ return self.config.delivery_method.sync_acl_permissions
166
+ return False
167
+
149
168
  @abstractmethod
150
169
  def get_file(
151
170
  self, file: RemoteFile, local_directory: str, logger: logging.Logger
@@ -183,3 +202,17 @@ class AbstractFileBasedStreamReader(ABC):
183
202
  makedirs(path.dirname(local_file_path), exist_ok=True)
184
203
  absolute_file_path = path.abspath(local_file_path)
185
204
  return [file_relative_path, local_file_path, absolute_file_path]
205
+
206
+ def get_file_acl_permissions(self, file: RemoteFile, logger: logging.Logger) -> Dict[str, Any]:
207
+ """
208
+ This is required for connectors that will support syncing
209
+ ACL Permissions from files.
210
+ """
211
+ return {}
212
+
213
+ def load_identity_groups(self, logger: logging.Logger) -> Iterable[Dict[str, Any]]:
214
+ """
215
+ This is required for connectors that will support syncing
216
+ identities.
217
+ """
218
+ yield {}
@@ -23,6 +23,31 @@ file_transfer_schema = {
23
23
  "properties": {"data": {"type": "object"}, "file": {"type": "object"}},
24
24
  }
25
25
 
26
+ remote_file_permissions_schema = {
27
+ "type": "object",
28
+ "properties": {
29
+ "id": {"type": "string"},
30
+ "file_path": {"type": "string"},
31
+ "allowed_identity_remote_ids": {"type": "array", "items": {"type": "string"}},
32
+ "publicly_accessible": {"type": "boolean"},
33
+ },
34
+ }
35
+
36
+ remote_file_identity_schema = {
37
+ "type": "object",
38
+ "properties": {
39
+ "id": {"type": "string"},
40
+ "remote_id": {"type": "string"},
41
+ "parent_id": {"type": ["null", "string"]},
42
+ "name": {"type": ["null", "string"]},
43
+ "description": {"type": ["null", "string"]},
44
+ "email_address": {"type": ["null", "string"]},
45
+ "member_email_addresses": {"type": ["null", "array"]},
46
+ "type": {"type": "string"},
47
+ "modified_at": {"type": "string"},
48
+ },
49
+ }
50
+
26
51
 
27
52
  @total_ordering
28
53
  class ComparableType(Enum):
@@ -1,4 +1,5 @@
1
1
  from airbyte_cdk.sources.file_based.stream.abstract_file_based_stream import AbstractFileBasedStream
2
2
  from airbyte_cdk.sources.file_based.stream.default_file_based_stream import DefaultFileBasedStream
3
+ from airbyte_cdk.sources.file_based.stream.identities_stream import IdentitiesStream
3
4
 
4
- __all__ = ["AbstractFileBasedStream", "DefaultFileBasedStream"]
5
+ __all__ = ["AbstractFileBasedStream", "DefaultFileBasedStream", "IdentitiesStream"]
@@ -29,6 +29,7 @@ from airbyte_cdk.sources.file_based.schema_helpers import (
29
29
  SchemaType,
30
30
  file_transfer_schema,
31
31
  merge_schemas,
32
+ remote_file_permissions_schema,
32
33
  schemaless_schema,
33
34
  )
34
35
  from airbyte_cdk.sources.file_based.stream import AbstractFileBasedStream
@@ -47,6 +48,7 @@ class DefaultFileBasedStream(AbstractFileBasedStream, IncrementalMixin):
47
48
 
48
49
  FILE_TRANSFER_KW = "use_file_transfer"
49
50
  PRESERVE_DIRECTORY_STRUCTURE_KW = "preserve_directory_structure"
51
+ SYNC_ACL_PERMISSIONS_KW = "sync_acl_permissions"
50
52
  FILES_KEY = "files"
51
53
  DATE_TIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ"
52
54
  ab_last_mod_col = "_ab_source_file_last_modified"
@@ -56,6 +58,7 @@ class DefaultFileBasedStream(AbstractFileBasedStream, IncrementalMixin):
56
58
  airbyte_columns = [ab_last_mod_col, ab_file_name_col]
57
59
  use_file_transfer = False
58
60
  preserve_directory_structure = True
61
+ sync_acl_permissions = False
59
62
 
60
63
  def __init__(self, **kwargs: Any):
61
64
  if self.FILE_TRANSFER_KW in kwargs:
@@ -64,6 +67,8 @@ class DefaultFileBasedStream(AbstractFileBasedStream, IncrementalMixin):
64
67
  self.preserve_directory_structure = kwargs.pop(
65
68
  self.PRESERVE_DIRECTORY_STRUCTURE_KW, True
66
69
  )
70
+ if self.SYNC_ACL_PERMISSIONS_KW in kwargs:
71
+ self.sync_acl_permissions = kwargs.pop(self.SYNC_ACL_PERMISSIONS_KW, False)
67
72
  super().__init__(**kwargs)
68
73
 
69
74
  @property
@@ -105,6 +110,8 @@ class DefaultFileBasedStream(AbstractFileBasedStream, IncrementalMixin):
105
110
  self.ab_file_name_col: {"type": "string"},
106
111
  },
107
112
  }
113
+ elif self.sync_acl_permissions:
114
+ return remote_file_permissions_schema
108
115
  else:
109
116
  return super()._filter_schema_invalid_properties(configured_catalog_json_schema)
110
117
 
@@ -187,6 +194,26 @@ class DefaultFileBasedStream(AbstractFileBasedStream, IncrementalMixin):
187
194
  yield stream_data_to_airbyte_message(
188
195
  self.name, record, is_file_transfer_message=True
189
196
  )
197
+ elif self.sync_acl_permissions:
198
+ try:
199
+ metadata_record = self.stream_reader.get_file_acl_permissions(
200
+ file, logger=self.logger
201
+ )
202
+ yield stream_data_to_airbyte_message(
203
+ self.name, metadata_record, is_file_transfer_message=False
204
+ )
205
+ except Exception as e:
206
+ self.logger.error(
207
+ f"Failed to retrieve metadata for file {file.uri}: {str(e)}"
208
+ )
209
+ yield AirbyteMessage(
210
+ type=MessageType.LOG,
211
+ log=AirbyteLogMessage(
212
+ level=Level.ERROR,
213
+ message=f"Error retrieving metadata: stream={self.name} file={file.uri}",
214
+ stack_trace=traceback.format_exc(),
215
+ ),
216
+ )
190
217
  else:
191
218
  for record in parser.parse_records(
192
219
  self.config, file, self.stream_reader, self.logger, schema
@@ -284,6 +311,8 @@ class DefaultFileBasedStream(AbstractFileBasedStream, IncrementalMixin):
284
311
  def _get_raw_json_schema(self) -> JsonSchema:
285
312
  if self.use_file_transfer:
286
313
  return file_transfer_schema
314
+ elif self.sync_acl_permissions:
315
+ return remote_file_permissions_schema
287
316
  elif self.config.input_schema:
288
317
  return self.config.get_input_schema() # type: ignore
289
318
  elif self.config.schemaless:
@@ -0,0 +1,99 @@
1
+ #
2
+ # Copyright (c) 2024 Airbyte, Inc., all rights reserved.
3
+ #
4
+
5
+ import traceback
6
+ from functools import cache
7
+ from typing import Any, Iterable, List, Mapping, MutableMapping, Optional
8
+
9
+ from airbyte_protocol_dataclasses.models import SyncMode
10
+
11
+ from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, Level
12
+ from airbyte_cdk.models import Type as MessageType
13
+ from airbyte_cdk.sources.file_based.config.file_based_stream_config import PrimaryKeyType
14
+ from airbyte_cdk.sources.file_based.config.identities_based_stream_config import (
15
+ IdentitiesStreamConfig,
16
+ )
17
+ from airbyte_cdk.sources.file_based.discovery_policy import AbstractDiscoveryPolicy
18
+ from airbyte_cdk.sources.file_based.exceptions import FileBasedErrorsCollector, FileBasedSourceError
19
+ from airbyte_cdk.sources.file_based.file_based_stream_reader import AbstractFileBasedStreamReader
20
+ from airbyte_cdk.sources.file_based.schema_helpers import remote_file_identity_schema
21
+ from airbyte_cdk.sources.file_based.types import StreamSlice
22
+ from airbyte_cdk.sources.streams import Stream
23
+ from airbyte_cdk.sources.streams.checkpoint import Cursor
24
+ from airbyte_cdk.sources.streams.core import JsonSchema
25
+ from airbyte_cdk.sources.utils.record_helper import stream_data_to_airbyte_message
26
+ from airbyte_cdk.utils.traced_exception import AirbyteTracedException
27
+
28
+
29
+ class IdentitiesStream(Stream):
30
+ """
31
+ The identities stream. A full refresh stream to sync identities from a certain domain.
32
+ The stream reader manage the logic to get such data, which is implemented on connector side.
33
+ """
34
+
35
+ is_resumable = False
36
+
37
+ def __init__(
38
+ self,
39
+ config: IdentitiesStreamConfig,
40
+ catalog_schema: Optional[Mapping[str, Any]],
41
+ stream_reader: AbstractFileBasedStreamReader,
42
+ discovery_policy: AbstractDiscoveryPolicy,
43
+ errors_collector: FileBasedErrorsCollector,
44
+ ):
45
+ super().__init__()
46
+ self.config = config
47
+ self.catalog_schema = catalog_schema
48
+ self.stream_reader = stream_reader
49
+ self._discovery_policy = discovery_policy
50
+ self.errors_collector = errors_collector
51
+ self._cursor: MutableMapping[str, Any] = {}
52
+
53
+ @property
54
+ def state(self) -> MutableMapping[str, Any]:
55
+ return self._cursor
56
+
57
+ @state.setter
58
+ def state(self, value: MutableMapping[str, Any]) -> None:
59
+ """State setter, accept state serialized by state getter."""
60
+ self._cursor = value
61
+
62
+ @property
63
+ def primary_key(self) -> PrimaryKeyType:
64
+ return None
65
+
66
+ def read_records(
67
+ self,
68
+ sync_mode: SyncMode,
69
+ cursor_field: Optional[List[str]] = None,
70
+ stream_slice: Optional[StreamSlice] = None,
71
+ stream_state: Optional[Mapping[str, Any]] = None,
72
+ ) -> Iterable[Mapping[str, Any] | AirbyteMessage]:
73
+ try:
74
+ identity_groups = self.stream_reader.load_identity_groups(logger=self.logger)
75
+ for record in identity_groups:
76
+ yield stream_data_to_airbyte_message(self.name, record)
77
+ except AirbyteTracedException as exc:
78
+ # Re-raise the exception to stop the whole sync immediately as this is a fatal error
79
+ raise exc
80
+ except Exception:
81
+ yield AirbyteMessage(
82
+ type=MessageType.LOG,
83
+ log=AirbyteLogMessage(
84
+ level=Level.ERROR,
85
+ message=f"{FileBasedSourceError.ERROR_PARSING_RECORD.value} stream={self.name}",
86
+ stack_trace=traceback.format_exc(),
87
+ ),
88
+ )
89
+
90
+ @cache
91
+ def get_json_schema(self) -> JsonSchema:
92
+ return remote_file_identity_schema
93
+
94
+ @property
95
+ def name(self) -> str:
96
+ return self.config.name
97
+
98
+ def get_cursor(self) -> Optional[Cursor]:
99
+ return None
@@ -45,7 +45,7 @@ def format_http_message(
45
45
  log_message["http"]["is_auxiliary"] = is_auxiliary # type: ignore [index]
46
46
  if stream_name:
47
47
  log_message["airbyte_cdk"] = {"stream": {"name": stream_name}}
48
- return log_message # type: ignore [return-value] # got "dict[str, object]", expected "dict[str, JsonType]"
48
+ return log_message # type: ignore[return-value] # got "dict[str, object]", expected "dict[str, JsonType]"
49
49
 
50
50
 
51
51
  def _normalize_body_string(body_str: Optional[Union[str, bytes]]) -> Optional[str]:
@@ -0,0 +1,99 @@
1
+ from abc import ABC
2
+ from datetime import datetime, timedelta
3
+ from enum import Enum
4
+ from typing import Callable
5
+
6
+ from airbyte_cdk.sources.streams.concurrent.cursor_types import CursorValueType
7
+
8
+
9
+ class ClampingStrategy(ABC):
10
+ def clamp(self, value: CursorValueType) -> CursorValueType:
11
+ raise NotImplementedError()
12
+
13
+
14
+ class NoClamping(ClampingStrategy):
15
+ def clamp(self, value: CursorValueType) -> CursorValueType:
16
+ return value
17
+
18
+
19
+ class ClampingEndProvider:
20
+ def __init__(
21
+ self,
22
+ clamping_strategy: ClampingStrategy,
23
+ end_provider: Callable[[], CursorValueType],
24
+ granularity: timedelta,
25
+ ) -> None:
26
+ self._clamping_strategy = clamping_strategy
27
+ self._end_provider = end_provider
28
+ self._granularity = granularity
29
+
30
+ def __call__(self) -> CursorValueType:
31
+ return self._clamping_strategy.clamp(self._end_provider()) - self._granularity
32
+
33
+
34
+ class DayClampingStrategy(ClampingStrategy):
35
+ def __init__(self, is_ceiling: bool = True) -> None:
36
+ self._is_ceiling = is_ceiling
37
+
38
+ def clamp(self, value: datetime) -> datetime: # type: ignore # datetime implements method from CursorValueType
39
+ return_value = value.replace(hour=0, minute=0, second=0, microsecond=0)
40
+ if self._is_ceiling:
41
+ return return_value + timedelta(days=1)
42
+ return return_value
43
+
44
+
45
+ class MonthClampingStrategy(ClampingStrategy):
46
+ def __init__(self, is_ceiling: bool = True) -> None:
47
+ self._is_ceiling = is_ceiling
48
+
49
+ def clamp(self, value: datetime) -> datetime: # type: ignore # datetime implements method from CursorValueType
50
+ return_value = value.replace(hour=0, minute=0, second=0, microsecond=0)
51
+ needs_to_round = value.day != 1
52
+ if not needs_to_round:
53
+ return return_value
54
+
55
+ return self._ceil(return_value) if self._is_ceiling else return_value.replace(day=1)
56
+
57
+ def _ceil(self, value: datetime) -> datetime:
58
+ return value.replace(
59
+ year=value.year + 1 if value.month == 12 else value.year,
60
+ month=(value.month % 12) + 1,
61
+ day=1,
62
+ hour=0,
63
+ minute=0,
64
+ second=0,
65
+ microsecond=0,
66
+ )
67
+
68
+
69
+ class Weekday(Enum):
70
+ """
71
+ These integer values map to the same ones used by the Datetime.date.weekday() implementation
72
+ """
73
+
74
+ MONDAY = 0
75
+ TUESDAY = 1
76
+ WEDNESDAY = 2
77
+ THURSDAY = 3
78
+ FRIDAY = 4
79
+ SATURDAY = 5
80
+ SUNDAY = 6
81
+
82
+
83
+ class WeekClampingStrategy(ClampingStrategy):
84
+ def __init__(self, day_of_week: Weekday, is_ceiling: bool = True) -> None:
85
+ self._day_of_week = day_of_week.value
86
+ self._is_ceiling = is_ceiling
87
+
88
+ def clamp(self, value: datetime) -> datetime: # type: ignore # datetime implements method from CursorValueType
89
+ days_diff_to_ceiling = (
90
+ 7 - (value.weekday() - self._day_of_week)
91
+ if value.weekday() > self._day_of_week
92
+ else abs(value.weekday() - self._day_of_week)
93
+ )
94
+ delta = (
95
+ timedelta(days_diff_to_ceiling)
96
+ if self._is_ceiling
97
+ else timedelta(days_diff_to_ceiling - 7)
98
+ )
99
+ return value.replace(hour=0, minute=0, second=0, microsecond=0) + delta