airbyte-cdk 6.31.2.dev0__py3-none-any.whl → 6.33.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airbyte_cdk/cli/source_declarative_manifest/_run.py +9 -3
- airbyte_cdk/connector_builder/connector_builder_handler.py +3 -2
- airbyte_cdk/sources/declarative/async_job/job_orchestrator.py +7 -7
- airbyte_cdk/sources/declarative/auth/jwt.py +17 -11
- airbyte_cdk/sources/declarative/auth/oauth.py +89 -23
- airbyte_cdk/sources/declarative/auth/token_provider.py +4 -5
- airbyte_cdk/sources/declarative/checks/check_dynamic_stream.py +19 -9
- airbyte_cdk/sources/declarative/concurrent_declarative_source.py +145 -43
- airbyte_cdk/sources/declarative/declarative_component_schema.yaml +51 -2
- airbyte_cdk/sources/declarative/declarative_stream.py +3 -1
- airbyte_cdk/sources/declarative/extractors/record_filter.py +3 -5
- airbyte_cdk/sources/declarative/incremental/__init__.py +6 -0
- airbyte_cdk/sources/declarative/incremental/concurrent_partition_cursor.py +400 -0
- airbyte_cdk/sources/declarative/incremental/global_substream_cursor.py +3 -0
- airbyte_cdk/sources/declarative/incremental/per_partition_cursor.py +35 -3
- airbyte_cdk/sources/declarative/manifest_declarative_source.py +20 -7
- airbyte_cdk/sources/declarative/models/declarative_component_schema.py +41 -5
- airbyte_cdk/sources/declarative/parsers/custom_code_compiler.py +143 -0
- airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py +313 -30
- airbyte_cdk/sources/declarative/partition_routers/async_job_partition_router.py +5 -5
- airbyte_cdk/sources/declarative/partition_routers/substream_partition_router.py +46 -12
- airbyte_cdk/sources/declarative/requesters/error_handlers/composite_error_handler.py +22 -0
- airbyte_cdk/sources/declarative/requesters/error_handlers/http_response_filter.py +4 -4
- airbyte_cdk/sources/declarative/retrievers/async_retriever.py +6 -12
- airbyte_cdk/sources/declarative/retrievers/simple_retriever.py +1 -1
- airbyte_cdk/sources/declarative/schema/__init__.py +2 -0
- airbyte_cdk/sources/declarative/schema/dynamic_schema_loader.py +44 -5
- airbyte_cdk/sources/http_logger.py +1 -1
- airbyte_cdk/sources/streams/concurrent/clamping.py +99 -0
- airbyte_cdk/sources/streams/concurrent/cursor.py +51 -57
- airbyte_cdk/sources/streams/concurrent/cursor_types.py +32 -0
- airbyte_cdk/sources/streams/concurrent/state_converters/datetime_stream_state_converter.py +22 -13
- airbyte_cdk/sources/streams/core.py +6 -6
- airbyte_cdk/sources/streams/http/http.py +1 -2
- airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py +231 -62
- airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py +171 -88
- airbyte_cdk/sources/types.py +4 -2
- airbyte_cdk/sources/utils/transform.py +23 -2
- airbyte_cdk/test/utils/manifest_only_fixtures.py +1 -2
- airbyte_cdk/utils/datetime_helpers.py +499 -0
- airbyte_cdk/utils/slice_hasher.py +8 -1
- airbyte_cdk-6.33.0.dist-info/LICENSE_SHORT +1 -0
- {airbyte_cdk-6.31.2.dev0.dist-info → airbyte_cdk-6.33.0.dist-info}/METADATA +6 -6
- {airbyte_cdk-6.31.2.dev0.dist-info → airbyte_cdk-6.33.0.dist-info}/RECORD +47 -41
- {airbyte_cdk-6.31.2.dev0.dist-info → airbyte_cdk-6.33.0.dist-info}/WHEEL +1 -1
- {airbyte_cdk-6.31.2.dev0.dist-info → airbyte_cdk-6.33.0.dist-info}/LICENSE.txt +0 -0
- {airbyte_cdk-6.31.2.dev0.dist-info → airbyte_cdk-6.33.0.dist-info}/entry_points.txt +0 -0
@@ -4,9 +4,9 @@ from dataclasses import InitVar, dataclass, field
|
|
4
4
|
from typing import Any, Callable, Iterable, Mapping, Optional
|
5
5
|
|
6
6
|
from airbyte_cdk.models import FailureType
|
7
|
+
from airbyte_cdk.sources.declarative.async_job.job import AsyncJob
|
7
8
|
from airbyte_cdk.sources.declarative.async_job.job_orchestrator import (
|
8
9
|
AsyncJobOrchestrator,
|
9
|
-
AsyncPartition,
|
10
10
|
)
|
11
11
|
from airbyte_cdk.sources.declarative.partition_routers.single_partition_router import (
|
12
12
|
SinglePartitionRouter,
|
@@ -42,12 +42,12 @@ class AsyncJobPartitionRouter(StreamSlicer):
|
|
42
42
|
|
43
43
|
for completed_partition in self._job_orchestrator.create_and_get_completed_partitions():
|
44
44
|
yield StreamSlice(
|
45
|
-
partition=dict(completed_partition.stream_slice.partition)
|
46
|
-
| {"partition": completed_partition},
|
45
|
+
partition=dict(completed_partition.stream_slice.partition),
|
47
46
|
cursor_slice=completed_partition.stream_slice.cursor_slice,
|
47
|
+
extra_fields={"jobs": list(completed_partition.jobs)},
|
48
48
|
)
|
49
49
|
|
50
|
-
def fetch_records(self,
|
50
|
+
def fetch_records(self, async_jobs: Iterable[AsyncJob]) -> Iterable[Mapping[str, Any]]:
|
51
51
|
"""
|
52
52
|
This method of fetching records extends beyond what a PartitionRouter/StreamSlicer should
|
53
53
|
be responsible for. However, this was added in because the JobOrchestrator is required to
|
@@ -62,4 +62,4 @@ class AsyncJobPartitionRouter(StreamSlicer):
|
|
62
62
|
failure_type=FailureType.system_error,
|
63
63
|
)
|
64
64
|
|
65
|
-
return self._job_orchestrator.fetch_records(
|
65
|
+
return self._job_orchestrator.fetch_records(async_jobs=async_jobs)
|
@@ -289,24 +289,58 @@ class SubstreamPartitionRouter(PartitionRouter):
|
|
289
289
|
return
|
290
290
|
|
291
291
|
if not parent_state and incremental_dependency:
|
292
|
-
#
|
293
|
-
|
294
|
-
substream_state = substream_state[0] if substream_state else {} # type: ignore [assignment] # Incorrect type for assignment
|
295
|
-
parent_state = {}
|
296
|
-
|
297
|
-
# Copy child state to parent streams with incremental dependencies
|
298
|
-
if substream_state:
|
299
|
-
for parent_config in self.parent_stream_configs:
|
300
|
-
if parent_config.incremental_dependency:
|
301
|
-
parent_state[parent_config.stream.name] = {
|
302
|
-
parent_config.stream.cursor_field: substream_state
|
303
|
-
}
|
292
|
+
# Migrate child state to parent state format
|
293
|
+
parent_state = self._migrate_child_state_to_parent_state(stream_state)
|
304
294
|
|
305
295
|
# Set state for each parent stream with an incremental dependency
|
306
296
|
for parent_config in self.parent_stream_configs:
|
307
297
|
if parent_config.incremental_dependency:
|
308
298
|
parent_config.stream.state = parent_state.get(parent_config.stream.name, {})
|
309
299
|
|
300
|
+
def _migrate_child_state_to_parent_state(self, stream_state: StreamState) -> StreamState:
|
301
|
+
"""
|
302
|
+
Migrate the child stream state to the parent stream's state format.
|
303
|
+
|
304
|
+
This method converts the global or child state into a format compatible with parent
|
305
|
+
streams. The migration occurs only for parent streams with incremental dependencies.
|
306
|
+
The method filters out per-partition states and retains only the global state in the
|
307
|
+
format `{cursor_field: cursor_value}`.
|
308
|
+
|
309
|
+
Args:
|
310
|
+
stream_state (StreamState): The state to migrate. Expected formats include:
|
311
|
+
- {"updated_at": "2023-05-27T00:00:00Z"}
|
312
|
+
- {"states": [...] } (ignored during migration)
|
313
|
+
|
314
|
+
Returns:
|
315
|
+
StreamState: A migrated state for parent streams in the format:
|
316
|
+
{
|
317
|
+
"parent_stream_name": {"parent_stream_cursor": "2023-05-27T00:00:00Z"}
|
318
|
+
}
|
319
|
+
|
320
|
+
Example:
|
321
|
+
Input: {"updated_at": "2023-05-27T00:00:00Z"}
|
322
|
+
Output: {
|
323
|
+
"parent_stream_name": {"parent_stream_cursor": "2023-05-27T00:00:00Z"}
|
324
|
+
}
|
325
|
+
"""
|
326
|
+
substream_state_values = list(stream_state.values())
|
327
|
+
substream_state = substream_state_values[0] if substream_state_values else {}
|
328
|
+
|
329
|
+
# Ignore per-partition states or invalid formats
|
330
|
+
if isinstance(substream_state, (list, dict)) or len(substream_state_values) != 1:
|
331
|
+
return {}
|
332
|
+
|
333
|
+
# Copy child state to parent streams with incremental dependencies
|
334
|
+
parent_state = {}
|
335
|
+
if substream_state:
|
336
|
+
for parent_config in self.parent_stream_configs:
|
337
|
+
if parent_config.incremental_dependency:
|
338
|
+
parent_state[parent_config.stream.name] = {
|
339
|
+
parent_config.stream.cursor_field: substream_state
|
340
|
+
}
|
341
|
+
|
342
|
+
return parent_state
|
343
|
+
|
310
344
|
def get_stream_state(self) -> Optional[Mapping[str, StreamState]]:
|
311
345
|
"""
|
312
346
|
Get the state of the parent streams.
|
@@ -8,6 +8,7 @@ from typing import Any, List, Mapping, Optional, Union
|
|
8
8
|
import requests
|
9
9
|
|
10
10
|
from airbyte_cdk.sources.streams.http.error_handlers import ErrorHandler
|
11
|
+
from airbyte_cdk.sources.streams.http.error_handlers.backoff_strategy import BackoffStrategy
|
11
12
|
from airbyte_cdk.sources.streams.http.error_handlers.response_models import (
|
12
13
|
ErrorResolution,
|
13
14
|
ResponseAction,
|
@@ -77,3 +78,24 @@ class CompositeErrorHandler(ErrorHandler):
|
|
77
78
|
return matched_error_resolution
|
78
79
|
|
79
80
|
return create_fallback_error_resolution(response_or_exception)
|
81
|
+
|
82
|
+
@property
|
83
|
+
def backoff_strategies(self) -> Optional[List[BackoffStrategy]]:
|
84
|
+
"""
|
85
|
+
Combines backoff strategies from all child error handlers into a single flattened list.
|
86
|
+
|
87
|
+
When used with HttpRequester, note the following behavior:
|
88
|
+
- In HttpRequester.__post_init__, the entire list of backoff strategies is assigned to the error handler
|
89
|
+
- However, the error handler's backoff_time() method only ever uses the first non-None strategy in the list
|
90
|
+
- This means that if any backoff strategies are present, the first non-None strategy becomes the default
|
91
|
+
- This applies to both user-defined response filters and errors from DEFAULT_ERROR_MAPPING
|
92
|
+
- The list structure is not used to map different strategies to different error conditions
|
93
|
+
- Therefore, subsequent strategies in the list will not be used
|
94
|
+
|
95
|
+
Returns None if no handlers have strategies defined, which will result in HttpRequester using its default backoff strategy.
|
96
|
+
"""
|
97
|
+
all_strategies = []
|
98
|
+
for handler in self.error_handlers:
|
99
|
+
if hasattr(handler, "backoff_strategies") and handler.backoff_strategies:
|
100
|
+
all_strategies.extend(handler.backoff_strategies)
|
101
|
+
return all_strategies if all_strategies else None
|
@@ -151,16 +151,16 @@ class HttpResponseFilter:
|
|
151
151
|
:param response: The HTTP response which can be used during interpolation
|
152
152
|
:return: The evaluated error message string to be emitted
|
153
153
|
"""
|
154
|
-
return self.error_message.eval( # type: ignore
|
154
|
+
return self.error_message.eval( # type: ignore[no-any-return, union-attr]
|
155
155
|
self.config, response=self._safe_response_json(response), headers=response.headers
|
156
156
|
)
|
157
157
|
|
158
158
|
def _response_matches_predicate(self, response: requests.Response) -> bool:
|
159
159
|
return (
|
160
160
|
bool(
|
161
|
-
self.predicate.condition # type:
|
162
|
-
and self.predicate.eval( # type:
|
163
|
-
None, # type: ignore
|
161
|
+
self.predicate.condition # type:ignore[union-attr]
|
162
|
+
and self.predicate.eval( # type:ignore[union-attr]
|
163
|
+
None, # type: ignore[arg-type]
|
164
164
|
response=self._safe_response_json(response),
|
165
165
|
headers=response.headers,
|
166
166
|
)
|
@@ -6,7 +6,7 @@ from typing import Any, Iterable, Mapping, Optional
|
|
6
6
|
|
7
7
|
from typing_extensions import deprecated
|
8
8
|
|
9
|
-
from airbyte_cdk.
|
9
|
+
from airbyte_cdk.sources.declarative.async_job.job import AsyncJob
|
10
10
|
from airbyte_cdk.sources.declarative.async_job.job_orchestrator import AsyncPartition
|
11
11
|
from airbyte_cdk.sources.declarative.extractors.record_selector import RecordSelector
|
12
12
|
from airbyte_cdk.sources.declarative.partition_routers.async_job_partition_router import (
|
@@ -16,7 +16,6 @@ from airbyte_cdk.sources.declarative.retrievers.retriever import Retriever
|
|
16
16
|
from airbyte_cdk.sources.source import ExperimentalClassWarning
|
17
17
|
from airbyte_cdk.sources.streams.core import StreamData
|
18
18
|
from airbyte_cdk.sources.types import Config, StreamSlice, StreamState
|
19
|
-
from airbyte_cdk.utils.traced_exception import AirbyteTracedException
|
20
19
|
|
21
20
|
|
22
21
|
@deprecated(
|
@@ -57,9 +56,9 @@ class AsyncRetriever(Retriever):
|
|
57
56
|
|
58
57
|
return self.state
|
59
58
|
|
60
|
-
def
|
59
|
+
def _validate_and_get_stream_slice_jobs(
|
61
60
|
self, stream_slice: Optional[StreamSlice] = None
|
62
|
-
) ->
|
61
|
+
) -> Iterable[AsyncJob]:
|
63
62
|
"""
|
64
63
|
Validates the stream_slice argument and returns the partition from it.
|
65
64
|
|
@@ -73,12 +72,7 @@ class AsyncRetriever(Retriever):
|
|
73
72
|
AirbyteTracedException: If the stream_slice is not an instance of StreamSlice or if the partition is not present in the stream_slice.
|
74
73
|
|
75
74
|
"""
|
76
|
-
|
77
|
-
raise AirbyteTracedException(
|
78
|
-
message="Invalid arguments to AsyncJobRetriever.read_records: stream_slice is no optional. Please contact Airbyte Support",
|
79
|
-
failure_type=FailureType.system_error,
|
80
|
-
)
|
81
|
-
return stream_slice["partition"] # type: ignore # stream_slice["partition"] has been added as an AsyncPartition as part of stream_slices
|
75
|
+
return stream_slice.extra_fields.get("jobs", []) if stream_slice else []
|
82
76
|
|
83
77
|
def stream_slices(self) -> Iterable[Optional[StreamSlice]]:
|
84
78
|
return self.stream_slicer.stream_slices()
|
@@ -89,8 +83,8 @@ class AsyncRetriever(Retriever):
|
|
89
83
|
stream_slice: Optional[StreamSlice] = None,
|
90
84
|
) -> Iterable[StreamData]:
|
91
85
|
stream_state: StreamState = self._get_stream_state()
|
92
|
-
|
93
|
-
records: Iterable[Mapping[str, Any]] = self.stream_slicer.fetch_records(
|
86
|
+
jobs: Iterable[AsyncJob] = self._validate_and_get_stream_slice_jobs(stream_slice)
|
87
|
+
records: Iterable[Mapping[str, Any]] = self.stream_slicer.fetch_records(jobs)
|
94
88
|
|
95
89
|
yield from self.record_selector.filter_and_transform(
|
96
90
|
all_data=records,
|
@@ -163,7 +163,7 @@ class SimpleRetriever(Retriever):
|
|
163
163
|
stream_slice,
|
164
164
|
next_page_token,
|
165
165
|
self._paginator.get_request_headers,
|
166
|
-
self.
|
166
|
+
self.request_option_provider.get_request_headers,
|
167
167
|
)
|
168
168
|
if isinstance(headers, str):
|
169
169
|
raise ValueError("Request headers cannot be a string")
|
@@ -4,6 +4,7 @@
|
|
4
4
|
|
5
5
|
from airbyte_cdk.sources.declarative.schema.default_schema_loader import DefaultSchemaLoader
|
6
6
|
from airbyte_cdk.sources.declarative.schema.dynamic_schema_loader import (
|
7
|
+
ComplexFieldType,
|
7
8
|
DynamicSchemaLoader,
|
8
9
|
SchemaTypeIdentifier,
|
9
10
|
TypesMap,
|
@@ -18,6 +19,7 @@ __all__ = [
|
|
18
19
|
"SchemaLoader",
|
19
20
|
"InlineSchemaLoader",
|
20
21
|
"DynamicSchemaLoader",
|
22
|
+
"ComplexFieldType",
|
21
23
|
"TypesMap",
|
22
24
|
"SchemaTypeIdentifier",
|
23
25
|
]
|
@@ -18,7 +18,7 @@ from airbyte_cdk.sources.declarative.transformations import RecordTransformation
|
|
18
18
|
from airbyte_cdk.sources.source import ExperimentalClassWarning
|
19
19
|
from airbyte_cdk.sources.types import Config, StreamSlice, StreamState
|
20
20
|
|
21
|
-
AIRBYTE_DATA_TYPES: Mapping[str,
|
21
|
+
AIRBYTE_DATA_TYPES: Mapping[str, MutableMapping[str, Any]] = {
|
22
22
|
"string": {"type": ["null", "string"]},
|
23
23
|
"boolean": {"type": ["null", "boolean"]},
|
24
24
|
"date": {"type": ["null", "string"], "format": "date"},
|
@@ -45,6 +45,25 @@ AIRBYTE_DATA_TYPES: Mapping[str, Mapping[str, Any]] = {
|
|
45
45
|
}
|
46
46
|
|
47
47
|
|
48
|
+
@deprecated("This class is experimental. Use at your own risk.", category=ExperimentalClassWarning)
|
49
|
+
@dataclass(frozen=True)
|
50
|
+
class ComplexFieldType:
|
51
|
+
"""
|
52
|
+
Identifies complex field type
|
53
|
+
"""
|
54
|
+
|
55
|
+
field_type: str
|
56
|
+
items: Optional[Union[str, "ComplexFieldType"]] = None
|
57
|
+
|
58
|
+
def __post_init__(self) -> None:
|
59
|
+
"""
|
60
|
+
Enforces that `items` is only used when `field_type` is a array
|
61
|
+
"""
|
62
|
+
# `items_type` is valid only for array target types
|
63
|
+
if self.items and self.field_type != "array":
|
64
|
+
raise ValueError("'items' can only be used when 'field_type' is an array.")
|
65
|
+
|
66
|
+
|
48
67
|
@deprecated("This class is experimental. Use at your own risk.", category=ExperimentalClassWarning)
|
49
68
|
@dataclass(frozen=True)
|
50
69
|
class TypesMap:
|
@@ -52,7 +71,7 @@ class TypesMap:
|
|
52
71
|
Represents a mapping between a current type and its corresponding target type.
|
53
72
|
"""
|
54
73
|
|
55
|
-
target_type: Union[List[str], str]
|
74
|
+
target_type: Union[List[str], str, ComplexFieldType]
|
56
75
|
current_type: Union[List[str], str]
|
57
76
|
condition: Optional[str]
|
58
77
|
|
@@ -135,8 +154,9 @@ class DynamicSchemaLoader(SchemaLoader):
|
|
135
154
|
transformed_properties = self._transform(properties, {})
|
136
155
|
|
137
156
|
return {
|
138
|
-
"$schema": "
|
157
|
+
"$schema": "https://json-schema.org/draft-07/schema#",
|
139
158
|
"type": "object",
|
159
|
+
"additionalProperties": True,
|
140
160
|
"properties": transformed_properties,
|
141
161
|
}
|
142
162
|
|
@@ -188,18 +208,37 @@ class DynamicSchemaLoader(SchemaLoader):
|
|
188
208
|
first_type = self._get_airbyte_type(mapped_field_type[0])
|
189
209
|
second_type = self._get_airbyte_type(mapped_field_type[1])
|
190
210
|
return {"oneOf": [first_type, second_type]}
|
211
|
+
|
191
212
|
elif isinstance(mapped_field_type, str):
|
192
213
|
return self._get_airbyte_type(mapped_field_type)
|
214
|
+
|
215
|
+
elif isinstance(mapped_field_type, ComplexFieldType):
|
216
|
+
return self._resolve_complex_type(mapped_field_type)
|
217
|
+
|
193
218
|
else:
|
194
219
|
raise ValueError(
|
195
220
|
f"Invalid data type. Available string or two items list of string. Got {mapped_field_type}."
|
196
221
|
)
|
197
222
|
|
223
|
+
def _resolve_complex_type(self, complex_type: ComplexFieldType) -> Mapping[str, Any]:
|
224
|
+
if not complex_type.items:
|
225
|
+
return self._get_airbyte_type(complex_type.field_type)
|
226
|
+
|
227
|
+
field_type = self._get_airbyte_type(complex_type.field_type)
|
228
|
+
|
229
|
+
field_type["items"] = (
|
230
|
+
self._get_airbyte_type(complex_type.items)
|
231
|
+
if isinstance(complex_type.items, str)
|
232
|
+
else self._resolve_complex_type(complex_type.items)
|
233
|
+
)
|
234
|
+
|
235
|
+
return field_type
|
236
|
+
|
198
237
|
def _replace_type_if_not_valid(
|
199
238
|
self,
|
200
239
|
field_type: Union[List[str], str],
|
201
240
|
raw_schema: MutableMapping[str, Any],
|
202
|
-
) -> Union[List[str], str]:
|
241
|
+
) -> Union[List[str], str, ComplexFieldType]:
|
203
242
|
"""
|
204
243
|
Replaces a field type if it matches a type mapping in `types_map`.
|
205
244
|
"""
|
@@ -216,7 +255,7 @@ class DynamicSchemaLoader(SchemaLoader):
|
|
216
255
|
return field_type
|
217
256
|
|
218
257
|
@staticmethod
|
219
|
-
def _get_airbyte_type(field_type: str) ->
|
258
|
+
def _get_airbyte_type(field_type: str) -> MutableMapping[str, Any]:
|
220
259
|
"""
|
221
260
|
Maps a field type to its corresponding Airbyte type definition.
|
222
261
|
"""
|
@@ -45,7 +45,7 @@ def format_http_message(
|
|
45
45
|
log_message["http"]["is_auxiliary"] = is_auxiliary # type: ignore [index]
|
46
46
|
if stream_name:
|
47
47
|
log_message["airbyte_cdk"] = {"stream": {"name": stream_name}}
|
48
|
-
return log_message # type: ignore
|
48
|
+
return log_message # type: ignore[return-value] # got "dict[str, object]", expected "dict[str, JsonType]"
|
49
49
|
|
50
50
|
|
51
51
|
def _normalize_body_string(body_str: Optional[Union[str, bytes]]) -> Optional[str]:
|
@@ -0,0 +1,99 @@
|
|
1
|
+
from abc import ABC
|
2
|
+
from datetime import datetime, timedelta
|
3
|
+
from enum import Enum
|
4
|
+
from typing import Callable
|
5
|
+
|
6
|
+
from airbyte_cdk.sources.streams.concurrent.cursor_types import CursorValueType
|
7
|
+
|
8
|
+
|
9
|
+
class ClampingStrategy(ABC):
|
10
|
+
def clamp(self, value: CursorValueType) -> CursorValueType:
|
11
|
+
raise NotImplementedError()
|
12
|
+
|
13
|
+
|
14
|
+
class NoClamping(ClampingStrategy):
|
15
|
+
def clamp(self, value: CursorValueType) -> CursorValueType:
|
16
|
+
return value
|
17
|
+
|
18
|
+
|
19
|
+
class ClampingEndProvider:
|
20
|
+
def __init__(
|
21
|
+
self,
|
22
|
+
clamping_strategy: ClampingStrategy,
|
23
|
+
end_provider: Callable[[], CursorValueType],
|
24
|
+
granularity: timedelta,
|
25
|
+
) -> None:
|
26
|
+
self._clamping_strategy = clamping_strategy
|
27
|
+
self._end_provider = end_provider
|
28
|
+
self._granularity = granularity
|
29
|
+
|
30
|
+
def __call__(self) -> CursorValueType:
|
31
|
+
return self._clamping_strategy.clamp(self._end_provider()) - self._granularity
|
32
|
+
|
33
|
+
|
34
|
+
class DayClampingStrategy(ClampingStrategy):
|
35
|
+
def __init__(self, is_ceiling: bool = True) -> None:
|
36
|
+
self._is_ceiling = is_ceiling
|
37
|
+
|
38
|
+
def clamp(self, value: datetime) -> datetime: # type: ignore # datetime implements method from CursorValueType
|
39
|
+
return_value = value.replace(hour=0, minute=0, second=0, microsecond=0)
|
40
|
+
if self._is_ceiling:
|
41
|
+
return return_value + timedelta(days=1)
|
42
|
+
return return_value
|
43
|
+
|
44
|
+
|
45
|
+
class MonthClampingStrategy(ClampingStrategy):
|
46
|
+
def __init__(self, is_ceiling: bool = True) -> None:
|
47
|
+
self._is_ceiling = is_ceiling
|
48
|
+
|
49
|
+
def clamp(self, value: datetime) -> datetime: # type: ignore # datetime implements method from CursorValueType
|
50
|
+
return_value = value.replace(hour=0, minute=0, second=0, microsecond=0)
|
51
|
+
needs_to_round = value.day != 1
|
52
|
+
if not needs_to_round:
|
53
|
+
return return_value
|
54
|
+
|
55
|
+
return self._ceil(return_value) if self._is_ceiling else return_value.replace(day=1)
|
56
|
+
|
57
|
+
def _ceil(self, value: datetime) -> datetime:
|
58
|
+
return value.replace(
|
59
|
+
year=value.year + 1 if value.month == 12 else value.year,
|
60
|
+
month=(value.month % 12) + 1,
|
61
|
+
day=1,
|
62
|
+
hour=0,
|
63
|
+
minute=0,
|
64
|
+
second=0,
|
65
|
+
microsecond=0,
|
66
|
+
)
|
67
|
+
|
68
|
+
|
69
|
+
class Weekday(Enum):
|
70
|
+
"""
|
71
|
+
These integer values map to the same ones used by the Datetime.date.weekday() implementation
|
72
|
+
"""
|
73
|
+
|
74
|
+
MONDAY = 0
|
75
|
+
TUESDAY = 1
|
76
|
+
WEDNESDAY = 2
|
77
|
+
THURSDAY = 3
|
78
|
+
FRIDAY = 4
|
79
|
+
SATURDAY = 5
|
80
|
+
SUNDAY = 6
|
81
|
+
|
82
|
+
|
83
|
+
class WeekClampingStrategy(ClampingStrategy):
|
84
|
+
def __init__(self, day_of_week: Weekday, is_ceiling: bool = True) -> None:
|
85
|
+
self._day_of_week = day_of_week.value
|
86
|
+
self._is_ceiling = is_ceiling
|
87
|
+
|
88
|
+
def clamp(self, value: datetime) -> datetime: # type: ignore # datetime implements method from CursorValueType
|
89
|
+
days_diff_to_ceiling = (
|
90
|
+
7 - (value.weekday() - self._day_of_week)
|
91
|
+
if value.weekday() > self._day_of_week
|
92
|
+
else abs(value.weekday() - self._day_of_week)
|
93
|
+
)
|
94
|
+
delta = (
|
95
|
+
timedelta(days_diff_to_ceiling)
|
96
|
+
if self._is_ceiling
|
97
|
+
else timedelta(days_diff_to_ceiling - 7)
|
98
|
+
)
|
99
|
+
return value.replace(hour=0, minute=0, second=0, microsecond=0) + delta
|