omnata-plugin-runtime 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- omnata_plugin_runtime/__init__.py +39 -0
- omnata_plugin_runtime/api.py +73 -0
- omnata_plugin_runtime/configuration.py +593 -0
- omnata_plugin_runtime/forms.py +306 -0
- omnata_plugin_runtime/logging.py +91 -0
- omnata_plugin_runtime/omnata_plugin.py +1154 -0
- omnata_plugin_runtime/plugin_entrypoints.py +286 -0
- omnata_plugin_runtime/rate_limiting.py +232 -0
- omnata_plugin_runtime/record_transformer.py +50 -0
- omnata_plugin_runtime-0.1.0.dist-info/LICENSE +504 -0
- omnata_plugin_runtime-0.1.0.dist-info/METADATA +28 -0
- omnata_plugin_runtime-0.1.0.dist-info/RECORD +13 -0
- omnata_plugin_runtime-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,286 @@
|
|
1
|
+
|
2
|
+
import json
|
3
|
+
import sys
|
4
|
+
import os
|
5
|
+
import importlib
|
6
|
+
import datetime
|
7
|
+
from logging import getLogger
|
8
|
+
from typing import Dict, List, Union, Optional
|
9
|
+
from omnata_plugin_runtime.forms import ConnectionMethod
|
10
|
+
from omnata_plugin_runtime.logging import OmnataPluginLogHandler
|
11
|
+
from omnata_plugin_runtime.rate_limiting import ApiLimits
|
12
|
+
from pydantic import BaseModel, parse_obj_as # pylint: disable=no-name-in-module
|
13
|
+
from pydantic.json import pydantic_encoder as pydantic_json_encoder # pylint:disable=no-name-in-module
|
14
|
+
from snowflake.snowpark import Session
|
15
|
+
from omnata_plugin_runtime.omnata_plugin import OmnataPlugin, HttpRateLimiting, OutboundSyncRequest, InboundSyncRequest
|
16
|
+
from omnata_plugin_runtime.configuration import OutboundSyncConfigurationParameters,SyncDirection,InboundSyncConfigurationParameters, OutboundSyncStrategy, StoredConfigurationValue,SyncConfigurationParameters,InboundSyncStreamsConfiguration, StreamConfiguration, StoredStreamConfiguration, SyncConfigurationParameters, StoredMappingValue
|
17
|
+
from omnata_plugin_runtime.rate_limiting import ApiLimits, RetryLaterException, RateLimitState, RequestRateLimit
|
18
|
+
from omnata_plugin_runtime.api import ApplyPayload
|
19
|
+
logger = getLogger(__name__)
|
20
|
+
# this is the new API for secrets access (https://docs.snowflake.com/en/LIMITEDACCESS/secret-api-reference)
|
21
|
+
import _snowflake # pylint: disable=import-error
|
22
|
+
IMPORT_DIRECTORY_NAME = "snowflake_import_directory"
|
23
|
+
|
24
|
+
class PluginEntrypoint():
|
25
|
+
"""
|
26
|
+
This class gives each plugin's stored procs an initial point of contact.
|
27
|
+
It will only work within Snowflake because it uses the _snowflake module.
|
28
|
+
"""
|
29
|
+
def __init__(self,
|
30
|
+
plugin_fqn:str,
|
31
|
+
session:Session,
|
32
|
+
module_name:str,
|
33
|
+
class_name:str):
|
34
|
+
logger.info('Initialising plugin entrypoint')
|
35
|
+
self._session = session
|
36
|
+
import_dir = sys._xoptions[IMPORT_DIRECTORY_NAME]
|
37
|
+
sys.path.append(os.path.join(import_dir,'app.zip'))
|
38
|
+
module = importlib.import_module(module_name)
|
39
|
+
class_obj = getattr(module, class_name)
|
40
|
+
self._plugin_instance:OmnataPlugin = class_obj()
|
41
|
+
|
42
|
+
def apply(self,apply_request:Dict):
|
43
|
+
logger.info('Entered apply method')
|
44
|
+
request = parse_obj_as(ApplyPayload,apply_request)
|
45
|
+
connection_secrets = self.get_secrets(request.oauth_secret_name,request.other_secrets_name)
|
46
|
+
omnata_log_handler = OmnataPluginLogHandler(session=self._session,
|
47
|
+
sync_id=request.sync_id,
|
48
|
+
sync_branch_id=request.sync_branch_id,
|
49
|
+
connection_id=request.connection_id,
|
50
|
+
sync_run_id=request.run_id)
|
51
|
+
omnata_log_handler.register(request.logging_level,self._plugin_instance.additional_loggers())
|
52
|
+
# construct some generic parameters for the purpose of getting the api limits
|
53
|
+
base_parameters = SyncConfigurationParameters(request.connection_method,
|
54
|
+
connection_parameters=request.connection_parameters,
|
55
|
+
connection_secrets=connection_secrets,
|
56
|
+
sync_parameters=request.sync_parameters,
|
57
|
+
current_form_parameters={})
|
58
|
+
all_api_limits = self._plugin_instance.api_limits(base_parameters)
|
59
|
+
logger.info(f"Default API limits: {json.dumps(all_api_limits, default=pydantic_json_encoder)}")
|
60
|
+
all_api_limits_by_category = {api_limit.endpoint_category:api_limit for api_limit in all_api_limits}
|
61
|
+
all_api_limits_by_category.update({k:v for k,v in [(x.endpoint_category,x) for x in request.api_limit_overrides]})
|
62
|
+
api_limits = list(all_api_limits_by_category.values())
|
63
|
+
# if any endpoint categories have no state, give them an empty state
|
64
|
+
for api_limit in api_limits:
|
65
|
+
if api_limit.endpoint_category not in request.rate_limits_state:
|
66
|
+
request.rate_limits_state[api_limit.endpoint_category] = RateLimitState()
|
67
|
+
logger.info(f"Rate limits state: {json.dumps(request.rate_limits_state, default=pydantic_json_encoder)}")
|
68
|
+
if request.sync_direction=='outbound':
|
69
|
+
parameters = OutboundSyncConfigurationParameters(
|
70
|
+
connection_method=request.connection_method,
|
71
|
+
connection_parameters=request.connection_parameters,
|
72
|
+
connection_secrets=connection_secrets,
|
73
|
+
sync_parameters=request.sync_parameters,
|
74
|
+
current_form_parameters={},
|
75
|
+
sync_strategy=request.sync_strategy,
|
76
|
+
field_mappings=request.field_mappings,
|
77
|
+
)
|
78
|
+
|
79
|
+
outbound_sync_request = OutboundSyncRequest(run_id=request.run_id,
|
80
|
+
session=self._session,
|
81
|
+
source_app_name=request.source_app_name,
|
82
|
+
results_schema_name=request.results_schema_name,
|
83
|
+
results_table_name=request.results_table_name,
|
84
|
+
plugin_instance=self._plugin_instance,
|
85
|
+
api_limits=request.api_limits,
|
86
|
+
rate_limit_state=request.rate_limits_state,
|
87
|
+
run_deadline=datetime.datetime.now() + datetime.timedelta(hours=4),
|
88
|
+
development_mode=False
|
89
|
+
)
|
90
|
+
|
91
|
+
with HttpRateLimiting(outbound_sync_request,parameters):
|
92
|
+
combined_results = self._plugin_instance.sync_outbound(parameters,outbound_sync_request)
|
93
|
+
if combined_results is not None:
|
94
|
+
outbound_sync_request.enqueue_results(combined_results)
|
95
|
+
outbound_sync_request.apply_results_queue()
|
96
|
+
|
97
|
+
elif request.sync_direction=='inbound':
|
98
|
+
logger.info("Running inbound sync")
|
99
|
+
parameters = InboundSyncConfigurationParameters(
|
100
|
+
connection_method=request.connection_method,
|
101
|
+
connection_parameters=request.connection_parameters,
|
102
|
+
connection_secrets=connection_secrets,
|
103
|
+
sync_parameters=request.sync_parameters,
|
104
|
+
current_form_parameters={})
|
105
|
+
|
106
|
+
# build streams object from parameters
|
107
|
+
streams_list:List[StoredStreamConfiguration] = []
|
108
|
+
streams_list = streams_list + list(request.streams_configuration.included_streams.values())
|
109
|
+
|
110
|
+
# if new streams are included, we need to fetch the list first to find them
|
111
|
+
if request.streams_configuration.include_new_streams:
|
112
|
+
# we have to invoke the inbound_configuration_form to get the StreamLister, as the list
|
113
|
+
# of streams may vary based on the sync parameters
|
114
|
+
form = self._plugin_instance.inbound_configuration_form(parameters)
|
115
|
+
if form.stream_lister is None:
|
116
|
+
logger.info("No stream lister defined, skipping new stream detection")
|
117
|
+
else:
|
118
|
+
all_streams:List[StreamConfiguration] = getattr(self._plugin_instance,form.stream_lister.source_function)(parameters)
|
119
|
+
for s in all_streams:
|
120
|
+
if s.stream_name not in request.streams_configuration.included_streams and \
|
121
|
+
s.stream_name not in request.streams_configuration.excluded_streams:
|
122
|
+
if request.streams_configuration.new_stream_sync_strategy not in s.supported_sync_strategies:
|
123
|
+
raise ValueError(f"New object {s.stream_name} was found, but does not support the defined sync strategy {request.streams_configuration}")
|
124
|
+
|
125
|
+
new_stream = StoredStreamConfiguration(stream_name=s.stream_name,
|
126
|
+
cursor_field=s.source_defined_cursor,
|
127
|
+
primary_key_field=s.source_defined_primary_key,
|
128
|
+
latest_state={},
|
129
|
+
storage_behaviour=request.streams_configuration.new_stream_storage_behaviour,
|
130
|
+
stream=s,
|
131
|
+
sync_strategy=request.streams_configuration.new_stream_sync_strategy
|
132
|
+
)
|
133
|
+
streams_list.append(new_stream)
|
134
|
+
|
135
|
+
for stream in streams_list:
|
136
|
+
if stream.stream_name in request.latest_stream_state:
|
137
|
+
stream.latest_state = request.request.latest_stream_state[stream.stream_name]
|
138
|
+
logger.info(f"Updating stream state for {stream.stream_name}: {stream.latest_state}")
|
139
|
+
else:
|
140
|
+
logger.info(f"Existing stream state for {stream.stream_name} not found")
|
141
|
+
# update the sync run to show the total number of streams to check
|
142
|
+
# TODO: make this a function call
|
143
|
+
#sync_run.INBOUND_STREAM_TOTAL_COUNTS = {stream.stream_name:0 for stream in streams_list}
|
144
|
+
#sync_run.update_object()
|
145
|
+
logger.info(f"streams list: {streams_list}")
|
146
|
+
logger.info(f"streams config: {request.streams_configuration}")
|
147
|
+
inbound_sync_request = InboundSyncRequest(run_id=request.run_id,
|
148
|
+
session=self._session,
|
149
|
+
source_app_name=request.source_app_name,
|
150
|
+
results_schema_name=request.results_schema_name,
|
151
|
+
results_table_name=request.results_table_name,
|
152
|
+
plugin_instance=self._plugin_instance,
|
153
|
+
api_limits=request.api_limits,
|
154
|
+
rate_limit_state=request.rate_limits_state,
|
155
|
+
run_deadline=datetime.datetime.now() + datetime.timedelta(hours=4),
|
156
|
+
development_mode=False,
|
157
|
+
streams=streams_list
|
158
|
+
)
|
159
|
+
|
160
|
+
inbound_sync_request.update_activity(f"Invoking plugin")
|
161
|
+
logger.info(f"inbound sync request: {inbound_sync_request}")
|
162
|
+
#plugin_instance._inbound_sync_request = outbound_sync_request
|
163
|
+
with HttpRateLimiting(inbound_sync_request,parameters):
|
164
|
+
combined_results = self._plugin_instance.sync_inbound(parameters,inbound_sync_request)
|
165
|
+
logger.info(f"Finished invoking plugin")
|
166
|
+
inbound_sync_request.update_activity(f"Staging remaining records")
|
167
|
+
logger.info(f"Calling apply_results_queue")
|
168
|
+
inbound_sync_request.apply_results_queue()
|
169
|
+
omnata_log_handler.flush()
|
170
|
+
# we need to calculate counts for:
|
171
|
+
# CHANGED_COUNT by counting up the records in INBOUND_STREAM_RECORD_COUNTS
|
172
|
+
logger.info('Finished applying records')
|
173
|
+
omnata_log_handler.flush()
|
174
|
+
omnata_log_handler.unregister()
|
175
|
+
return {
|
176
|
+
"streams": dict(streams_list)
|
177
|
+
}
|
178
|
+
|
179
|
+
def configuration_form(self,
|
180
|
+
connection_method:str,
|
181
|
+
connection_parameters:Dict,
|
182
|
+
oauth_secret_name:Optional[str],
|
183
|
+
other_secrets_name:Optional[str],
|
184
|
+
sync_direction:str,
|
185
|
+
sync_strategy:Dict,
|
186
|
+
function_name:str,
|
187
|
+
sync_parameters:Dict,
|
188
|
+
current_form_parameters:Optional[Dict]):
|
189
|
+
logger.info(f'Entered configuration_form method')
|
190
|
+
sync_strategy = normalise_nulls(sync_strategy)
|
191
|
+
oauth_secret_name = normalise_nulls(oauth_secret_name)
|
192
|
+
other_secrets_name = normalise_nulls(other_secrets_name)
|
193
|
+
connection_secrets=self.get_secrets(oauth_secret_name,other_secrets_name)
|
194
|
+
connection_parameters = parse_obj_as(Dict[str,StoredConfigurationValue],connection_parameters)
|
195
|
+
sync_parameters = parse_obj_as(Dict[str,StoredConfigurationValue],sync_parameters)
|
196
|
+
form_parameters = None
|
197
|
+
if current_form_parameters is not None:
|
198
|
+
form_parameters = parse_obj_as(Dict[str,StoredConfigurationValue],current_form_parameters)
|
199
|
+
if sync_direction=='outbound':
|
200
|
+
sync_strat = OutboundSyncStrategy.parse_obj(sync_strategy)
|
201
|
+
parameters = OutboundSyncConfigurationParameters(connection_parameters=connection_parameters,
|
202
|
+
connection_secrets=connection_secrets,
|
203
|
+
sync_strategy=sync_strat,
|
204
|
+
sync_parameters=sync_parameters,
|
205
|
+
connection_method=connection_method,
|
206
|
+
current_form_parameters=form_parameters)
|
207
|
+
elif sync_direction=='inbound':
|
208
|
+
parameters = InboundSyncConfigurationParameters(connection_parameters=connection_parameters,
|
209
|
+
connection_secrets=connection_secrets,
|
210
|
+
sync_parameters=sync_parameters,
|
211
|
+
connection_method=connection_method,
|
212
|
+
current_form_parameters=form_parameters)
|
213
|
+
else:
|
214
|
+
raise ValueError(f'Unknown direction {sync_direction}')
|
215
|
+
the_function = getattr(self._plugin_instance, function_name or f"{sync_direction}_configuration_form")
|
216
|
+
script_result = the_function(parameters)
|
217
|
+
if isinstance(script_result,BaseModel):
|
218
|
+
script_result = script_result.dict()
|
219
|
+
elif isinstance(script_result,List):
|
220
|
+
if len(script_result) > 0 and isinstance(script_result[0],BaseModel):
|
221
|
+
script_result = [r.dict() for r in script_result]
|
222
|
+
return script_result
|
223
|
+
|
224
|
+
def connection_form(self):
|
225
|
+
try:
|
226
|
+
logger.info('Entered connection_form method')
|
227
|
+
form:List[ConnectionMethod] = self._plugin_instance.connection_form()
|
228
|
+
return [f.dict() for f in form]
|
229
|
+
except Exception as exception:
|
230
|
+
logger.error(str(exception),exc_info=True,stack_info=True)
|
231
|
+
return {
|
232
|
+
"success": False,
|
233
|
+
"error": str(exception)
|
234
|
+
}
|
235
|
+
|
236
|
+
def get_secrets(self,oauth_secret_name:Optional[str],other_secrets_name:Optional[str]) -> Dict[str,StoredConfigurationValue]:
|
237
|
+
connection_secrets={}
|
238
|
+
if oauth_secret_name is not None:
|
239
|
+
connection_secrets['access_token']= StoredConfigurationValue(value=_snowflake.get_oauth_access_token(oauth_secret_name))
|
240
|
+
if other_secrets_name is not None:
|
241
|
+
try:
|
242
|
+
secret_string_content = _snowflake.get_generic_secret_string(other_secrets_name)
|
243
|
+
other_secrets = json.loads(secret_string_content)
|
244
|
+
except Exception as exception:
|
245
|
+
logger.error(f'Error parsing secrets content: {str(exception)}')
|
246
|
+
raise ValueError(f'Error parsing secrets content: {str(exception)}')
|
247
|
+
connection_secrets = parse_obj_as(Dict[str,StoredConfigurationValue],other_secrets)
|
248
|
+
return connection_secrets
|
249
|
+
|
250
|
+
def connect(self,method,connection_parameters:Dict,network_rule_name:str,oauth_secret_name:Optional[str],other_secrets_name:Optional[str]):
|
251
|
+
logger.info('Entered connect method')
|
252
|
+
logger.info(f'Connection parameters: {connection_parameters}')
|
253
|
+
connection_secrets = self.get_secrets(oauth_secret_name,other_secrets_name)
|
254
|
+
|
255
|
+
from omnata_plugin_runtime.omnata_plugin import ConnectionConfigurationParameters
|
256
|
+
connect_response = self._plugin_instance.connect(ConnectionConfigurationParameters(connection_method=method,
|
257
|
+
connection_parameters=parse_obj_as(Dict[str,StoredConfigurationValue],connection_parameters),
|
258
|
+
connection_secrets=parse_obj_as(Dict[str,StoredConfigurationValue],connection_secrets)))
|
259
|
+
# the connect method can also return more network addresses. If so, we need to update the
|
260
|
+
# network rule associated with the external access integration
|
261
|
+
if connect_response.network_addresses is not None:
|
262
|
+
existing_rule_result = self._session.sql(f'desc network rule {network_rule_name}').collect()
|
263
|
+
rule_values:List[str] = existing_rule_result[0].value_list.split(',')
|
264
|
+
for network_address in connect_response.network_addresses:
|
265
|
+
if network_address not in rule_values:
|
266
|
+
rule_values.append(network_address)
|
267
|
+
rule_values_string = ','.join([f"'{value}'" for value in rule_values])
|
268
|
+
self._session.sql(f'alter network rule {network_rule_name} set value_list = ({rule_values_string})').collect()
|
269
|
+
|
270
|
+
return connect_response.dict()
|
271
|
+
|
272
|
+
def api_limits(self):
|
273
|
+
logger.info('Entered api_limits method')
|
274
|
+
response:List[ApiLimits] = self._plugin_instance.api_limits(None)
|
275
|
+
return [api_limit.dict() for api_limit in response]
|
276
|
+
|
277
|
+
def normalise_nulls(obj):
|
278
|
+
"""
|
279
|
+
If an object came through a SQL interface with a null value, we convert it to a regular None here
|
280
|
+
"""
|
281
|
+
if type(obj).__name__=='sqlNullWrapper':
|
282
|
+
return None
|
283
|
+
# handle a bunch of objects being given at once
|
284
|
+
if type(obj).__name__=='list':
|
285
|
+
return [normalise_nulls(x) for x in obj]
|
286
|
+
return obj
|
@@ -0,0 +1,232 @@
|
|
1
|
+
# it's not the 1980s anymore
|
2
|
+
# pylint: disable=line-too-long,multiple-imports,logging-fstring-interpolation
|
3
|
+
"""
|
4
|
+
Contains functionality for limiting http requests made by Omnata plugins
|
5
|
+
"""
|
6
|
+
from __future__ import annotations
|
7
|
+
from pydantic import Field
|
8
|
+
from typing import List, Literal, Optional
|
9
|
+
from logging import getLogger
|
10
|
+
from email.utils import parsedate_to_datetime
|
11
|
+
from .configuration import SubscriptableBaseModel
|
12
|
+
import datetime
|
13
|
+
import threading
|
14
|
+
import requests
|
15
|
+
import re
|
16
|
+
logger = getLogger(__name__)
|
17
|
+
|
18
|
+
TimeUnitType = Literal['second','minute','hour','day']
|
19
|
+
|
20
|
+
HttpMethodType = Literal['GET','HEAD','POST','PUT','DELETE','CONNECT','OPTIONS','TRACE','PATCH']
|
21
|
+
|
22
|
+
class HttpRequestMatcher(SubscriptableBaseModel):
|
23
|
+
"""
|
24
|
+
A class used to match an HTTP request
|
25
|
+
"""
|
26
|
+
http_methods:List[HttpMethodType]
|
27
|
+
url_regex:str
|
28
|
+
|
29
|
+
@classmethod
|
30
|
+
def match_all(cls):
|
31
|
+
"""
|
32
|
+
A HttpRequestMatcher which will match all requests.
|
33
|
+
"""
|
34
|
+
return cls(http_methods=['GET','HEAD','POST','PUT','DELETE','CONNECT','OPTIONS','TRACE','PATCH'],
|
35
|
+
url_regex='.*')
|
36
|
+
|
37
|
+
class ApiLimits(SubscriptableBaseModel):
|
38
|
+
"""
|
39
|
+
Encapsulates the constraints imposed by an app's APIs
|
40
|
+
"""
|
41
|
+
endpoint_category:str = Field(
|
42
|
+
'All endpoints',
|
43
|
+
description='the name of the API category (e.g. "Data loading endpoints")',
|
44
|
+
)
|
45
|
+
request_matchers:List[HttpRequestMatcher]=Field(
|
46
|
+
[HttpRequestMatcher.match_all()],
|
47
|
+
description="a list of request matchers. If None is provided, all requests will be matched",
|
48
|
+
)
|
49
|
+
request_rates:List[RequestRateLimit]=Field(
|
50
|
+
None,
|
51
|
+
description="imposes time delays between requests to stay under a defined rate limit",
|
52
|
+
)
|
53
|
+
|
54
|
+
def request_matches(self,method:HttpMethodType,url:str):
|
55
|
+
"""
|
56
|
+
Given the request matchers that exist, determines whether the provided HTTP method and url is a match
|
57
|
+
"""
|
58
|
+
for request_matcher in self.request_matchers:
|
59
|
+
if method in request_matcher.http_methods and re.search(request_matcher.url_regex,url):
|
60
|
+
return True
|
61
|
+
return False
|
62
|
+
|
63
|
+
@classmethod
|
64
|
+
def apply_overrides(cls, default_api_limits:List[ApiLimits],overridden_values:List[ApiLimits]) -> List[ApiLimits]:
|
65
|
+
"""
|
66
|
+
Takes a list of default api limits, and replaces them with any overridden values
|
67
|
+
"""
|
68
|
+
if overridden_values is None or len(overridden_values)==0:
|
69
|
+
return default_api_limits
|
70
|
+
overrides_keyed = {l.construct_from_variant:l for l in overridden_values}
|
71
|
+
for api_limit in default_api_limits:
|
72
|
+
if api_limit.endpoint_category in overrides_keyed.keys():
|
73
|
+
api_limit.request_rates = overrides_keyed[api_limit.endpoint_category]
|
74
|
+
return default_api_limits
|
75
|
+
|
76
|
+
@classmethod
|
77
|
+
def request_match(cls, all_api_limits:List[ApiLimits],method:HttpMethodType,url:str) -> Optional[ApiLimits]:
|
78
|
+
"""
|
79
|
+
Given a set of defined API limits, return the first one that matches, or None if none of them match
|
80
|
+
"""
|
81
|
+
for api_limits in all_api_limits:
|
82
|
+
if api_limits.request_matches(method,url):
|
83
|
+
return api_limits
|
84
|
+
return None
|
85
|
+
|
86
|
+
def calculate_wait(self,rate_limit_state:RateLimitState) -> datetime.datetime:
|
87
|
+
"""
|
88
|
+
Based on the configured wait limits, given a sorted list of previous requests (newest to oldest),
|
89
|
+
determine when the next request is allowed to occur.
|
90
|
+
Each rate limit is a number of requests over a time window.
|
91
|
+
Examples:
|
92
|
+
If the rate limit is 5 requests every 10 seconds, we:
|
93
|
+
- determine the timestamp of the 5th most recent request
|
94
|
+
- add 10 seconds to that timestamp
|
95
|
+
The resulting timestamp is when the next request can be made (if it's in the past, it can be done immediately)
|
96
|
+
If multiple rate limits exist, the maximum timestamp is used (i.e. the most restrictive rate limit applies)
|
97
|
+
"""
|
98
|
+
logger.info(f"calculating wait time, given previous requests as {rate_limit_state.previous_request_timestamps}")
|
99
|
+
if self.request_rates is None:
|
100
|
+
return datetime.datetime.utcnow()
|
101
|
+
longest_wait = datetime.datetime.utcnow()
|
102
|
+
if rate_limit_state.wait_until is not None and rate_limit_state.wait_until > longest_wait:
|
103
|
+
longest_wait = rate_limit_state.wait_until
|
104
|
+
for request_rate in self.request_rates:
|
105
|
+
if len(rate_limit_state.previous_request_timestamps) > 0:
|
106
|
+
request_index = request_rate.request_count - 1
|
107
|
+
if len(rate_limit_state.previous_request_timestamps) < request_index:
|
108
|
+
request_index = len(rate_limit_state.previous_request_timestamps)-1
|
109
|
+
timestamp_at_horizon = rate_limit_state.previous_request_timestamps[request_index]
|
110
|
+
next_allowed_request = timestamp_at_horizon + datetime.timedelta(seconds=request_rate.number_of_seconds())
|
111
|
+
if next_allowed_request > longest_wait:
|
112
|
+
longest_wait = next_allowed_request
|
113
|
+
|
114
|
+
return longest_wait
|
115
|
+
|
116
|
+
class RateLimitState(SubscriptableBaseModel):
|
117
|
+
"""
|
118
|
+
Encapsulates the rate limiting state of an endpoint category
|
119
|
+
for a particular connection (as opposed to configuration)
|
120
|
+
"""
|
121
|
+
wait_until:Optional[datetime.datetime]=Field(
|
122
|
+
None,
|
123
|
+
description="Providing a value here means that no requests should occur until a specific moment in the future",
|
124
|
+
)
|
125
|
+
previous_request_timestamps:Optional[List[datetime.datetime]]=Field(
|
126
|
+
[],
|
127
|
+
description="A list of timestamps where previous requests have been made, used to calculate the next request time",
|
128
|
+
)
|
129
|
+
_request_timestamps_lock = threading.Lock()
|
130
|
+
|
131
|
+
def register_http_request(self):
|
132
|
+
"""
|
133
|
+
Registers a request as having just occurred, for rate limiting purposes.
|
134
|
+
You only need to use this if your HTTP requests are not automatically being
|
135
|
+
registered, which happens if http.client.HTTPConnection is not being used.
|
136
|
+
"""
|
137
|
+
with self._request_timestamps_lock:
|
138
|
+
append_time = datetime.datetime.utcnow()
|
139
|
+
self.previous_request_timestamps.insert(0,append_time)
|
140
|
+
|
141
|
+
def prune_history(self,request_rates:List[RequestRateLimit]=None):
|
142
|
+
"""
|
143
|
+
When we store the request history, it doesn't make sense to go back indefinitely.
|
144
|
+
We only need the requests which fall within the longest rate limiting window
|
145
|
+
|
146
|
+
"""
|
147
|
+
longest_window_seconds = max([rate.number_of_seconds() for rate in request_rates])
|
148
|
+
irrelevance_horizon = datetime.datetime.now() - datetime.timedelta(seconds=longest_window_seconds)
|
149
|
+
self.previous_request_timestamps = [ts for ts in self.previous_request_timestamps if ts > irrelevance_horizon]
|
150
|
+
|
151
|
+
class RequestRateLimit(SubscriptableBaseModel):
|
152
|
+
"""
|
153
|
+
Request rate limits
|
154
|
+
Defined as a request count, time unit and number of units e.g. (1,"second",5) = 1 request per 5 seconds, or (100, "minute", 15) = 100 requests per 15 minutes
|
155
|
+
"""
|
156
|
+
request_count:int
|
157
|
+
time_unit:TimeUnitType
|
158
|
+
unit_count:int
|
159
|
+
|
160
|
+
def number_of_seconds(self):
|
161
|
+
"""
|
162
|
+
Converts the time_unit and unit_count to a number of seconds.
|
163
|
+
E.g. 5 minutes = 300
|
164
|
+
2 hours = 7200
|
165
|
+
"""
|
166
|
+
if self.time_unit=='second':
|
167
|
+
return self.unit_count
|
168
|
+
elif self.time_unit=='minute':
|
169
|
+
return self.unit_count*60
|
170
|
+
elif self.time_unit=='hour':
|
171
|
+
return self.unit_count*3600
|
172
|
+
elif self.time_unit=='day':
|
173
|
+
return self.unit_count*86400
|
174
|
+
else:
|
175
|
+
raise ValueError(f"Unknown time unit: {self.time_unit}")
|
176
|
+
|
177
|
+
def to_description(self) -> str:
|
178
|
+
"""Returns a readable description of this limit.
|
179
|
+
For example:
|
180
|
+
"1 request per minute"
|
181
|
+
"5 requests per 2 seconds"
|
182
|
+
|
183
|
+
Returns:
|
184
|
+
str: the description as described above
|
185
|
+
"""
|
186
|
+
return str(self.request_count) + " " + \
|
187
|
+
"request" + ("s" if self.request_count > 1 else "") + \
|
188
|
+
" per " + \
|
189
|
+
(self.time_unit if self.unit_count==1 else f"{self.unit_count} {self.time_unit}s")
|
190
|
+
|
191
|
+
class RetryLaterException(Exception):
|
192
|
+
"""
|
193
|
+
Exception raised when the app has notified that rate limits are exceeded.
|
194
|
+
Throwing this during record apply imposes a temporary extra API constraint that
|
195
|
+
we need to wait until a future date before more requests are made.
|
196
|
+
|
197
|
+
"""
|
198
|
+
|
199
|
+
def __init__(self, future_datetime:datetime.datetime):
|
200
|
+
self.future_datetime = future_datetime
|
201
|
+
message = "Remote system wants us to retry later"
|
202
|
+
self.message = message
|
203
|
+
super().__init__(self.message)
|
204
|
+
|
205
|
+
class InterruptedWhileWaitingException(Exception):
|
206
|
+
"""
|
207
|
+
Indicates that while waiting for rate limiting to expire, the sync was interrupted
|
208
|
+
"""
|
209
|
+
|
210
|
+
def too_many_requests_hook(fallback_future_datetime:datetime.datetime = datetime.datetime.utcnow() + datetime.timedelta(hours=24)):
|
211
|
+
"""
|
212
|
+
A Requests hook which raises a RetryLaterException if an HTTP 429 response is returned.
|
213
|
+
Examines the Retry-After header (https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After)
|
214
|
+
to determine the appropriate future datetime to retry.
|
215
|
+
If that isn't available, it falls back to fallback_future_datetime.
|
216
|
+
|
217
|
+
"""
|
218
|
+
def hook(resp:requests.Response, *args, **kwargs):
|
219
|
+
"""
|
220
|
+
The actual hook implementation
|
221
|
+
"""
|
222
|
+
if resp.status_code==429:
|
223
|
+
if 'Retry-After' in resp.headers:
|
224
|
+
retry_after:str = resp.headers['Retry-After']
|
225
|
+
if retry_after.isnumeric():
|
226
|
+
raise RetryLaterException(future_datetime=datetime.datetime.utcnow() + datetime.timedelta(seconds=int(retry_after)))
|
227
|
+
retry_date = parsedate_to_datetime(retry_after)
|
228
|
+
raise RetryLaterException(future_datetime=retry_date)
|
229
|
+
raise RetryLaterException(future_datetime=fallback_future_datetime)
|
230
|
+
return hook
|
231
|
+
|
232
|
+
ApiLimits.update_forward_refs()
|
@@ -0,0 +1,50 @@
|
|
1
|
+
|
2
|
+
from typing import Dict, List
|
3
|
+
from pydantic import parse_obj_as # pylint:disable=no-name-in-module
|
4
|
+
from logging import getLogger
|
5
|
+
from .configuration import StoredConfigurationValue, StoredMappingValue, StoredFieldMapping
|
6
|
+
logger = getLogger(__name__)
|
7
|
+
|
8
|
+
class RecordTransformerResult(dict):
|
9
|
+
"""
|
10
|
+
The result of transforming an individual record
|
11
|
+
"""
|
12
|
+
def __init__(self, success:bool,errors:List[str] = None,transformed_record:dict = None):
|
13
|
+
"""
|
14
|
+
Constructs a RecordTransformerResult
|
15
|
+
"""
|
16
|
+
if success:
|
17
|
+
dict.__init__(self, success=success,transformed_record=transformed_record)
|
18
|
+
else:
|
19
|
+
dict.__init__(self, success=success,errors=errors)
|
20
|
+
|
21
|
+
class RecordTransformer():
|
22
|
+
"""
|
23
|
+
The default record transformer, used to convert Snowflake records into an app-specific representation
|
24
|
+
using the configured field mappings, and optionally perform data validation so that source failures occur
|
25
|
+
"""
|
26
|
+
def transform_record(self,source_metadata:dict,sync_parameters:Dict[str, StoredConfigurationValue],field_mappings:StoredMappingValue,source_record:dict) -> RecordTransformerResult:
|
27
|
+
"""
|
28
|
+
Default transformer.
|
29
|
+
When the visual field mapper is used, simply picks out the columns and renames them.
|
30
|
+
When the jinja mapper is used, copies the template into the output
|
31
|
+
"""
|
32
|
+
transformed_record = {}
|
33
|
+
errors = []
|
34
|
+
if field_mappings.mapper_type == 'jinja_template':
|
35
|
+
transformed_record['jinja_template'] = field_mappings.jinja_template
|
36
|
+
transformed_record['source_record'] = source_record
|
37
|
+
elif field_mappings.mapper_type == 'field_mapping_selector':
|
38
|
+
parsed_field_mappings:List[StoredFieldMapping] = parse_obj_as(List[StoredFieldMapping],field_mappings.field_mappings)
|
39
|
+
for field_mapping in parsed_field_mappings:
|
40
|
+
if field_mapping.source_column not in source_record:
|
41
|
+
errors.append(f"Column '{field_mapping.source_column}' not found in record")
|
42
|
+
else:
|
43
|
+
transformed_record[field_mapping.app_field] = source_record[field_mapping.source_column]
|
44
|
+
else:
|
45
|
+
errors.append(f"Unrecognised mapper type: {field_mappings.mapper_type}")
|
46
|
+
logger.debug(f"Transformed record: {transformed_record}")
|
47
|
+
if len(errors) > 0:
|
48
|
+
logger.info(f"Record transformer errors: {errors}")
|
49
|
+
return dict(RecordTransformerResult(success=False, errors=errors))
|
50
|
+
return dict(RecordTransformerResult(success=True, transformed_record=transformed_record))
|