omnata-plugin-runtime 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,286 @@
1
+
2
+ import json
3
+ import sys
4
+ import os
5
+ import importlib
6
+ import datetime
7
+ from logging import getLogger
8
+ from typing import Dict, List, Union, Optional
9
+ from omnata_plugin_runtime.forms import ConnectionMethod
10
+ from omnata_plugin_runtime.logging import OmnataPluginLogHandler
11
+ from omnata_plugin_runtime.rate_limiting import ApiLimits
12
+ from pydantic import BaseModel, parse_obj_as # pylint: disable=no-name-in-module
13
+ from pydantic.json import pydantic_encoder as pydantic_json_encoder # pylint:disable=no-name-in-module
14
+ from snowflake.snowpark import Session
15
+ from omnata_plugin_runtime.omnata_plugin import OmnataPlugin, HttpRateLimiting, OutboundSyncRequest, InboundSyncRequest
16
+ from omnata_plugin_runtime.configuration import OutboundSyncConfigurationParameters,SyncDirection,InboundSyncConfigurationParameters, OutboundSyncStrategy, StoredConfigurationValue,SyncConfigurationParameters,InboundSyncStreamsConfiguration, StreamConfiguration, StoredStreamConfiguration, SyncConfigurationParameters, StoredMappingValue
17
+ from omnata_plugin_runtime.rate_limiting import ApiLimits, RetryLaterException, RateLimitState, RequestRateLimit
18
+ from omnata_plugin_runtime.api import ApplyPayload
19
+ logger = getLogger(__name__)
20
+ # this is the new API for secrets access (https://docs.snowflake.com/en/LIMITEDACCESS/secret-api-reference)
21
+ import _snowflake # pylint: disable=import-error
22
+ IMPORT_DIRECTORY_NAME = "snowflake_import_directory"
23
+
24
+ class PluginEntrypoint():
25
+ """
26
+ This class gives each plugin's stored procs an initial point of contact.
27
+ It will only work within Snowflake because it uses the _snowflake module.
28
+ """
29
+ def __init__(self,
30
+ plugin_fqn:str,
31
+ session:Session,
32
+ module_name:str,
33
+ class_name:str):
34
+ logger.info('Initialising plugin entrypoint')
35
+ self._session = session
36
+ import_dir = sys._xoptions[IMPORT_DIRECTORY_NAME]
37
+ sys.path.append(os.path.join(import_dir,'app.zip'))
38
+ module = importlib.import_module(module_name)
39
+ class_obj = getattr(module, class_name)
40
+ self._plugin_instance:OmnataPlugin = class_obj()
41
+
42
+ def apply(self,apply_request:Dict):
43
+ logger.info('Entered apply method')
44
+ request = parse_obj_as(ApplyPayload,apply_request)
45
+ connection_secrets = self.get_secrets(request.oauth_secret_name,request.other_secrets_name)
46
+ omnata_log_handler = OmnataPluginLogHandler(session=self._session,
47
+ sync_id=request.sync_id,
48
+ sync_branch_id=request.sync_branch_id,
49
+ connection_id=request.connection_id,
50
+ sync_run_id=request.run_id)
51
+ omnata_log_handler.register(request.logging_level,self._plugin_instance.additional_loggers())
52
+ # construct some generic parameters for the purpose of getting the api limits
53
+ base_parameters = SyncConfigurationParameters(request.connection_method,
54
+ connection_parameters=request.connection_parameters,
55
+ connection_secrets=connection_secrets,
56
+ sync_parameters=request.sync_parameters,
57
+ current_form_parameters={})
58
+ all_api_limits = self._plugin_instance.api_limits(base_parameters)
59
+ logger.info(f"Default API limits: {json.dumps(all_api_limits, default=pydantic_json_encoder)}")
60
+ all_api_limits_by_category = {api_limit.endpoint_category:api_limit for api_limit in all_api_limits}
61
+ all_api_limits_by_category.update({k:v for k,v in [(x.endpoint_category,x) for x in request.api_limit_overrides]})
62
+ api_limits = list(all_api_limits_by_category.values())
63
+ # if any endpoint categories have no state, give them an empty state
64
+ for api_limit in api_limits:
65
+ if api_limit.endpoint_category not in request.rate_limits_state:
66
+ request.rate_limits_state[api_limit.endpoint_category] = RateLimitState()
67
+ logger.info(f"Rate limits state: {json.dumps(request.rate_limits_state, default=pydantic_json_encoder)}")
68
+ if request.sync_direction=='outbound':
69
+ parameters = OutboundSyncConfigurationParameters(
70
+ connection_method=request.connection_method,
71
+ connection_parameters=request.connection_parameters,
72
+ connection_secrets=connection_secrets,
73
+ sync_parameters=request.sync_parameters,
74
+ current_form_parameters={},
75
+ sync_strategy=request.sync_strategy,
76
+ field_mappings=request.field_mappings,
77
+ )
78
+
79
+ outbound_sync_request = OutboundSyncRequest(run_id=request.run_id,
80
+ session=self._session,
81
+ source_app_name=request.source_app_name,
82
+ results_schema_name=request.results_schema_name,
83
+ results_table_name=request.results_table_name,
84
+ plugin_instance=self._plugin_instance,
85
+ api_limits=request.api_limits,
86
+ rate_limit_state=request.rate_limits_state,
87
+ run_deadline=datetime.datetime.now() + datetime.timedelta(hours=4),
88
+ development_mode=False
89
+ )
90
+
91
+ with HttpRateLimiting(outbound_sync_request,parameters):
92
+ combined_results = self._plugin_instance.sync_outbound(parameters,outbound_sync_request)
93
+ if combined_results is not None:
94
+ outbound_sync_request.enqueue_results(combined_results)
95
+ outbound_sync_request.apply_results_queue()
96
+
97
+ elif request.sync_direction=='inbound':
98
+ logger.info("Running inbound sync")
99
+ parameters = InboundSyncConfigurationParameters(
100
+ connection_method=request.connection_method,
101
+ connection_parameters=request.connection_parameters,
102
+ connection_secrets=connection_secrets,
103
+ sync_parameters=request.sync_parameters,
104
+ current_form_parameters={})
105
+
106
+ # build streams object from parameters
107
+ streams_list:List[StoredStreamConfiguration] = []
108
+ streams_list = streams_list + list(request.streams_configuration.included_streams.values())
109
+
110
+ # if new streams are included, we need to fetch the list first to find them
111
+ if request.streams_configuration.include_new_streams:
112
+ # we have to invoke the inbound_configuration_form to get the StreamLister, as the list
113
+ # of streams may vary based on the sync parameters
114
+ form = self._plugin_instance.inbound_configuration_form(parameters)
115
+ if form.stream_lister is None:
116
+ logger.info("No stream lister defined, skipping new stream detection")
117
+ else:
118
+ all_streams:List[StreamConfiguration] = getattr(self._plugin_instance,form.stream_lister.source_function)(parameters)
119
+ for s in all_streams:
120
+ if s.stream_name not in request.streams_configuration.included_streams and \
121
+ s.stream_name not in request.streams_configuration.excluded_streams:
122
+ if request.streams_configuration.new_stream_sync_strategy not in s.supported_sync_strategies:
123
+ raise ValueError(f"New object {s.stream_name} was found, but does not support the defined sync strategy {request.streams_configuration}")
124
+
125
+ new_stream = StoredStreamConfiguration(stream_name=s.stream_name,
126
+ cursor_field=s.source_defined_cursor,
127
+ primary_key_field=s.source_defined_primary_key,
128
+ latest_state={},
129
+ storage_behaviour=request.streams_configuration.new_stream_storage_behaviour,
130
+ stream=s,
131
+ sync_strategy=request.streams_configuration.new_stream_sync_strategy
132
+ )
133
+ streams_list.append(new_stream)
134
+
135
+ for stream in streams_list:
136
+ if stream.stream_name in request.latest_stream_state:
137
+ stream.latest_state = request.request.latest_stream_state[stream.stream_name]
138
+ logger.info(f"Updating stream state for {stream.stream_name}: {stream.latest_state}")
139
+ else:
140
+ logger.info(f"Existing stream state for {stream.stream_name} not found")
141
+ # update the sync run to show the total number of streams to check
142
+ # TODO: make this a function call
143
+ #sync_run.INBOUND_STREAM_TOTAL_COUNTS = {stream.stream_name:0 for stream in streams_list}
144
+ #sync_run.update_object()
145
+ logger.info(f"streams list: {streams_list}")
146
+ logger.info(f"streams config: {request.streams_configuration}")
147
+ inbound_sync_request = InboundSyncRequest(run_id=request.run_id,
148
+ session=self._session,
149
+ source_app_name=request.source_app_name,
150
+ results_schema_name=request.results_schema_name,
151
+ results_table_name=request.results_table_name,
152
+ plugin_instance=self._plugin_instance,
153
+ api_limits=request.api_limits,
154
+ rate_limit_state=request.rate_limits_state,
155
+ run_deadline=datetime.datetime.now() + datetime.timedelta(hours=4),
156
+ development_mode=False,
157
+ streams=streams_list
158
+ )
159
+
160
+ inbound_sync_request.update_activity(f"Invoking plugin")
161
+ logger.info(f"inbound sync request: {inbound_sync_request}")
162
+ #plugin_instance._inbound_sync_request = outbound_sync_request
163
+ with HttpRateLimiting(inbound_sync_request,parameters):
164
+ combined_results = self._plugin_instance.sync_inbound(parameters,inbound_sync_request)
165
+ logger.info(f"Finished invoking plugin")
166
+ inbound_sync_request.update_activity(f"Staging remaining records")
167
+ logger.info(f"Calling apply_results_queue")
168
+ inbound_sync_request.apply_results_queue()
169
+ omnata_log_handler.flush()
170
+ # we need to calculate counts for:
171
+ # CHANGED_COUNT by counting up the records in INBOUND_STREAM_RECORD_COUNTS
172
+ logger.info('Finished applying records')
173
+ omnata_log_handler.flush()
174
+ omnata_log_handler.unregister()
175
+ return {
176
+ "streams": dict(streams_list)
177
+ }
178
+
179
+ def configuration_form(self,
180
+ connection_method:str,
181
+ connection_parameters:Dict,
182
+ oauth_secret_name:Optional[str],
183
+ other_secrets_name:Optional[str],
184
+ sync_direction:str,
185
+ sync_strategy:Dict,
186
+ function_name:str,
187
+ sync_parameters:Dict,
188
+ current_form_parameters:Optional[Dict]):
189
+ logger.info(f'Entered configuration_form method')
190
+ sync_strategy = normalise_nulls(sync_strategy)
191
+ oauth_secret_name = normalise_nulls(oauth_secret_name)
192
+ other_secrets_name = normalise_nulls(other_secrets_name)
193
+ connection_secrets=self.get_secrets(oauth_secret_name,other_secrets_name)
194
+ connection_parameters = parse_obj_as(Dict[str,StoredConfigurationValue],connection_parameters)
195
+ sync_parameters = parse_obj_as(Dict[str,StoredConfigurationValue],sync_parameters)
196
+ form_parameters = None
197
+ if current_form_parameters is not None:
198
+ form_parameters = parse_obj_as(Dict[str,StoredConfigurationValue],current_form_parameters)
199
+ if sync_direction=='outbound':
200
+ sync_strat = OutboundSyncStrategy.parse_obj(sync_strategy)
201
+ parameters = OutboundSyncConfigurationParameters(connection_parameters=connection_parameters,
202
+ connection_secrets=connection_secrets,
203
+ sync_strategy=sync_strat,
204
+ sync_parameters=sync_parameters,
205
+ connection_method=connection_method,
206
+ current_form_parameters=form_parameters)
207
+ elif sync_direction=='inbound':
208
+ parameters = InboundSyncConfigurationParameters(connection_parameters=connection_parameters,
209
+ connection_secrets=connection_secrets,
210
+ sync_parameters=sync_parameters,
211
+ connection_method=connection_method,
212
+ current_form_parameters=form_parameters)
213
+ else:
214
+ raise ValueError(f'Unknown direction {sync_direction}')
215
+ the_function = getattr(self._plugin_instance, function_name or f"{sync_direction}_configuration_form")
216
+ script_result = the_function(parameters)
217
+ if isinstance(script_result,BaseModel):
218
+ script_result = script_result.dict()
219
+ elif isinstance(script_result,List):
220
+ if len(script_result) > 0 and isinstance(script_result[0],BaseModel):
221
+ script_result = [r.dict() for r in script_result]
222
+ return script_result
223
+
224
+ def connection_form(self):
225
+ try:
226
+ logger.info('Entered connection_form method')
227
+ form:List[ConnectionMethod] = self._plugin_instance.connection_form()
228
+ return [f.dict() for f in form]
229
+ except Exception as exception:
230
+ logger.error(str(exception),exc_info=True,stack_info=True)
231
+ return {
232
+ "success": False,
233
+ "error": str(exception)
234
+ }
235
+
236
+ def get_secrets(self,oauth_secret_name:Optional[str],other_secrets_name:Optional[str]) -> Dict[str,StoredConfigurationValue]:
237
+ connection_secrets={}
238
+ if oauth_secret_name is not None:
239
+ connection_secrets['access_token']= StoredConfigurationValue(value=_snowflake.get_oauth_access_token(oauth_secret_name))
240
+ if other_secrets_name is not None:
241
+ try:
242
+ secret_string_content = _snowflake.get_generic_secret_string(other_secrets_name)
243
+ other_secrets = json.loads(secret_string_content)
244
+ except Exception as exception:
245
+ logger.error(f'Error parsing secrets content: {str(exception)}')
246
+ raise ValueError(f'Error parsing secrets content: {str(exception)}')
247
+ connection_secrets = parse_obj_as(Dict[str,StoredConfigurationValue],other_secrets)
248
+ return connection_secrets
249
+
250
+ def connect(self,method,connection_parameters:Dict,network_rule_name:str,oauth_secret_name:Optional[str],other_secrets_name:Optional[str]):
251
+ logger.info('Entered connect method')
252
+ logger.info(f'Connection parameters: {connection_parameters}')
253
+ connection_secrets = self.get_secrets(oauth_secret_name,other_secrets_name)
254
+
255
+ from omnata_plugin_runtime.omnata_plugin import ConnectionConfigurationParameters
256
+ connect_response = self._plugin_instance.connect(ConnectionConfigurationParameters(connection_method=method,
257
+ connection_parameters=parse_obj_as(Dict[str,StoredConfigurationValue],connection_parameters),
258
+ connection_secrets=parse_obj_as(Dict[str,StoredConfigurationValue],connection_secrets)))
259
+ # the connect method can also return more network addresses. If so, we need to update the
260
+ # network rule associated with the external access integration
261
+ if connect_response.network_addresses is not None:
262
+ existing_rule_result = self._session.sql(f'desc network rule {network_rule_name}').collect()
263
+ rule_values:List[str] = existing_rule_result[0].value_list.split(',')
264
+ for network_address in connect_response.network_addresses:
265
+ if network_address not in rule_values:
266
+ rule_values.append(network_address)
267
+ rule_values_string = ','.join([f"'{value}'" for value in rule_values])
268
+ self._session.sql(f'alter network rule {network_rule_name} set value_list = ({rule_values_string})').collect()
269
+
270
+ return connect_response.dict()
271
+
272
+ def api_limits(self):
273
+ logger.info('Entered api_limits method')
274
+ response:List[ApiLimits] = self._plugin_instance.api_limits(None)
275
+ return [api_limit.dict() for api_limit in response]
276
+
277
+ def normalise_nulls(obj):
278
+ """
279
+ If an object came through a SQL interface with a null value, we convert it to a regular None here
280
+ """
281
+ if type(obj).__name__=='sqlNullWrapper':
282
+ return None
283
+ # handle a bunch of objects being given at once
284
+ if type(obj).__name__=='list':
285
+ return [normalise_nulls(x) for x in obj]
286
+ return obj
@@ -0,0 +1,232 @@
1
+ # it's not the 1980s anymore
2
+ # pylint: disable=line-too-long,multiple-imports,logging-fstring-interpolation
3
+ """
4
+ Contains functionality for limiting http requests made by Omnata plugins
5
+ """
6
+ from __future__ import annotations
7
+ from pydantic import Field
8
+ from typing import List, Literal, Optional
9
+ from logging import getLogger
10
+ from email.utils import parsedate_to_datetime
11
+ from .configuration import SubscriptableBaseModel
12
+ import datetime
13
+ import threading
14
+ import requests
15
+ import re
16
+ logger = getLogger(__name__)
17
+
18
+ TimeUnitType = Literal['second','minute','hour','day']
19
+
20
+ HttpMethodType = Literal['GET','HEAD','POST','PUT','DELETE','CONNECT','OPTIONS','TRACE','PATCH']
21
+
22
+ class HttpRequestMatcher(SubscriptableBaseModel):
23
+ """
24
+ A class used to match an HTTP request
25
+ """
26
+ http_methods:List[HttpMethodType]
27
+ url_regex:str
28
+
29
+ @classmethod
30
+ def match_all(cls):
31
+ """
32
+ A HttpRequestMatcher which will match all requests.
33
+ """
34
+ return cls(http_methods=['GET','HEAD','POST','PUT','DELETE','CONNECT','OPTIONS','TRACE','PATCH'],
35
+ url_regex='.*')
36
+
37
+ class ApiLimits(SubscriptableBaseModel):
38
+ """
39
+ Encapsulates the constraints imposed by an app's APIs
40
+ """
41
+ endpoint_category:str = Field(
42
+ 'All endpoints',
43
+ description='the name of the API category (e.g. "Data loading endpoints")',
44
+ )
45
+ request_matchers:List[HttpRequestMatcher]=Field(
46
+ [HttpRequestMatcher.match_all()],
47
+ description="a list of request matchers. If None is provided, all requests will be matched",
48
+ )
49
+ request_rates:List[RequestRateLimit]=Field(
50
+ None,
51
+ description="imposes time delays between requests to stay under a defined rate limit",
52
+ )
53
+
54
+ def request_matches(self,method:HttpMethodType,url:str):
55
+ """
56
+ Given the request matchers that exist, determines whether the provided HTTP method and url is a match
57
+ """
58
+ for request_matcher in self.request_matchers:
59
+ if method in request_matcher.http_methods and re.search(request_matcher.url_regex,url):
60
+ return True
61
+ return False
62
+
63
+ @classmethod
64
+ def apply_overrides(cls, default_api_limits:List[ApiLimits],overridden_values:List[ApiLimits]) -> List[ApiLimits]:
65
+ """
66
+ Takes a list of default api limits, and replaces them with any overridden values
67
+ """
68
+ if overridden_values is None or len(overridden_values)==0:
69
+ return default_api_limits
70
+ overrides_keyed = {l.construct_from_variant:l for l in overridden_values}
71
+ for api_limit in default_api_limits:
72
+ if api_limit.endpoint_category in overrides_keyed.keys():
73
+ api_limit.request_rates = overrides_keyed[api_limit.endpoint_category]
74
+ return default_api_limits
75
+
76
+ @classmethod
77
+ def request_match(cls, all_api_limits:List[ApiLimits],method:HttpMethodType,url:str) -> Optional[ApiLimits]:
78
+ """
79
+ Given a set of defined API limits, return the first one that matches, or None if none of them match
80
+ """
81
+ for api_limits in all_api_limits:
82
+ if api_limits.request_matches(method,url):
83
+ return api_limits
84
+ return None
85
+
86
+ def calculate_wait(self,rate_limit_state:RateLimitState) -> datetime.datetime:
87
+ """
88
+ Based on the configured wait limits, given a sorted list of previous requests (newest to oldest),
89
+ determine when the next request is allowed to occur.
90
+ Each rate limit is a number of requests over a time window.
91
+ Examples:
92
+ If the rate limit is 5 requests every 10 seconds, we:
93
+ - determine the timestamp of the 5th most recent request
94
+ - add 10 seconds to that timestamp
95
+ The resulting timestamp is when the next request can be made (if it's in the past, it can be done immediately)
96
+ If multiple rate limits exist, the maximum timestamp is used (i.e. the most restrictive rate limit applies)
97
+ """
98
+ logger.info(f"calculating wait time, given previous requests as {rate_limit_state.previous_request_timestamps}")
99
+ if self.request_rates is None:
100
+ return datetime.datetime.utcnow()
101
+ longest_wait = datetime.datetime.utcnow()
102
+ if rate_limit_state.wait_until is not None and rate_limit_state.wait_until > longest_wait:
103
+ longest_wait = rate_limit_state.wait_until
104
+ for request_rate in self.request_rates:
105
+ if len(rate_limit_state.previous_request_timestamps) > 0:
106
+ request_index = request_rate.request_count - 1
107
+ if len(rate_limit_state.previous_request_timestamps) < request_index:
108
+ request_index = len(rate_limit_state.previous_request_timestamps)-1
109
+ timestamp_at_horizon = rate_limit_state.previous_request_timestamps[request_index]
110
+ next_allowed_request = timestamp_at_horizon + datetime.timedelta(seconds=request_rate.number_of_seconds())
111
+ if next_allowed_request > longest_wait:
112
+ longest_wait = next_allowed_request
113
+
114
+ return longest_wait
115
+
116
+ class RateLimitState(SubscriptableBaseModel):
117
+ """
118
+ Encapsulates the rate limiting state of an endpoint category
119
+ for a particular connection (as opposed to configuration)
120
+ """
121
+ wait_until:Optional[datetime.datetime]=Field(
122
+ None,
123
+ description="Providing a value here means that no requests should occur until a specific moment in the future",
124
+ )
125
+ previous_request_timestamps:Optional[List[datetime.datetime]]=Field(
126
+ [],
127
+ description="A list of timestamps where previous requests have been made, used to calculate the next request time",
128
+ )
129
+ _request_timestamps_lock = threading.Lock()
130
+
131
+ def register_http_request(self):
132
+ """
133
+ Registers a request as having just occurred, for rate limiting purposes.
134
+ You only need to use this if your HTTP requests are not automatically being
135
+ registered, which happens if http.client.HTTPConnection is not being used.
136
+ """
137
+ with self._request_timestamps_lock:
138
+ append_time = datetime.datetime.utcnow()
139
+ self.previous_request_timestamps.insert(0,append_time)
140
+
141
+ def prune_history(self,request_rates:List[RequestRateLimit]=None):
142
+ """
143
+ When we store the request history, it doesn't make sense to go back indefinitely.
144
+ We only need the requests which fall within the longest rate limiting window
145
+
146
+ """
147
+ longest_window_seconds = max([rate.number_of_seconds() for rate in request_rates])
148
+ irrelevance_horizon = datetime.datetime.now() - datetime.timedelta(seconds=longest_window_seconds)
149
+ self.previous_request_timestamps = [ts for ts in self.previous_request_timestamps if ts > irrelevance_horizon]
150
+
151
+ class RequestRateLimit(SubscriptableBaseModel):
152
+ """
153
+ Request rate limits
154
+ Defined as a request count, time unit and number of units e.g. (1,"second",5) = 1 request per 5 seconds, or (100, "minute", 15) = 100 requests per 15 minutes
155
+ """
156
+ request_count:int
157
+ time_unit:TimeUnitType
158
+ unit_count:int
159
+
160
+ def number_of_seconds(self):
161
+ """
162
+ Converts the time_unit and unit_count to a number of seconds.
163
+ E.g. 5 minutes = 300
164
+ 2 hours = 7200
165
+ """
166
+ if self.time_unit=='second':
167
+ return self.unit_count
168
+ elif self.time_unit=='minute':
169
+ return self.unit_count*60
170
+ elif self.time_unit=='hour':
171
+ return self.unit_count*3600
172
+ elif self.time_unit=='day':
173
+ return self.unit_count*86400
174
+ else:
175
+ raise ValueError(f"Unknown time unit: {self.time_unit}")
176
+
177
+ def to_description(self) -> str:
178
+ """Returns a readable description of this limit.
179
+ For example:
180
+ "1 request per minute"
181
+ "5 requests per 2 seconds"
182
+
183
+ Returns:
184
+ str: the description as described above
185
+ """
186
+ return str(self.request_count) + " " + \
187
+ "request" + ("s" if self.request_count > 1 else "") + \
188
+ " per " + \
189
+ (self.time_unit if self.unit_count==1 else f"{self.unit_count} {self.time_unit}s")
190
+
191
+ class RetryLaterException(Exception):
192
+ """
193
+ Exception raised when the app has notified that rate limits are exceeded.
194
+ Throwing this during record apply imposes a temporary extra API constraint that
195
+ we need to wait until a future date before more requests are made.
196
+
197
+ """
198
+
199
+ def __init__(self, future_datetime:datetime.datetime):
200
+ self.future_datetime = future_datetime
201
+ message = "Remote system wants us to retry later"
202
+ self.message = message
203
+ super().__init__(self.message)
204
+
205
+ class InterruptedWhileWaitingException(Exception):
206
+ """
207
+ Indicates that while waiting for rate limiting to expire, the sync was interrupted
208
+ """
209
+
210
+ def too_many_requests_hook(fallback_future_datetime:datetime.datetime = datetime.datetime.utcnow() + datetime.timedelta(hours=24)):
211
+ """
212
+ A Requests hook which raises a RetryLaterException if an HTTP 429 response is returned.
213
+ Examines the Retry-After header (https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After)
214
+ to determine the appropriate future datetime to retry.
215
+ If that isn't available, it falls back to fallback_future_datetime.
216
+
217
+ """
218
+ def hook(resp:requests.Response, *args, **kwargs):
219
+ """
220
+ The actual hook implementation
221
+ """
222
+ if resp.status_code==429:
223
+ if 'Retry-After' in resp.headers:
224
+ retry_after:str = resp.headers['Retry-After']
225
+ if retry_after.isnumeric():
226
+ raise RetryLaterException(future_datetime=datetime.datetime.utcnow() + datetime.timedelta(seconds=int(retry_after)))
227
+ retry_date = parsedate_to_datetime(retry_after)
228
+ raise RetryLaterException(future_datetime=retry_date)
229
+ raise RetryLaterException(future_datetime=fallback_future_datetime)
230
+ return hook
231
+
232
+ ApiLimits.update_forward_refs()
@@ -0,0 +1,50 @@
1
+
2
+ from typing import Dict, List
3
+ from pydantic import parse_obj_as # pylint:disable=no-name-in-module
4
+ from logging import getLogger
5
+ from .configuration import StoredConfigurationValue, StoredMappingValue, StoredFieldMapping
6
+ logger = getLogger(__name__)
7
+
8
+ class RecordTransformerResult(dict):
9
+ """
10
+ The result of transforming an individual record
11
+ """
12
+ def __init__(self, success:bool,errors:List[str] = None,transformed_record:dict = None):
13
+ """
14
+ Constructs a RecordTransformerResult
15
+ """
16
+ if success:
17
+ dict.__init__(self, success=success,transformed_record=transformed_record)
18
+ else:
19
+ dict.__init__(self, success=success,errors=errors)
20
+
21
+ class RecordTransformer():
22
+ """
23
+ The default record transformer, used to convert Snowflake records into an app-specific representation
24
+ using the configured field mappings, and optionally perform data validation so that source failures occur
25
+ """
26
+ def transform_record(self,source_metadata:dict,sync_parameters:Dict[str, StoredConfigurationValue],field_mappings:StoredMappingValue,source_record:dict) -> RecordTransformerResult:
27
+ """
28
+ Default transformer.
29
+ When the visual field mapper is used, simply picks out the columns and renames them.
30
+ When the jinja mapper is used, copies the template into the output
31
+ """
32
+ transformed_record = {}
33
+ errors = []
34
+ if field_mappings.mapper_type == 'jinja_template':
35
+ transformed_record['jinja_template'] = field_mappings.jinja_template
36
+ transformed_record['source_record'] = source_record
37
+ elif field_mappings.mapper_type == 'field_mapping_selector':
38
+ parsed_field_mappings:List[StoredFieldMapping] = parse_obj_as(List[StoredFieldMapping],field_mappings.field_mappings)
39
+ for field_mapping in parsed_field_mappings:
40
+ if field_mapping.source_column not in source_record:
41
+ errors.append(f"Column '{field_mapping.source_column}' not found in record")
42
+ else:
43
+ transformed_record[field_mapping.app_field] = source_record[field_mapping.source_column]
44
+ else:
45
+ errors.append(f"Unrecognised mapper type: {field_mappings.mapper_type}")
46
+ logger.debug(f"Transformed record: {transformed_record}")
47
+ if len(errors) > 0:
48
+ logger.info(f"Record transformer errors: {errors}")
49
+ return dict(RecordTransformerResult(success=False, errors=errors))
50
+ return dict(RecordTransformerResult(success=True, transformed_record=transformed_record))