omnata-plugin-runtime 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1154 @@
1
+ # it's not the 1980s anymore
2
+ # pylint: disable=line-too-long,multiple-imports,logging-fstring-interpolation
3
+ """
4
+ Omnata Plugin Runtime.
5
+ Includes data container classes and defines the contract for a plugin.
6
+ """
7
+ from __future__ import annotations
8
+ import http
9
+ import pandas
10
+ import jinja2
11
+ from functools import wraps, partial
12
+ from abc import ABC, abstractmethod
13
+ from typing import Callable, Dict, Iterable, List, Literal, Type
14
+ import threading, datetime, json, queue
15
+ from logging import getLogger
16
+ from dateutil.parser import parse
17
+ from jinja2 import Environment
18
+ from snowflake.snowpark import Session
19
+ from snowflake.snowpark.functions import col, lit, parse_json, is_null, not_
20
+ from snowflake.connector.pandas_tools import write_pandas
21
+ from .forms import ConnectionMethod, OutboundSyncConfigurationForm,InboundSyncConfigurationForm
22
+ from .configuration import STANDARD_OUTBOUND_SYNC_ACTIONS, OutboundSyncAction, OutboundSyncStrategy, OutboundSyncConfigurationParameters, InboundSyncConfigurationParameters, ConnectionConfigurationParameters, StreamConfiguration, SubscriptableBaseModel, SyncConfigurationParameters, StoredStreamConfiguration
23
+ from .rate_limiting import ApiLimits, InterruptedWhileWaitingException, RateLimitState
24
+
25
+ logger = getLogger(__name__)
26
+ SortDirectionType = Literal['asc','desc']
27
+
28
+ class PluginManifest(SubscriptableBaseModel):
29
+ """
30
+ Constructs a Plugin Manifest, which identifies the application, describes how it can work, and defines any runtime code dependancies.
31
+ :param str plugin_id: A short, string identifier for the application, a combination of lowercase alphanumeric and underscores, e.g. "google_sheets"
32
+ :param str plugin_name: A descriptive name for the application, e.g. "Google Sheets"
33
+ :param str developer_id: A short, string identifier for the developer, a combination of lowercase alphanumeric and underscores, e.g. "acme_corp"
34
+ :param str developer_name: A descriptive name for the developer, e.g. "Acme Corp"
35
+ :param str docs_url: The URL where plugin documentation can be found, e.g. "https://docs.omnata.com"
36
+ :param bool supports_inbound: A flag to indicate whether or not the plugin supports inbound sync. Support for inbound sync behaviours (full/incremental) is defined per inbound stream.
37
+ :param List[OutboundSyncStrategy] supported_outbound_strategies: A list of sync strategies that the plugin can support, e.g. create,upsert.
38
+ """
39
+ plugin_id:str
40
+ plugin_name:str
41
+ developer_id:str
42
+ developer_name:str
43
+ docs_url:str
44
+ supports_inbound:bool
45
+ supported_outbound_strategies:List[OutboundSyncStrategy]
46
+
47
+ def jinja_filter(func):
48
+ """
49
+ This annotation designates a function as a jinja filter.
50
+ Adding it will put the function into the jinja globals so that it can be used in templates.
51
+ """
52
+ func.is_jinja_filter = True
53
+ return func
54
+
55
+ class SyncRequest(ABC):
56
+ """
57
+ Functionality common to inbound and outbound syncs requests.
58
+
59
+ Both inbound and outbound syncs have records to apply back to Snowflake (outbound have load results, inbound have records).
60
+ So there's common functionality for feeding them in, as well as logging, other housekeeping tasks, and rate limiting.
61
+ """
62
+ def __init__(self,run_id:int,
63
+ session:Session,
64
+ source_app_name:str,
65
+ results_schema_name:str,
66
+ results_table_name:str,
67
+ plugin_instance:Type[OmnataPlugin],
68
+ api_limits:List[ApiLimits],
69
+ rate_limit_state:Dict[str,RateLimitState],
70
+ run_deadline:datetime.datetime,
71
+ development_mode:bool=False):
72
+ """
73
+ Constructs a SyncRequest.
74
+
75
+ :param int run_id: The ID number for the run, used to report back status to the engine
76
+ :param any session: The snowpark session object, only used internally
77
+ :param OmnataPlugin plugin_instance: The instance of the Omnata Plugin this request is for
78
+ :param ApiLimits api_limits: Constraints to observe when performing HTTP requests
79
+ :param bool development_mode: In development mode, apply_results_queue does not load into Snowflake, instead they are cached locally and can be retrieved via get_queued_results
80
+ :return: nothing
81
+ """
82
+ logger.info(f"Initiating SyncRequest for sync run {run_id}")
83
+ self._run_deadline = run_deadline
84
+ self.plugin_instance = plugin_instance
85
+ self._source_app_name = source_app_name
86
+ self._results_schema_name = results_schema_name
87
+ self._results_table_name = results_table_name
88
+ self._full_results_table_name = f"{source_app_name}.{results_schema_name}.{results_table_name}"
89
+ if self.plugin_instance is not None:
90
+ self.plugin_instance._sync_request = self
91
+ self._session:Session = session
92
+ self._run_id = run_id
93
+ self.api_limits = api_limits
94
+ self.rate_limit_state = rate_limit_state
95
+ # these deal with applying the results, not sure they belong here
96
+ self._apply_results_lock = threading.Lock()
97
+ # Snowflake connector appears to not be thread safe
98
+ # # File \"/var/task/snowflake/snowpark/table.py\", line 221, in _get_update_result\n
99
+ # return UpdateResult(int(rows[0][0]), int(rows[0][1]))\nIndexError: list index out of range"
100
+ self._snowflake_query_lock = threading.Lock()
101
+ self._loadbatch_id = 0
102
+ self._loadbatch_id_lock = threading.Lock()
103
+ self.development_mode = development_mode
104
+ # This is used internally by the testing framework, when we're loading records in a behave test
105
+ self._prebaked_record_state:pandas.DataFrame = None
106
+ # create a stop requestor to cease thread activity
107
+ self._thread_cancellation_token = threading.Event()
108
+ self._thread_exception_thrown = None
109
+ self._apply_results_task = None
110
+ self._cancel_checking_task = None
111
+ self._apply_results = None # initialised properly in subclasses
112
+ # create an exception handler for the threads
113
+ def thread_exception_hook(args):
114
+ logger.error('Thread exception',exc_info=True)
115
+ self._thread_cancellation_token.set() # this will tell the other threads to stop working
116
+ logger.info(f'thread_cancellation_token: {self._thread_cancellation_token.is_set()}')
117
+ #nonlocal thread_exception_thrown
118
+ self._thread_exception_thrown = args
119
+ threading.excepthook = thread_exception_hook
120
+ # start another worker thread to handle uploads of results every 10 seconds
121
+ # we don't join on this thread, instead we cancel it once the workers have finished
122
+ if self.development_mode is False:
123
+ if self._apply_results_task is None:
124
+ self._apply_results_task = threading.Thread(target=self.__apply_results_worker, args=(self._thread_cancellation_token,))
125
+ self._apply_results_task.start()
126
+ if self._cancel_checking_task is None:
127
+ self._cancel_checking_task = threading.Thread(target=self.__cancel_checking_worker, args=(self._thread_cancellation_token,))
128
+ self._cancel_checking_task.start()
129
+ # also spin up a thread to monitor for run cancellation
130
+
131
+ def __apply_results_worker(self,cancellation_token):
132
+ """
133
+ Designed to be run in a thread, this method polls the results every 10 seconds and sends them back to Snowflake.
134
+ """
135
+ while not cancellation_token.is_set():
136
+ logger.info("apply results worker checking for results")
137
+ self.apply_results_queue()
138
+ cancellation_token.wait(10)
139
+ logger.info("apply results worker exiting")
140
+
141
+ def __cancel_checking_worker(self,cancellation_token):
142
+ """
143
+ Designed to be run in a thread, this method checks to see if the sync run has been cancelled.
144
+ """
145
+ while not cancellation_token.is_set():
146
+ logger.info("cancel checking worked checking for results")
147
+ with self._snowflake_query_lock:
148
+ run_cancelled_results = self._session.table('DATA.SYNC_RUN') \
149
+ .where((col('SYNC_RUN_ID')==lit(self._run_id))) \
150
+ .select(not_(is_null(col('CANCELLED_DATETIME'))).alias('IS_CANCELLED')).collect()
151
+ if len(run_cancelled_results)==0:
152
+ raise ValueError(f"Sync run {self._run_id} did not exist when checking for cancellation")
153
+ if run_cancelled_results[0].IS_CANCELLED:
154
+ self.apply_cancellation()
155
+ cancellation_token.wait(10)
156
+ logger.info("cancel checking worker exiting")
157
+
158
+ @abstractmethod
159
+ def apply_results_queue(self):
160
+ """
161
+ Abstract method to apply the queued results. Inbound and Outbound syncs will each implement their own results
162
+ processing logic
163
+ """
164
+ logger.error('apply_results_queue called on SyncRequest base class, this should never occur')
165
+
166
+ @abstractmethod
167
+ def apply_cancellation(self):
168
+ """
169
+ Abstract method to handle run cancellation.
170
+ """
171
+
172
+ @abstractmethod
173
+ def apply_deadline_reached(self):
174
+ """
175
+ Abstract method to handle a run deadline being reached
176
+ """
177
+
178
+ def register_http_request(self, endpoint_category:str):
179
+ """
180
+ Registers a request as having just occurred, for rate limiting purposes.
181
+ You only need to use this if your HTTP requests are not automatically being
182
+ registered, which happens if http.client.HTTPConnection is not being used.
183
+ """
184
+ if endpoint_category in self.rate_limit_state:
185
+ self.rate_limit_state[endpoint_category].register_http_request()
186
+
187
+ def wait_for_rate_limiting(self, api_limit:ApiLimits) -> bool:
188
+ """
189
+ Waits for rate limits to pass before returning. Uses the api_limits and the history of
190
+ request timestamps to determine how long to wait.
191
+
192
+ :return: true if wait for rate limits was successful, otherwise false (thread was interrupted)
193
+ :raises: DeadlineReachedException if rate limiting is going to require us to wait past the run deadline
194
+ """
195
+ if api_limit is None:
196
+ return True
197
+ wait_until = api_limit.calculate_wait(self.rate_limit_state[api_limit.endpoint_category])
198
+ if wait_until > self._run_deadline:
199
+ # if the rate limiting is going to require us to wait past the run deadline, we bail out now
200
+ raise DeadlineReachedException()
201
+ time_now = datetime.datetime.utcnow()
202
+ logger.info(f"calculated wait until date was {wait_until}, comparing to {time_now}")
203
+
204
+ while wait_until > time_now:
205
+ seconds_to_sleep = (wait_until - time_now).total_seconds()
206
+ if self._thread_cancellation_token.wait(seconds_to_sleep):
207
+ return False
208
+ wait_until = api_limit.calculate_wait(self.rate_limit_state[api_limit.endpoint_category])
209
+ time_now = datetime.datetime.utcnow()
210
+ return True
211
+
212
+ def wait(self,seconds:float) -> bool:
213
+ """
214
+ Waits for a given number of seconds, provided the current sync run isn't cancelled in the meantime.
215
+ Returns True if no cancellation occurred, otherwise False.
216
+ If False is returned, the plugin should exit immediately.
217
+ """
218
+ return not self._thread_cancellation_token.wait(seconds)
219
+
220
+ def update_activity(self,current_activity:str):
221
+ """
222
+ Provides an update to the user on what's happening inside the sync run. It should
223
+ be used before commencing a potential long-running phase, like polling and waiting or
224
+ calling an API (keep in mind, rate limiting may delay even a fast API).
225
+ Keep this to a very consise string, like 'Fetching records from API'.
226
+ Avoid lengthy diagnostic messages, anything like this should be logged the normal way.
227
+ """
228
+ logger.info(f"Activity update: {current_activity}")
229
+ with self._snowflake_query_lock:
230
+ try:
231
+ self._session.sql(f"call {self._source_app_name}.API.PLUGIN_ACTIVITY_UPDATE(:1,:2)",
232
+ [self._run_id,current_activity]).execute()
233
+ except Exception as e:
234
+ logger.error(f"Error updating activity: {e}")
235
+
236
+ class HttpRateLimiting():
237
+ """
238
+ A custom context manager which applies rate limiting automatically.
239
+ Not thread safe but shouldn't need to be, since it'll be used once spanning all HTTP activity
240
+ """
241
+ def __init__(self,sync_request: SyncRequest, parameters:SyncConfigurationParameters):
242
+ self.sync_request = sync_request
243
+ self.original_putrequest = None
244
+ self.parameters = parameters
245
+
246
+ def __enter__(self):
247
+ """
248
+ Used to manage the outbound http requests made by Omnata Plugins.
249
+ It does this by patching http.client.HTTPConnection.putrequest
250
+ """
251
+ self_outer = self
252
+ self.original_putrequest = http.client.HTTPConnection.putrequest
253
+ def new_putrequest(self, method:str, url:str, skip_host:bool=False, skip_accept_encoding:bool=False):
254
+ # first, we do any waiting that we need to do (possibly none)
255
+ matched_api_limit = ApiLimits.request_match(self_outer.sync_request.api_limits,method,url)
256
+ if matched_api_limit is not None:
257
+ if not self_outer.sync_request.wait_for_rate_limiting(matched_api_limit):
258
+ logger.info('Interrupted while waiting for rate limiting')
259
+ raise InterruptedWhileWaitingException()
260
+ # and also register this current request in its limit category
261
+ self_outer.sync_request.register_http_request(matched_api_limit.endpoint_category)
262
+ return self_outer.original_putrequest(self, method, url, skip_host, skip_accept_encoding)
263
+ http.client.HTTPConnection.putrequest = new_putrequest
264
+
265
+ def __exit__(self, exc_type, exc_value, traceback):
266
+ http.client.HTTPConnection.putrequest = self.original_putrequest
267
+
268
+ class OutboundSyncRequest(SyncRequest):
269
+ """
270
+ A request to sync data outbound (from Snowflake to an app)
271
+ """
272
+ def __init__(self,run_id:int,
273
+ session:Session,
274
+ source_app_name:str,
275
+ results_schema_name:str,
276
+ results_table_name:str,
277
+ plugin_instance:Type[OmnataPlugin],
278
+ api_limits:List[ApiLimits],
279
+ rate_limit_state:Dict[str,RateLimitState],
280
+ run_deadline:datetime.datetime,
281
+ development_mode:bool=False):
282
+ """
283
+ Constructs an OutboundSyncRequest.
284
+
285
+ :param int run_id: The ID number for the run, only used to report back on status
286
+ :param any session: The snowpark session object, only used internally
287
+ :param OmnataPlugin plugin_instance: The instance of the Omnata Plugin this request is for
288
+ :param ApiLimits api_limits: Constraints to observe when performing HTTP requests
289
+ :param bool development_mode: In development mode, apply_results_queue does not load into Snowflake, instead they are cached locally and can be retrieved via get_queued_results
290
+ :return: nothing
291
+ """
292
+ SyncRequest.__init__(self,run_id,session,source_app_name,results_schema_name,results_table_name,plugin_instance,api_limits,rate_limit_state,run_deadline,development_mode)
293
+ self._apply_results:pandas.DataFrame = []
294
+
295
+ def _get_next_loadbatch_id(self):
296
+ with self._loadbatch_id_lock:
297
+ self._loadbatch_id = self._loadbatch_id + 1
298
+ return self._loadbatch_id
299
+
300
+ def apply_results_queue(self):
301
+ """
302
+ Merges all of the queued results and applies them
303
+ """
304
+ logger.info('OutboundSyncRequest apply_results_queue')
305
+ if self._apply_results is not None:
306
+ with self._apply_results_lock:
307
+ self._apply_results = [x for x in self._apply_results if x is not None and len(x) > 0] # remove any None/empty dataframes
308
+ if len(self._apply_results)>0:
309
+ logger.info(f"Applying {len(self._apply_results)} batches of queued results")
310
+ # upload all cached apply results
311
+ all_dfs = pandas.concat(self._apply_results)
312
+ logger.info(f"applying: {all_dfs}")
313
+ self._apply_results_dataframe(all_dfs)
314
+ self._apply_results.clear()
315
+ else:
316
+ logger.info("No queued results to apply")
317
+
318
+ def apply_cancellation(self):
319
+ """
320
+ Handles a cancellation of an outbound sync.
321
+ 1. Signals an interruption to the load process for the other threads
322
+ 2. Applies remaining queued results
323
+ 3. Marks remaining active records as delayed
324
+ """
325
+ # set the token so that the other threads stop
326
+ logger.info('Applying cancellation for OutboundSyncRequest')
327
+ self._thread_cancellation_token.set()
328
+ self.apply_results_queue()
329
+
330
+ def apply_deadline_reached(self):
331
+ """
332
+ Handles the reaching of a deadline for an outbound sync.
333
+ The behaviour is the same as for a cancellation, since the record state looks the same
334
+ """
335
+ logger.info('Apply deadline reached for OutboundSyncRequest')
336
+ self.apply_cancellation()
337
+
338
+ def enqueue_results(self,results:pandas.DataFrame):
339
+ """
340
+ Adds some results to the queue for applying asynchronously
341
+ """
342
+ logger.info(f"Enqueueing {len(results)} results for upload")
343
+ with self._apply_results_lock:
344
+ self._apply_results.append(self._preprocess_results_dataframe(results))
345
+
346
+ def get_queued_results(self):
347
+ """
348
+ Returns results queued during processing
349
+ """
350
+ if len(self._apply_results)==0:
351
+ raise ValueError('get_queued_results was called, but no results have been queued')
352
+ concat_results = pandas.concat(self._apply_results)
353
+ return concat_results
354
+
355
+ def _preprocess_results_dataframe(self,results_df:pandas.DataFrame):
356
+ """
357
+ Validates and pre-processes outbound sync results dataframe.
358
+ The result is a dataframe contain all (and only):
359
+ 'IDENTIFIER' string
360
+ 'APP_IDENTIFIER' string
361
+ 'APPLY_STATE' string
362
+ 'APPLY_STATE_DATETIME' datetime (UTC)
363
+ 'LOADBATCH_ID' int
364
+ 'RESULT' object
365
+ """
366
+ for required_column in ['IDENTIFIER','RESULT','SUCCESS']:
367
+ if required_column not in results_df.columns:
368
+ raise ValueError(f'{required_column} column was not included in results')
369
+ results_df.set_index('IDENTIFIER',inplace=True, drop=False)
370
+ append_time = datetime.datetime.utcnow()
371
+ results_df['APPLY_STATE_DATETIME'] = pandas.Timestamp(append_time).tz_localize('UTC')
372
+ if results_df is not None:
373
+ logger.info(f"Applying a queued results dataframe of {len(results_df)} records")
374
+ # change the success flag to an appropriate APPLY STATUS
375
+ results_df.loc[results_df['SUCCESS']==True,'APPLY_STATE'] = 'SUCCESS'
376
+ results_df.loc[results_df['SUCCESS']==False,'APPLY_STATE'] = 'DESTINATION_FAILURE'
377
+ results_df = results_df.drop('SUCCESS',axis=1)
378
+ # if results weren't added by enqueue_results, we'll add the status datetime column now
379
+ if 'APPLY_STATE_DATETIME' not in results_df.columns:
380
+ append_time = datetime.datetime.utcnow()
381
+ results_df['APPLY_STATE_DATETIME'] = pandas.Timestamp(append_time).tz_localize('UTC')
382
+ if 'APP_IDENTIFIER' not in results_df:
383
+ results_df['APP_IDENTIFIER'] = None
384
+ if 'LOADBATCH_ID' not in results_df:
385
+ results_df['LOADBATCH_ID'] = self._get_next_loadbatch_id()
386
+ # trim out the columns we don't need to return
387
+ return results_df[results_df.columns.intersection(['IDENTIFIER','APP_IDENTIFIER','APPLY_STATE','APPLY_STATE_DATETIME','LOADBATCH_ID','RESULT'])]
388
+
389
+ def _apply_results_dataframe(self,results_df:pandas.DataFrame):
390
+ """
391
+ Applies results for an outbound sync. This involves merging back onto the record state table
392
+ """
393
+ logger.info('applying results to table')
394
+ # use a random table name with a random string to avoid collisions
395
+ with self._snowflake_query_lock:
396
+ success, nchunks, nrows, _ = write_pandas(conn=self._session._conn._cursor.connection,
397
+ df=results_df,
398
+ table_name=self._results_table_name,
399
+ auto_create_table=False)
400
+
401
+ def __dataframe_wrapper(self,data_frame,render_jinja:bool=True):
402
+ """
403
+ Takes care of some common stuff we need to do for each dataframe for outbound syncs.
404
+ Parses the JSON in the transformed record column (Snowflake passes it as a string).
405
+ Also when the mapper is a jinja template, renders it.
406
+ """
407
+ if data_frame is None:
408
+ logger.info("Dataframe wrapper skipping pre-processing as dataframe is None")
409
+ return None
410
+ logger.info(f"Dataframe wrapper pre-processing {len(data_frame)} records: {data_frame}")
411
+ if len(data_frame) > 0:
412
+ try:
413
+ data_frame['TRANSFORMED_RECORD'] = data_frame['TRANSFORMED_RECORD'].apply(json.loads)
414
+ except TypeError as type_error:
415
+ logger.error('Error parsing transformed record output as JSON',exc_info=True)
416
+ if 'the JSON object must be str, bytes or bytearray, not NoneType' in str(type_error):
417
+ raise ValueError('null was returned from the record transformer, an object must always be returned') from type_error
418
+ if render_jinja and 'jinja_template' in data_frame.iloc[0]['TRANSFORMED_RECORD']:
419
+ logger.info("Rendering jinja template")
420
+ env = Environment()
421
+ # examine the plugin instance for jinja_filter decorated methods
422
+ if self.plugin_instance is not None:
423
+ for name in dir(self.plugin_instance):
424
+ member = getattr(self.plugin_instance,name)
425
+ if callable(member) and hasattr(member,'is_jinja_filter'):
426
+ logger.info(f"Adding jinja filter to environment: {name}")
427
+ env.filters[name] = member
428
+ def do_jinja_render(jinja_env,row_value):
429
+ logger.info(f"do_jinja_render: {row_value}")
430
+ jinja_template = jinja_env.from_string(row_value['jinja_template'])
431
+ try:
432
+ rendered_result = jinja_template.render({'row':row_value['source_record']})
433
+ logger.info(f"Individual jinja rendering result: {rendered_result}")
434
+ return rendered_result
435
+ except TypeError as type_error:
436
+ # re-throw as a template error so that we can handle it nicely
437
+ logger.error('Error during jinja render',exc_info=True)
438
+ raise jinja2.TemplateError(type_error)
439
+
440
+ # bit iffy about using apply since historically it's not guaranteed only-once, apparently tries to be clever with vectorizing
441
+ data_frame['TRANSFORMED_RECORD'] = data_frame.apply(lambda row: do_jinja_render(env,row['TRANSFORMED_RECORD']),axis=1)
442
+ # if it breaks things in future, switch to iterrows() and at[]
443
+ return data_frame
444
+
445
+ def get_records(self,sync_actions:List[OutboundSyncAction]=None,batched:bool=False,render_jinja:bool=True,
446
+ sort_column:str=None,sort_direction:SortDirectionType='desc') -> pandas.DataFrame | Iterable[pandas.DataFrame]:
447
+ """
448
+ Retrieves a dataframe of records to create,update or delete in the app.
449
+ :param List[OutboundSyncAction] sync_action: Which sync actions to included (includes all standard actions by default)
450
+ :param bool batched: If set to true, requests an iterator for a batch of dataframes. This is needed if a large data size (multiple GBs or more) is expected, so that the whole dataset isn't held in memory at one time.
451
+ :param bool render_jinja: If set to true and a jinja template is used, renders it automatically.
452
+ :param str sort_column: Applies a sort order to the dataframe.
453
+ :param SortDirectionType sort_direction: The sort direction, 'asc' or 'desc'
454
+ :type SortDirectionType: Literal['asc','desc']
455
+ :return: A pandas dataframe if batched is False (the default), otherwise an iterator of pandas dataframes
456
+ :rtype: pandas.DataFrame or iterator
457
+ """
458
+ if sync_actions is None:
459
+ sync_actions = [action() for action in list(STANDARD_OUTBOUND_SYNC_ACTIONS.values())]
460
+ # ignore null sync actions
461
+ sync_actions = [s for s in sync_actions if s]
462
+ # only used by testing framework when running a behave test
463
+ if self._prebaked_record_state is not None:
464
+ logger.info('returning prebaked record state')
465
+ dataframe = self._prebaked_record_state[self._prebaked_record_state['SYNC_ACTION'].isin(sync_actions)] # pylint: disable=unsubscriptable-object
466
+ if len(dataframe)==0:
467
+ # no need to do the whole FixedSizeGenerator thing for 0 records
468
+ return self.__dataframe_wrapper(dataframe,render_jinja)
469
+ # these were recorded earlier, we should preserve batch size and concurrency from the original
470
+ #wrapped_df = FixedSizeGenerator(dataframe,self.api_limits.batch_size)
471
+ mapfunc = partial(self.__dataframe_wrapper, render_jinja=render_jinja)
472
+ return map(mapfunc, [dataframe])
473
+ #return self.__dataframe_wrapper(wrapped_df,render_jinja)
474
+ sync_action_names:List[str] = [action.action_name for action in sync_actions]
475
+ with self._snowflake_query_lock:
476
+ dataframe = self._session.table(self._full_results_table_name) \
477
+ .filter((col("SYNC_ACTION").in_(sync_action_names)) &
478
+ (col("APPLY_STATE") == lit('ACTIVE'))) \
479
+ .select(col('IDENTIFIER'),col('SYNC_ACTION'),col('TRANSFORMED_RECORD'))
480
+ # apply sorting
481
+ if sort_column is not None:
482
+ sort_col = col(sort_column)
483
+ sorted_col = sort_col.desc() if sort_direction=='desc' else sort_col.asc()
484
+ dataframe = dataframe.sort(sorted_col)
485
+ if batched:
486
+ # we use map to create an iterable wrapper around the pandas batches which are also iterable
487
+ # we use an intermediate partial to allow us to pass the extra parameter
488
+ mapfunc = partial(self.__dataframe_wrapper, render_jinja=render_jinja)
489
+ return map(mapfunc, dataframe.to_pandas_batches())
490
+ #return map(self.__dataframe_wrapper,dataframe.to_pandas_batches(),render_jinja)
491
+ return self.__dataframe_wrapper(dataframe.to_pandas(),render_jinja)
492
+
493
+
494
+ class InboundSyncRequest(SyncRequest):
495
+ """
496
+ Encapsulates a request to retrieve records from an application.
497
+ """
498
+ def __init__(self,run_id:int,
499
+ session:Session,
500
+ source_app_name:str,
501
+ results_schema_name:str,
502
+ results_table_name:str,
503
+ plugin_instance:Type[OmnataPlugin],
504
+ api_limits:List[ApiLimits],
505
+ rate_limit_state:Dict[str,RateLimitState],
506
+ run_deadline:datetime.datetime,
507
+ development_mode:bool = False,
508
+ streams:List[StoredStreamConfiguration] = None
509
+ ):
510
+ """
511
+ Constructs a record apply request.
512
+
513
+ :param int sync_id: The ID number for the sync, only used internally
514
+ :param int sync_slug: The slug for the sync, only used internally
515
+ :param int sync_branch_id: The ID number for the sync branch (optional), only used internally
516
+ :param int sync_branch_name: The name of the branch (main or otherwise), only used internally
517
+ :param int run_id: The ID number for the run, only used internally
518
+ :param any session: The snowpark session object, only used internally
519
+ :param OmnataPlugin plugin_instance: The instance of the Omnata Plugin this request is for
520
+ :param ApiLimits api_limits: Constraints to observe when performing HTTP requests
521
+ :param bool development_mode: In development mode, apply_results_queue does not load into Snowflake, instead they are cached locally and can be retrieved via get_queued_results
522
+ :param StoredStreamConfiguration streams: The configuration for each stream to fetch
523
+ :return: nothing
524
+ """
525
+ SyncRequest.__init__(self,run_id,session,source_app_name,results_schema_name,results_table_name,plugin_instance,api_limits,rate_limit_state,run_deadline,development_mode)
526
+ self.streams = streams
527
+ self._streams_dict:Dict[str,StoredStreamConfiguration] = {s.stream_name:s for s in streams}
528
+ self._apply_results:Dict[str,List[pandas.DataFrame]] = {}
529
+ self._latest_states:Dict[str,any] = {}
530
+ self._temp_tables = {}
531
+ self._temp_table_lock = threading.Lock()
532
+ self._results_exist:Dict[str,bool] = {} # track whether or not results exist for stream
533
+ self._stream_record_counts:Dict[str,int] = {stream_name:0 for stream_name in self._streams_dict.keys()}
534
+ self._stream_change_counts:Dict[str,int] = {stream_name:0 for stream_name in self._streams_dict.keys()}
535
+ self._completed_streams:List[str] = []
536
+
537
+ def apply_results_queue(self):
538
+ """
539
+ Merges all of the queued results and applies them
540
+ """
541
+ logger.info('InboundSyncRequest apply_results_queue ')
542
+ if self._apply_results is not None:
543
+ with self._apply_results_lock:
544
+ for stream_name,stream_results in self._apply_results.items():
545
+ results = [x for x in stream_results if x is not None and len(x) > 0] # remove any None/empty dataframes
546
+ if len(results)>0:
547
+ logger.info(f"Applying {len(results)} batches of queued results")
548
+ # upload all cached apply results
549
+ all_dfs = pandas.concat(results)
550
+ logger.info(f"applying: {all_dfs}")
551
+ self._apply_results_dataframe(stream_name,all_dfs)
552
+ # add the count of this batch to the total for this stream
553
+ self._stream_record_counts[stream_name] = self._stream_record_counts[stream_name] + len(all_dfs)
554
+ # update the stream state object too
555
+ self._apply_latest_states()
556
+ self._apply_results[stream_name] = None
557
+ self._apply_results = {}
558
+ # update the inbound stream record counts, so we can see progress
559
+ # TODO: report back via function call
560
+ with self._snowflake_query_lock:
561
+ stream_record_counts = json.dumps(self._stream_record_counts)
562
+ completed_streams = json.dumps(self._completed_streams)
563
+ logger.info(f"Updating sync run {self._run_id} stream record counts: {stream_record_counts}, completed streams: {completed_streams}")
564
+ self._session.table('DATA.SYNC_RUN').update({
565
+ "INBOUND_STREAM_TOTAL_COUNTS": parse_json(lit(stream_record_counts)),
566
+ "INBOUND_COMPLETED_STREAMS": parse_json(lit(completed_streams))
567
+ },(col('SYNC_RUN_ID')==lit(self._run_id)))
568
+
569
+ def apply_cancellation(self):
570
+ """
571
+ Signals an interruption to the load process for the other threads.
572
+ Also updates the Sync Run to include which streams were cancelled.
573
+ """
574
+ # set the token so that the other threads stop
575
+ self._thread_cancellation_token.set()
576
+ # any stream which didn't complete at this point is considered cancelled
577
+ cancelled_streams = [stream.stream_name for stream in self.streams if stream.stream_name not in self._completed_streams]
578
+ with self._snowflake_query_lock:
579
+ self._session.table('DATA.SYNC_RUN').update({
580
+ "INBOUND_CANCELLED_STREAMS": parse_json(lit(json.dumps(cancelled_streams)))
581
+ },(col('SYNC_RUN_ID')==lit(self._run_id)))
582
+
583
+ def apply_deadline_reached(self):
584
+ """
585
+ Signals an interruption to the load process for the other threads.
586
+ Also updates the Sync Run to include which streams were abandoned.
587
+ """
588
+ # set the token so that the other threads stop
589
+ self._thread_cancellation_token.set()
590
+ # any stream which didn't complete at this point is considered abandoned
591
+ abandoned_streams = [stream.stream_name for stream in self.streams if stream.stream_name not in self._completed_streams]
592
+ with self._snowflake_query_lock:
593
+ self._session.table('DATA.SYNC_RUN').update({
594
+ "INBOUND_ABANDONED_STREAMS": parse_json(lit(json.dumps(abandoned_streams)))
595
+ },(col('SYNC_RUN_ID')==lit(self._run_id)))
596
+
597
+ def enqueue_results(self,stream_name:str,results:List[Dict],new_state:any):
598
+ """
599
+ Adds some results to the queue for applying asynchronously
600
+ """
601
+ # TODO: maybe also have a mechanism to apply immediately if the queued results are getting too large
602
+ logger.info(f"Enqueueing {len(results)} results for upload")
603
+ if stream_name is None or len(stream_name)==0:
604
+ raise ValueError("Stream name cannot be empty")
605
+ with self._apply_results_lock:
606
+ existing_results:List[pandas.DataFrame] = []
607
+ if stream_name in self._apply_results:
608
+ existing_results = self._apply_results[stream_name]
609
+ existing_results.append(self._preprocess_results_list(stream_name,results))
610
+ self._apply_results[stream_name] = existing_results
611
+ current_latest = self._latest_states or {}
612
+ self._latest_states = {**current_latest,**{stream_name:new_state}}
613
+
614
+ def mark_stream_complete(self,stream_name:str):
615
+ """
616
+ Marks a stream as completed, this is called automatically per stream when using @managed_inbound_processing.
617
+ If @managed_inbound_processing is not used, call this whenever a stream has finished recieving records.
618
+ """
619
+ self._completed_streams.append(stream_name)
620
+ # dedup just in case it's called twice
621
+ self._completed_streams = list(set(self._completed_streams))
622
+
623
+
624
+ def _enqueue_state(self,stream_name:str,new_state:any):
625
+ """
626
+ Enqueues some new stream state to be stored. This method should not be called directly,
627
+ instead you should store state using the new_state parameter in the enqueue_results
628
+ method to ensure it's applied along with the associated new records.
629
+ """
630
+ with self._apply_results_lock:
631
+ current_latest = self._latest_states or {}
632
+ self._latest_states = {**current_latest,**{stream_name:new_state}}
633
+
634
+
635
+ def get_queued_results(self,stream_name:str):
636
+ """
637
+ Returns results queued during processing
638
+ """
639
+ if stream_name not in self._apply_results or len(self._apply_results[stream_name])==0:
640
+ raise ValueError('get_queued_results was called, but no results have been queued')
641
+ concat_results = pandas.concat(self._apply_results[stream_name])
642
+ return concat_results
643
+
644
+ def _convert_by_json_schema(self,stream_name:str,data:Dict,json_schema:Dict) -> Dict:
645
+ """
646
+ Apply opportunistic normalization before loading into Snowflake
647
+ """
648
+ try:
649
+ datetime_properties = [k for k,v in json_schema['properties'].items() if 'format' in v and v['format']=='date-time']
650
+ for datetime_property in datetime_properties:
651
+ try:
652
+ if datetime_property in data:
653
+ data[datetime_property] = parse(data[datetime_property]).isoformat()
654
+ except Exception as exception2:
655
+ logger.debug(f"Failure to convert inbound data property {datetime_property} on stream {stream_name}: {str(exception2)}")
656
+ except Exception as exception:
657
+ logger.debug(f"Failure to convert inbound data: {str(exception)}")
658
+ return data
659
+
660
+ def _preprocess_results_list(self,stream_name:str,results:List[Dict]):
661
+ """
662
+ Creates a dataframe from the enqueued list, ready to upload.
663
+ The result is a dataframe contain all (and only):
664
+ 'APP_IDENTIFIER' string
665
+ 'STREAM_NAME' string
666
+ 'RETRIEVE_DATE' datetime (UTC)
667
+ 'RECORD_DATA' object
668
+ """
669
+ #for required_column in ['RECORD_DATA']:
670
+ # if required_column not in results_df.columns:
671
+ # raise ValueError(f'{required_column} column was not included in results')
672
+ if stream_name not in self._streams_dict:
673
+ raise ValueError(f"Cannot preprocess results for stream {stream_name} as its configuration doesn't exist")
674
+ logger.info(f"preprocessing for stream: {self._streams_dict[stream_name]}")
675
+ if len(results) > 0:
676
+ stream_obj:StreamConfiguration = self._streams_dict[stream_name].stream
677
+ results_df = pandas.DataFrame.from_dict([{'RECORD_DATA':self._convert_by_json_schema(stream_name,data,stream_obj.json_schema)} for data in results])
678
+ if self._streams_dict[stream_name].stream.source_defined_primary_key is not None:
679
+ primary_key_field = self._streams_dict[stream_name].stream.source_defined_primary_key
680
+ results_df['APP_IDENTIFIER'] = results_df['RECORD_DATA'].apply(lambda x: x[primary_key_field])
681
+ elif self._streams_dict[stream_name].primary_key_field is not None:
682
+ primary_key_field = self._streams_dict[stream_name].primary_key_field
683
+ results_df['APP_IDENTIFIER'] = results_df['RECORD_DATA'].apply(lambda x: x[primary_key_field])
684
+ else:
685
+ results_df['APP_IDENTIFIER'] = None
686
+ results_df['RETRIEVE_DATE'] = datetime.datetime.utcnow()
687
+ results_df['STREAM_NAME'] = stream_name
688
+ else:
689
+ results_df = pandas.DataFrame([],columns=['APP_IDENTIFIER','STREAM_NAME','RECORD_DATA','RETRIEVE_DATE'])
690
+ # trim out the columns we don't need to return
691
+ return results_df[results_df.columns.intersection(['APP_IDENTIFIER','STREAM_NAME','RECORD_DATA','RETRIEVE_DATE'])]
692
+
693
+ def _apply_results_dataframe(self,stream_name:str,results_df:pandas.DataFrame):
694
+ """
695
+ Applies results for an inbound sync. The results are staged into a temporary
696
+ table in Snowflake, so that we can make an atomic commit at the end.
697
+ """
698
+ if len(results_df) > 0:
699
+ with self._snowflake_query_lock:
700
+ logger.info(f"Applying {len(results_df)} results to {self._full_results_table_name}")
701
+ success, nchunks, nrows, _ = write_pandas(conn=self._session._conn._cursor.connection,
702
+ df=results_df,
703
+ table_name=self._full_results_table_name,
704
+ quote_identifiers=False, # already done in get_temp_table_name
705
+ #schema='INBOUND_RAW', # it seems to be ok to provide schema in the table name
706
+ auto_create_table=True,
707
+ table_type='transient')
708
+ if not success:
709
+ raise ValueError(f"Failed to write results to Snowflake table {self._full_results_table_name}")
710
+ # temp tables aren't allowed
711
+ #snowflake_df = self._session.create_dataframe(results_df)
712
+ #snowflake_df.write.save_as_table(table_name=temp_table,
713
+ # mode='append',
714
+ # column_order='index',
715
+ # #create_temp_table=True
716
+ # )
717
+ self._results_exist[stream_name] = True
718
+ else:
719
+ logger.info("Results dataframe is empty, not applying")
720
+
721
+ def _apply_latest_states(self):
722
+ """
723
+ Updates the SYNC table to have the latest stream states.
724
+ TODO: This should be done in concert with the results, revisit
725
+ """
726
+ with self._snowflake_query_lock:
727
+ try:
728
+ self._session.sql(f"call {self._source_app_name}.API.PLUGIN_STREAM_STATE_UPDATE(:1,PARSE_JSON($$:2$$))",
729
+ [self._run_id,json.dumps(self._latest_states)]).execute()
730
+ except Exception as e:
731
+ logger.error(f"Error updating activity: {e}")
732
+
733
+
734
+ class OAuthParameters(SubscriptableBaseModel):
735
+ """
736
+ Encapsulates a set of OAuth Parameters
737
+ """
738
+ scope:str
739
+ authorization_url:str
740
+ access_token_url:str
741
+ client_id:str
742
+ state:str=None
743
+ response_type:str='code'
744
+ access_type:str='offline'
745
+
746
+ class Config:
747
+ extra = 'allow' # OAuth can contain extra fields
748
+
749
+ class ConnectResponse(SubscriptableBaseModel):
750
+ """
751
+ Encapsulates the response to a connection request. This is used to pass back any additional
752
+ information that may be discovered during connection that's relevant to the plugin (e.g. Account Identifiers).
753
+ You can also specifies any additional network addresses that are needed to connect to the app, that might not
754
+ have been known until the connection was made.
755
+ """
756
+ connection_parameters:dict=None
757
+ connection_secrets:dict=None
758
+ network_addresses:List[str]=None
759
+
760
+ class OmnataPlugin(ABC):
761
+ """
762
+ Class which defines the contract for an Omnata Push Plugin
763
+ """
764
+ def __init__(self):
765
+ """
766
+ Plugin constructors must never take parameters
767
+ """
768
+ self._sync_request:SyncRequest = None
769
+
770
+ @abstractmethod
771
+ def get_manifest(self) -> PluginManifest:
772
+ """
773
+ Returns a manifest object to describe the plugin and its capabilities
774
+ """
775
+ raise NotImplementedError('Your plugin class must implement the get_manifest method')
776
+
777
+ @abstractmethod
778
+ def connection_form(self) -> List[ConnectionMethod]:
779
+ """
780
+ Returns a form definition so that user input can be collected, in order to connect to an app
781
+
782
+ :return A list of ConnectionMethods, each of which offer a way of authenticating to the app and describing what information must be captured
783
+ :rtype List[ConnectionMethod]
784
+ """
785
+ raise NotImplementedError('Your plugin class must implement the connection_form method')
786
+
787
+ def outbound_configuration_form(self,parameters:OutboundSyncConfigurationParameters) -> OutboundSyncConfigurationForm:
788
+ """
789
+ Returns a form definition so that user input can be collected. This function may be called repeatedly with new parameter values
790
+ when dependant fields are used
791
+
792
+ :param OutboundSyncConfigurationParameters parameters the parameters of the sync, configured so far.
793
+ :return A OutboundSyncConfigurationForm, which describes what information must be collected to configure the sync
794
+ :rtype OutboundSyncConfigurationForm
795
+ """
796
+ raise NotImplementedError('Your plugin class must implement the outbound_configuration_form method')
797
+
798
+ def inbound_configuration_form(self,parameters:InboundSyncConfigurationParameters) -> InboundSyncConfigurationForm:
799
+ """
800
+ Returns a form definition so that user input can be collected. This function may be called repeatedly with new parameter values
801
+ when dependant fields are used
802
+
803
+ :param InboundSyncConfigurationParameters parameters the parameters of the sync, configured so far.
804
+ :return A InboundSyncConfigurationForm, which describes what information must be collected to configure the sync
805
+ :rtype InboundSyncConfigurationForm
806
+ """
807
+ raise NotImplementedError('Your plugin class must implement the inbound_configuration_form method')
808
+
809
+ @abstractmethod
810
+ def connect(self,parameters: ConnectionConfigurationParameters) -> ConnectResponse:
811
+ """
812
+ Connects to an app, validating that the information provided by the user was correct.
813
+ For OAuth connection methods, this will be called after the OAuth flow has completed, so the
814
+ access token will be available in the parameters.
815
+
816
+ :param PluginConfigurationParameters parameters the parameters of the sync, as configured by the user
817
+ :return A ConnectResponse, which may provide further information about the app instance for storing
818
+ :rtype ConnectResponse
819
+ :raises ValueError: if issues were encountered during connection
820
+ """
821
+ raise NotImplementedError('Your plugin class must implement the connect method')
822
+
823
+ def oauth_parameters(self,parameters: ConnectionConfigurationParameters) -> OAuthParameters:
824
+ """
825
+ This function is called for any connection method where the "oauth" flag is set to true.
826
+ Connection Parameters are provided in case they are needed to construct the OAuth parameters.
827
+
828
+ :param PluginConfigurationParameters parameters the parameters of the sync, as configured by the user
829
+ :return A OAuthParameters, which contains information to commence an OAuth flow
830
+ :rtype OAuthParameters
831
+ """
832
+ raise NotImplementedError('Your plugin class must implement the oauth_parameters method')
833
+
834
+ def sync_outbound(self,parameters: OutboundSyncConfigurationParameters,outbound_sync_request:OutboundSyncRequest):
835
+ """
836
+ Applies a set of changed records to an app. This function is called whenever a run occurs and changed records
837
+ are found.
838
+ To return results, invoke outbound_sync_request.enqueue_results() during the load process.
839
+
840
+ :param PluginConfigurationParameters parameters the parameters of the sync, as configured by the user
841
+ :param OutboundSyncRequest outbound_sync_request an object describing what has changed
842
+ :return None
843
+ :raises ValueError: if issues were encountered during connection
844
+ """
845
+ raise NotImplementedError('Your plugin class must implement the sync_outbound method')
846
+
847
+ def sync_inbound(self,parameters: InboundSyncConfigurationParameters,inbound_sync_request:InboundSyncRequest):
848
+ """
849
+ Retrieves the next set of records from an application.
850
+ The inbound_sync_request contains the list of streams to be synchronized.
851
+ To return results, invoke inbound_sync_request.enqueue_results() during the load process.
852
+
853
+ :param PluginConfigurationParameters parameters the parameters of the sync, as configured by the user
854
+ :param InboundSyncRequest inbound_sync_request an object describing what needs to be sync'd
855
+ :return None
856
+ :raises ValueError: if issues were encountered during connection
857
+ """
858
+ raise NotImplementedError('Your plugin class must implement the sync_inbound method')
859
+
860
+ def api_limits(self,parameters: SyncConfigurationParameters) -> List[ApiLimits]:
861
+ """
862
+ Defines the API limits in place for the app's API
863
+ """
864
+ return []
865
+
866
+ def additional_loggers(self) -> List[str]:
867
+ """
868
+ Ordinarily, your plugin code will log to a logger named 'omnata_plugin' and these
869
+ messages will automatically be stored in Snowflake and associated with the current
870
+ sync run, so that they appear in the UI's logs.
871
+ However, if you leverage third party python libraries, it may be useful to capture
872
+ log messages from those as well. Overriding this method and returning the names of
873
+ any additional loggers, will cause them to be captured as well.
874
+ For example, if the source code of a third party libary includes:
875
+ logging.getLogger(name='our_api_wrapper'), then returning ['our_api_wrapper']
876
+ will capture its log messages.
877
+ The capture level of third party loggers will be whatever is configured for the sync.
878
+ """
879
+ return []
880
+
881
+ class FixedSizeGenerator():
882
+ """
883
+ A thread-safe class which wraps the pandas batches generator provided by Snowflake,
884
+ but provides batches of a fixed size.
885
+ """
886
+ def __init__(self,generator,batch_size):
887
+ self.generator = generator
888
+ # handle dataframe as well as a dataframe generator, just to be more flexible
889
+ if self.generator.__class__.__name__=='DataFrame':
890
+ logger.info(f"Wrapping a dataframe of length {len(self.generator)} in a map so it acts as a generator")
891
+ self.generator = map(lambda x:x,[self.generator])
892
+ self.leftovers = None
893
+ self.batch_size = batch_size
894
+ self.thread_lock = threading.Lock()
895
+ def __next__(self):
896
+ with self.thread_lock:
897
+ logger.info(f"initial leftovers: {self.leftovers}")
898
+ records_df = self.leftovers
899
+ self.leftovers = None
900
+ try:
901
+ # build up a dataframe until we reach the batch size
902
+ while records_df is None or len(records_df) < self.batch_size:
903
+ current_count = 0 if records_df is None else len(records_df)
904
+ logger.info(f"fetching another dataframe from the generator, got {current_count} out of a desired {self.batch_size}")
905
+ next_df = next(self.generator)
906
+ if next_df is not None and next_df.__class__.__name__ not in ('DataFrame'):
907
+ logger.error(f"Dataframe generator provided an unexpected object, type {next_df.__class__.__name__}")
908
+ raise ValueError(f"Dataframe generator provided an unexpected object, type {next_df.__class__.__name__}")
909
+ if next_df is None and records_df is None:
910
+ logger.info('Original and next dataframes were None, returning None')
911
+ return None
912
+ records_df = pandas.concat([records_df,next_df])
913
+ logger.info(f"after concatenation, dataframe has {len(records_df)} records")
914
+ except StopIteration:
915
+ logger.info('FixedSizeGenerator consumed the last pandas batch')
916
+
917
+ if records_df is None:
918
+ logger.info('No records left, returning None')
919
+ return None
920
+ elif records_df is not None and len(records_df) > self.batch_size:
921
+ logger.info(f'putting {len(records_df[self.batch_size:])} records back ({len(records_df)} > {self.batch_size})')
922
+ self.leftovers = records_df[self.batch_size:].reset_index(drop=True)
923
+ records_df = records_df[0:self.batch_size].reset_index(drop=True)
924
+ else:
925
+ current_count = 0 if records_df is None else len(records_df)
926
+ logger.info(f'{current_count} records does not exceed batch size, not putting any back')
927
+ logger.info(f"FixedSizeGenerator about to return dataframe {records_df}")
928
+ return records_df
929
+ def __iter__(self):
930
+ ''' Returns the Iterator object '''
931
+ return self
932
+
933
+ def __managed_outbound_processing_worker(plugin_class_obj:OmnataPlugin, method:Callable, worker_index:int,
934
+ dataframe_generator:FixedSizeGenerator, cancellation_token:threading.Event, method_args, method_kwargs):
935
+ """
936
+ A worker thread for the managed_outbound_processing annotation.
937
+ Consumes a fixed sized set of records by passing them to the wrapped function,
938
+ while adhering to the defined API constraints.
939
+ """
940
+ while not cancellation_token.is_set():
941
+ # Get our generator object out of the queue
942
+ logger.info(f"worker {worker_index} processing. Cancelled: {cancellation_token.is_set()}")
943
+ if datetime.datetime.now() > plugin_class_obj._sync_request._run_deadline:
944
+ # if we've reached the deadline for the run, end it
945
+ plugin_class_obj._sync_request.apply_deadline_reached()
946
+ return
947
+ records_df = next(dataframe_generator)
948
+ logger.info(f"records returned from dataframe generator: {records_df}")
949
+ if records_df is None:
950
+ logger.info(f"worker {worker_index} has no records left to process")
951
+ return
952
+ elif len(records_df) == 0:
953
+ logger.info(f"worker {worker_index} has 0 records left to process")
954
+ return
955
+
956
+ logger.info(f"worker {worker_index} fetched {len(records_df)} records for processing")
957
+ # threads block while waiting for their allocation of records, it's possible there's been
958
+ # a cancellation in the meantime
959
+ if cancellation_token.is_set():
960
+ logger.info(f"worker {worker_index} exiting before applying records, due to cancellation")
961
+ return
962
+ logger.info(f'worker {worker_index} processing {len(records_df)} records')
963
+ # restore the first argument, was originally the dataframe/generator but now it's the appropriately sized dataframe
964
+ try:
965
+ results_df = method(plugin_class_obj, *(records_df,*method_args), **method_kwargs)
966
+ except InterruptedWhileWaitingException:
967
+ # If an outbound run is cancelled while waiting for rate limiting, this should mean that
968
+ # the cancellation is handled elsewhere, so we don't need to do anything special here other than stop waiting
969
+ return
970
+ logger.info(f'worker {worker_index} received {len(results_df)} results, applying')
971
+
972
+ # we want to write the results of the batch back to Snowflake, so we
973
+ # enqueue them and they'll be picked up by the apply_results worker
974
+ plugin_class_obj._sync_request.enqueue_results(results_df)
975
+ logger.info(f'worker {worker_index} applied results, marking queue task as done')
976
+
977
+ def managed_outbound_processing(concurrency:int, batch_size:int):
978
+ """
979
+ This is a decorator which can be added to a method on an OmnataPlugin class.
980
+ It expects to be invoked with either a DataFrame or a DataFrame generator, and
981
+ the method will receive a DataFrame of the correct size based on the batch_size parameter.
982
+
983
+ The decorator itself must be used as a function call with tuning parameters like so:
984
+ @managed_outbound_processing(concurrency=5, batch_size=100)
985
+ def my_function(param1,param2)
986
+
987
+ Threaded workers will be used to invoke in parallel, according to the concurrency constraints.
988
+
989
+ The decorated method is expected to return a DataFrame with the outcome of each record that was provided.
990
+ """
991
+ def actual_decorator(method):
992
+ @wraps(method)
993
+ def _impl(self, *method_args, **method_kwargs):
994
+ logger.info(f"method_args: {method_args}")
995
+ logger.info(f"method_kwargs: {method_kwargs}")
996
+ if self._sync_request is None:
997
+ raise ValueError('To use the managed_outbound_processing decorator, you must attach an apply request to the plugin instance (via the outbound_sync_request property)')
998
+ #if self._sync_request.api_limits is None:
999
+ # raise ValueError('To use the managed_outbound_processing decorator, API constraints must be defined. These can be provided in the response to the connect method')
1000
+ logger.info(f"Batch size: {batch_size}. Concurrency: {concurrency}")
1001
+ if len(method_args)==0:
1002
+ raise ValueError('You must provide at least one method argument, and the first argument must be a DataFrame or DataFrame generator (from outbound_sync_request.get_records_to_*)')
1003
+ first_arg = method_args[0]
1004
+ logger.info(first_arg.__class__.__name__)
1005
+ if first_arg.__class__.__name__ == 'DataFrame':
1006
+ logger.info('managed_outbound_processing received a DataFrame')
1007
+ elif hasattr(first_arg,'__next__'):
1008
+ logger.info('managed_outbound_processing received an iterator function')
1009
+ else:
1010
+ raise ValueError(f'The first argument to a @managed_outbound_processing method must be a DataFrame or DataFrame generator (from outbound_sync_request.get_records_to_*). Instead, a {first_arg.__class__.__name__} was provided.')
1011
+
1012
+ # put the record iterator on the queue, ready for the first task to read it
1013
+ fixed_size_generator = FixedSizeGenerator(first_arg,batch_size = batch_size)
1014
+ tasks = []
1015
+ logger.info(f"Creating {concurrency} worker(s) for applying records")
1016
+ for i in range(concurrency):
1017
+ # the dataframe/generator was put on the queue, so we remove it from the method args
1018
+ task = threading.Thread(target=__managed_outbound_processing_worker, args=(self,method,i,
1019
+ fixed_size_generator,self._sync_request._thread_cancellation_token,
1020
+ method_args[1:],method_kwargs))
1021
+ tasks.append(task)
1022
+ task.start()
1023
+
1024
+ # wait for workers to finish
1025
+ for task in tasks:
1026
+ task.join()
1027
+ logger.info("Task joined")
1028
+ logger.info("All workers completed processing")
1029
+
1030
+ # it's possible that some records weren't applied, since they are processed asynchronously on a timer
1031
+ if self._sync_request.development_mode is False:
1032
+ self._sync_request.apply_results_queue()
1033
+
1034
+ self._sync_request._thread_cancellation_token.set()
1035
+ # the thread cancellation should be detected by the apply results tasks, so it finishes gracefully
1036
+ if self._sync_request.development_mode is False and self._sync_request._apply_results_task is not None:
1037
+ self._sync_request._apply_results_task.join()
1038
+ logger.info("Checking for thread exception")
1039
+ if self._sync_request._thread_exception_thrown:
1040
+ raise self._sync_request._thread_exception_thrown.exc_value
1041
+
1042
+ logger.info("Main managed_outbound_processing thread completing")
1043
+ return
1044
+ return _impl
1045
+ return actual_decorator
1046
+
1047
+ def __managed_inbound_processing_worker(plugin_class_obj:Type[OmnataPlugin], method:Callable, worker_index:int,
1048
+ streams_queue:queue.Queue, cancellation_token:threading.Event, method_args, method_kwargs):
1049
+ """
1050
+ A worker thread for the managed_outbound_processing annotation.
1051
+ Consumes a fixed sized set of records by passing them to the wrapped function,
1052
+ while adhering to the defined API constraints.
1053
+ """
1054
+ while not cancellation_token.is_set():
1055
+ # Get our generator object out of the queue
1056
+ logger.info(f"worker {worker_index} processing. Cancelled: {cancellation_token.is_set()}")
1057
+ if datetime.datetime.now() > plugin_class_obj._sync_request._run_deadline:
1058
+ # if we've reached the deadline for the run, end it
1059
+ plugin_class_obj._sync_request.apply_deadline_reached()
1060
+ return
1061
+ try:
1062
+ stream:StoredStreamConfiguration = streams_queue.get_nowait()
1063
+ logger.info(f"stream returned from queue: {stream}")
1064
+ # restore the first argument, was originally the dataframe/generator but now it's the appropriately sized dataframe
1065
+ try:
1066
+ method(plugin_class_obj, *(stream,*method_args), **method_kwargs)
1067
+ plugin_class_obj._sync_request.mark_stream_complete(stream.stream_name)
1068
+ except InterruptedWhileWaitingException:
1069
+ # If an inbound run is cancelled while waiting for rate limiting, this should mean that
1070
+ # the cancellation is handled elsewhere, so we don't need to do anything special here other than stop waiting
1071
+ return
1072
+ except queue.Empty:
1073
+ logger.info("streams queue is empty")
1074
+ return
1075
+
1076
+ def managed_inbound_processing(concurrency:int):
1077
+ """
1078
+ This is a decorator which can be added to a method on an OmnataPlugin class.
1079
+ It expects to be invoked with a list of StoredStreamConfiguration objects as the
1080
+ first parameter.
1081
+ The method will receive a single StoredStreamConfiguration object at a time as its
1082
+ first parameter, and is expected to publish its results via
1083
+ inbound_sync_request.enqueue_results() during the load process.
1084
+
1085
+ The decorator itself must be used as a function call with a tuning parameter like so:
1086
+ @managed_inbound_processing(concurrency=5)
1087
+ def my_function(param1,param2)
1088
+
1089
+ Based on the concurrency constraints, it will create threaded workers to retrieve
1090
+ the streams in parallel.
1091
+ """
1092
+ def actual_decorator(method):
1093
+ @wraps(method)
1094
+ def _impl(self, *method_args, **method_kwargs):
1095
+ logger.info(f"method_args: {method_args}")
1096
+ logger.info(f"method_kwargs: {method_kwargs}")
1097
+ if self._sync_request is None:
1098
+ raise ValueError('To use the managed_inbound_processing decorator, you must attach an apply request to the plugin instance (via the outbound_sync_request property)')
1099
+ #if self._sync_request.api_limits is None:
1100
+ # raise ValueError('To use the managed_inbound_processing decorator, API constraints must be defined. These can be provided in the response to the connect method')
1101
+ if len(method_args)==0:
1102
+ raise ValueError('You must provide at least one method argument, and the first argument must be a DataFrame or DataFrame generator (from outbound_sync_request.get_records_to_*)')
1103
+ first_arg:List[StoredStreamConfiguration] = method_args[0]
1104
+ logger.info(first_arg.__class__.__name__)
1105
+ if first_arg.__class__.__name__ == 'list':
1106
+ logger.info('managed_inbound_processing received a list')
1107
+ else:
1108
+ raise ValueError(f'The first argument to a @managed_inbound_processing method must be a list of StoredStreamConfigurations. Instead, a {first_arg.__class__.__name__} was provided.')
1109
+
1110
+ streams_list:List[StoredStreamConfiguration] = first_arg
1111
+ # create a queue full of all the streams to process
1112
+ streams_queue = queue.Queue()
1113
+ for stream in streams_list:
1114
+ streams_queue.put(stream)
1115
+
1116
+ tasks = []
1117
+ logger.info(f"Creating {concurrency} worker(s) for applying records")
1118
+
1119
+ for i in range(concurrency):
1120
+ # the dataframe/generator was put on the queue, so we remove it from the method args
1121
+ task = threading.Thread(target=__managed_inbound_processing_worker, args=(self,method,i,
1122
+ streams_queue,self._sync_request._thread_cancellation_token,
1123
+ method_args[1:],method_kwargs))
1124
+ tasks.append(task)
1125
+ task.start()
1126
+
1127
+ # wait for workers to finish
1128
+ for task in tasks:
1129
+ task.join()
1130
+ logger.info("Task joined")
1131
+ logger.info("All workers completed processing")
1132
+
1133
+ # it's possible that some records weren't applied, since they are processed asynchronously on a timer
1134
+ if self._sync_request.development_mode is False:
1135
+ self._sync_request.apply_results_queue()
1136
+
1137
+ self._sync_request._thread_cancellation_token.set()
1138
+ # the thread cancellation should be detected by the apply results tasks, so it finishes gracefully
1139
+ if self._sync_request.development_mode is False and self._sync_request._apply_results_task is not None:
1140
+ self._sync_request._apply_results_task.join()
1141
+ logger.info("Checking for thread exception")
1142
+ if self._sync_request._thread_exception_thrown:
1143
+ raise self._sync_request._thread_exception_thrown.exc_value
1144
+
1145
+ logger.info("Main managed_inbound_processing thread completing")
1146
+ return
1147
+ return _impl
1148
+ return actual_decorator
1149
+
1150
+ class DeadlineReachedException(Exception):
1151
+ """
1152
+ Indicates that a sync needed to be abandoned due to reaching a deadline, or needing to wait past a future
1153
+ deadline.
1154
+ """