omnata-plugin-runtime 0.10.33a297__tar.gz → 0.11.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: omnata-plugin-runtime
3
- Version: 0.10.33a297
3
+ Version: 0.11.0
4
4
  Summary: Classes and common runtime components for building and running Omnata Plugins
5
5
  Author: James Weakley
6
6
  Author-email: james.weakley@omnata.com
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "omnata-plugin-runtime"
3
- version = "0.10.33-a297"
3
+ version = "0.11.0"
4
4
  description = "Classes and common runtime components for building and running Omnata Plugins"
5
5
  authors = ["James Weakley <james.weakley@omnata.com>"]
6
6
  readme = "README.md"
@@ -405,6 +405,20 @@ class FullyQualifiedTable(BaseModel):
405
405
  return self.get_fully_qualified_name(
406
406
  table_override=f"{self.table_name}_CRITERIA_DELETES"
407
407
  )
408
+
409
+ def get_fully_qualified_state_register_table_name(self) -> str:
410
+ """
411
+ Returns the fully qualified name of the state register table.
412
+ This is used to store state values for syncs, paired with query IDs to use with time travel.
413
+ """
414
+ return self.get_fully_qualified_name(table_override=f"{self.table_name}_STATE_REGISTER")
415
+
416
+ def get_fully_qualified_state_register_table_sequence_name(self) -> str:
417
+ """
418
+ Returns the fully qualified name of the state register table.
419
+ This is used to store state values for syncs, paired with query IDs to use with time travel.
420
+ """
421
+ return self.get_fully_qualified_name(table_override=f"{self.table_name}_STATE_REGISTER_SEQ")
408
422
 
409
423
  class SnowflakeViewPart(BaseModel):
410
424
  """
@@ -456,10 +470,20 @@ class SnowflakeViewPart(BaseModel):
456
470
  c.name_with_comment(binding_list) for c in self.columns
457
471
  ]
458
472
 
459
- def cte_text(self,original_name: bool = False, include_only_columns:Optional[List[str]] = None) -> str:
473
+ def cte_text(self,original_name: bool = False,
474
+ include_only_columns:Optional[List[str]] = None,
475
+ include_extra_columns:Optional[List[str]] = None
476
+ ) -> str:
460
477
  """
461
478
  Returns the CTE text for this view part.
462
479
  """
480
+ if include_extra_columns is not None:
481
+ # includes direct columns plus any extra specified
482
+ return f""" "{self.stream_name}" as (
483
+ select {', '.join([c.definition(original_name=original_name,remove_stream_prefix=self.stream_name) for c in self.columns
484
+ if c.original_name in include_extra_columns or not c.is_join_column])}
485
+ from {self.raw_table_location.get_fully_qualified_name()}
486
+ ) """
463
487
  if include_only_columns is None:
464
488
  return f""" "{self.stream_name}" as (
465
489
  select {', '.join([c.definition(original_name=original_name,remove_stream_prefix=self.stream_name) for c in self.direct_columns()])}
@@ -490,6 +514,29 @@ class SnowflakeViewParts(BaseModel):
490
514
  ..., description="The other streams that are joined to the main stream"
491
515
  )
492
516
 
517
+ def column_indirectly_references_other_streams(
518
+ self,
519
+ all_view_parts:List[SnowflakeViewPart],
520
+ stream_name:str,column_name:str) -> bool:
521
+
522
+ for part in all_view_parts:
523
+ if part.stream_name == stream_name:
524
+ for col in part.columns:
525
+ if col.original_name == column_name:
526
+ if col.referenced_columns:
527
+ for ref_stream, ref_cols in col.referenced_columns.items():
528
+ if ref_stream != stream_name:
529
+ return True
530
+ else:
531
+ # we have to call this recursively in case the referenced column also references other streams
532
+ result = any(
533
+ self.column_indirectly_references_other_streams(
534
+ all_view_parts, ref_stream, ref_col
535
+ ) for ref_col in ref_cols
536
+ )
537
+ return result
538
+ return False
539
+
493
540
  def view_body(self):
494
541
  """
495
542
  Creates a view definition from the parts.
@@ -505,31 +552,40 @@ class SnowflakeViewParts(BaseModel):
505
552
 
506
553
  # first, we need to collapse all referenced columns into a single map
507
554
  all_referenced_columns:Dict[str,List[str]] = {}
555
+
556
+ # if a column references other columns, but there are no dependencies outside of its own stream, we can include those columns in the initial CTE for that stream
557
+ # because they can be calculated directly without needing joins
558
+ columns_only_referencing_own_stream:Dict[str,List[str]] = {}
559
+
560
+
508
561
  for part in [self.main_part] + self.joined_parts:
509
- # if the main part references any columns in this part in its joins, we need to include those columns
562
+ # if the main part references any columns in this part in its joins, we need to include those columns because they are used in the join condition
510
563
  aliases_for_stream = [j.join_stream_alias for j in self.main_part.joins
511
564
  if j.join_stream_name == part.stream_name]
512
565
  columns_used_in_joins = [
513
566
  j.left_column for j in self.main_part.joins if j.left_alias in aliases_for_stream
514
567
  ]
515
- if part.stream_name not in all_referenced_columns:
516
- all_referenced_columns[part.stream_name] = []
517
- all_referenced_columns[part.stream_name] += columns_used_in_joins
568
+ all_referenced_columns.setdefault(part.stream_name, []).extend(columns_used_in_joins)
569
+ # now, for each column in the part, if it references columns in other streams, we need to include those columns
518
570
  for column in part.columns:
519
571
  if column.referenced_columns:
520
572
  for stream_name, referenced_columns in column.referenced_columns.items():
521
- if stream_name not in all_referenced_columns:
522
- all_referenced_columns[stream_name] = []
523
- all_referenced_columns[stream_name] += referenced_columns
573
+ all_referenced_columns.setdefault(stream_name, []).extend(referenced_columns)
574
+ # populate columns_only_referencing_own_stream by following the chain of references until we reach a column that references another stream or has no references
575
+ if self.column_indirectly_references_other_streams(
576
+ [self.main_part] + self.joined_parts, part.stream_name, column.original_name
577
+ ) == False:
578
+ columns_only_referencing_own_stream.setdefault(part.stream_name, []).append(column.original_name)
579
+ else:
580
+ # if the column has no references, it can be included in the initial CTE for its own stream
581
+ columns_only_referencing_own_stream.setdefault(part.stream_name, []).append(column.original_name)
582
+ # if this part has joins to other streams, we need to include the join columns
524
583
  for join in part.joins:
525
- if join.join_stream_name not in all_referenced_columns:
526
- all_referenced_columns[join.join_stream_name] = []
527
- all_referenced_columns[join.join_stream_name].append(join.join_stream_column)
528
- all_referenced_columns[part.stream_name].append(join.left_column)
529
-
584
+ all_referenced_columns.setdefault(join.join_stream_name, []).append(join.join_stream_column)
585
+ all_referenced_columns.setdefault(part.stream_name, []).append(join.left_column)
530
586
 
531
587
  ctes = [
532
- self.main_part.cte_text(original_name=True)
588
+ self.main_part.cte_text(original_name=True,include_extra_columns=columns_only_referencing_own_stream.get(self.main_part.stream_name))
533
589
  ] + [
534
590
  part.cte_text(original_name=True,include_only_columns=all_referenced_columns.get(part.stream_name))
535
591
  for part in joined_parts_deduped
@@ -539,9 +595,9 @@ class SnowflakeViewParts(BaseModel):
539
595
  final_cte = f""" OMNATA_FINAL_CTE as (
540
596
  select {', '.join(
541
597
  [
542
- f'"{self.main_part.stream_name}"."{c.original_name}"' for c in self.main_part.direct_columns()
598
+ f'"{self.main_part.stream_name}"."{c.original_name}"' for c in self.main_part.columns if not c.is_join_column or c.original_name in columns_only_referencing_own_stream.get(self.main_part.stream_name,[])
543
599
  ]+[
544
- c.definition(original_name=True) for c in self.main_part.join_columns()
600
+ c.definition(original_name=True) for c in self.main_part.columns if c.is_join_column and c.original_name not in columns_only_referencing_own_stream.get(self.main_part.stream_name,[])
545
601
  ])}
546
602
  from "{self.main_part.stream_name}" """
547
603
  if len(self.main_part.joins) > 0:
@@ -747,6 +803,19 @@ class SnowflakeViewParts(BaseModel):
747
803
 
748
804
  return cls(main_part=main_stream_view_part, joined_parts=joined_parts)
749
805
 
806
+
807
+ # Helper function to find a view part by stream name
808
+ def find_part(view_part: SnowflakeViewPart, joined_parts: List[SnowflakeViewPart], stream_name: str) -> Optional[SnowflakeViewPart]:
809
+ if stream_name == view_part.stream_name:
810
+ return view_part
811
+ for part in joined_parts:
812
+ if part.stream_name == stream_name:
813
+ return part
814
+ for join in view_part.joins:
815
+ if join.join_stream_alias == stream_name:
816
+ return view_part
817
+ return None
818
+
750
819
  def prune(view_part: SnowflakeViewPart, joined_parts: List[SnowflakeViewPart]) -> bool:
751
820
  """
752
821
  Prunes columns from view parts that reference fields that don't exist in the referenced streams.
@@ -759,12 +828,6 @@ def prune(view_part: SnowflakeViewPart, joined_parts: List[SnowflakeViewPart]) -
759
828
  Raises ValueError if a cyclic dependency is detected.
760
829
  """
761
830
  columns_removed = False
762
-
763
- # Helper function to find a view part by stream name
764
- def find_part(stream_name: str) -> Optional[SnowflakeViewPart]:
765
- if stream_name == view_part.stream_name:
766
- return view_part
767
- return next((p for p in joined_parts if p.stream_name == stream_name), None)
768
831
 
769
832
  # Helper function to check if a column should be kept or removed
770
833
  def should_keep_column(column: SnowflakeViewColumn, part: SnowflakeViewPart) -> bool:
@@ -779,7 +842,7 @@ def prune(view_part: SnowflakeViewPart, joined_parts: List[SnowflakeViewPart]) -
779
842
  # Check each referenced stream and its fields
780
843
  for ref_stream_name, ref_fields in column.referenced_columns.items():
781
844
  # Find the referenced part
782
- ref_part = find_part(ref_stream_name)
845
+ ref_part = find_part(view_part, joined_parts,ref_stream_name)
783
846
 
784
847
  # If referenced stream doesn't exist, remove the column
785
848
  if ref_part is None:
@@ -93,6 +93,9 @@ from .rate_limiting import (
93
93
  RateLimitState,
94
94
  RateLimitedSession
95
95
  )
96
+ from .json_schema import (
97
+ FullyQualifiedTable
98
+ )
96
99
 
97
100
  SortDirectionType = Literal["asc", "desc"]
98
101
 
@@ -1055,7 +1058,6 @@ class InboundSyncRequest(SyncRequest):
1055
1058
 
1056
1059
  # These are similar to the results, but represent requests to delete records by some criteria
1057
1060
  self._apply_results_criteria_deletes: Dict[str, List[pandas.DataFrame]] = {}
1058
- self._latest_states: Dict[str, Any] = {}
1059
1061
  self._temp_tables = {}
1060
1062
  self._temp_table_lock = threading.Lock()
1061
1063
  self._results_exist: Dict[
@@ -1085,24 +1087,47 @@ class InboundSyncRequest(SyncRequest):
1085
1087
  sync_id=sync_id,
1086
1088
  branch_name=branch_name
1087
1089
  )
1088
- # named by convention, see SyncRunProcessor.enqueue
1089
- self._criteria_deletes_table_name = (
1090
- f"{self._source_app_name}.{self._results_schema_name}.{self._results_table_name}_CRITERIA_DELETES"
1090
+ # The results table name is also used to derive several other table/stage names
1091
+ results_table = FullyQualifiedTable(
1092
+ database_name= self._source_app_name,
1093
+ schema_name= self._results_schema_name,
1094
+ table_name= self._results_table_name
1091
1095
  )
1092
- self._apply_results: Dict[str, List[pandas.DataFrame]] = {}
1096
+ self._criteria_deletes_table_name = results_table.get_fully_qualified_criteria_deletes_table_name()
1097
+ self.state_register_table_name = results_table.get_fully_qualified_state_register_table_name()
1098
+ # this is keyed on stream name, each containing a list of dataframes and state updates mixed
1099
+ self._apply_results: Dict[str, List[pandas.DataFrame | Dict]] = {}
1093
1100
 
1094
1101
  def apply_results_queue(self):
1095
1102
  """
1096
- Merges all of the queued results and applies them
1103
+ Merges all of the queued results and applies them, including state updates.
1097
1104
  """
1098
- logger.debug("InboundSyncRequest apply_results_queue ")
1105
+ logger.debug("InboundSyncRequest apply_results_queue")
1099
1106
  if self._apply_results is not None:
1100
1107
  with self._apply_results_lock:
1101
1108
  results:List[pandas.DataFrame] = []
1102
- stream_names:List[str] = []
1109
+ stream_states_for_upload:Dict[str, Dict[str, Any]] = {}
1103
1110
  for stream_name, stream_results in self._apply_results.items():
1111
+ # the stream results contains an ordered sequence of dataframes and state updates (append only)
1112
+ # we only want to apply the dataframes up until the most recent state update
1113
+ # so first, we iterate backwards to find the last state update
1114
+ last_state_index = -1
1115
+ for i in range(len(stream_results) - 1, -1, -1):
1116
+ if isinstance(stream_results[i], dict):
1117
+ last_state_index = i
1118
+ stream_states_for_upload[stream_name] = stream_results[i]
1119
+ break
1120
+ # if there are no state updates, we can't do anything with this stream
1121
+ if last_state_index == -1:
1122
+ logger.debug(
1123
+ f"No state updates for stream {stream_name}, skipping"
1124
+ )
1125
+ continue
1126
+ assert isinstance(stream_states_for_upload[stream_name], dict), "Latest state must be a dictionary"
1127
+ # now we can take the dataframes up to the last state update
1128
+ dfs = stream_results[:last_state_index]
1104
1129
  non_empty_dfs = [
1105
- x for x in stream_results if x is not None and len(x) > 0
1130
+ x for x in dfs if x is not None and isinstance(x, pandas.DataFrame) and len(x) > 0
1106
1131
  ]
1107
1132
  # get the total length of all the dataframes
1108
1133
  total_length = sum([len(x) for x in non_empty_dfs])
@@ -1110,22 +1135,28 @@ class InboundSyncRequest(SyncRequest):
1110
1135
  self._stream_record_counts[
1111
1136
  stream_name
1112
1137
  ] = self._stream_record_counts[stream_name] + total_length
1113
- results.extend(non_empty_dfs) # remove any None/empty dataframes
1114
- stream_names.append(stream_name)
1138
+ results.extend(non_empty_dfs)
1139
+ # now remove everything up to the last state update
1140
+ # we do this so that we don't apply the same state update multiple times
1141
+ self._apply_results[stream_name] = stream_results[
1142
+ last_state_index + 1 :
1143
+ ] # keep everything after the last state update
1115
1144
  if len(results) > 0:
1116
1145
  logger.debug(
1117
1146
  f"Applying {len(results)} batches of queued results"
1118
1147
  )
1119
1148
  # upload all cached apply results
1120
1149
  all_dfs = pandas.concat(results)
1121
- self._apply_results_dataframe(stream_names, all_dfs)
1122
- # update the stream state object too
1123
- self._apply_latest_states()
1124
- for stream_name in stream_names:
1125
- self._apply_results[stream_name] = None
1126
- self._apply_results = {}
1127
-
1150
+ query_id = self._apply_results_dataframe(list(stream_states_for_upload.keys()), all_dfs)
1151
+ # now that the results have been updated, we need to insert records into the state register table
1152
+ # we do this by inserting the latest state for each stream
1153
+ self._directly_insert_to_state_register(
1154
+ stream_states_for_upload, query_id=query_id
1155
+ )
1156
+
1128
1157
  # also take care of uploading delete requests
1158
+ # technically these should be managed along with the state, however there aren't any scenarios where checkpointing is done
1159
+ # and deletes have an impact. This is because we only checkpoint in scenarios where the target table is empty first
1129
1160
  if hasattr(self,'_apply_results_criteria_deletes') and self._apply_results_criteria_deletes is not None:
1130
1161
  with self._apply_results_lock:
1131
1162
  results:List[pandas.DataFrame] = []
@@ -1149,27 +1180,57 @@ class InboundSyncRequest(SyncRequest):
1149
1180
  # so we need to make sure all the results are applied first
1150
1181
  self.apply_progress_updates()
1151
1182
 
1183
+ def _directly_insert_to_state_register(
1184
+ self, stream_states_for_upload: Dict[str, Dict[str, Any]],
1185
+ query_id: Optional[str] = None
1186
+ ) -> str:
1187
+ binding_values = []
1188
+ select_clauses = []
1189
+
1190
+ with self._snowflake_query_lock:
1191
+ if query_id is None:
1192
+ query_id = self._get_query_id_for_now()
1193
+ for stream_name, latest_state in stream_states_for_upload.items():
1194
+ binding_values.extend([stream_name, query_id, json.dumps(latest_state)])
1195
+ select_clauses.append(
1196
+ f"select ?, ?, PARSE_JSON(?)"
1197
+ )
1198
+ final_query = f"""INSERT INTO {self.state_register_table_name} (STREAM_NAME, QUERY_ID, STATE_VALUE)
1199
+ {' union all '.join(select_clauses)}"""
1200
+ self._session.sql(final_query, binding_values).collect()
1201
+ streams_included = list(stream_states_for_upload.keys())
1202
+ logger.debug(f"Inserted state for streams: {streams_included} with query ID {query_id}")
1203
+
1152
1204
  def apply_progress_updates(self, ignore_errors:bool = True):
1153
1205
  """
1154
1206
  Sends a message to the plugin with the current progress of the sync run, if it has changed since last time.
1155
1207
  """
1156
- if self._apply_results is not None:
1157
- with self._apply_results_lock:
1158
- new_progress_update = PluginMessageStreamProgressUpdate(
1159
- stream_total_counts=self._stream_record_counts,
1160
- # records could have been marked as completed, but still have results to apply
1161
- completed_streams=[s for s in self._completed_streams if s not in self._apply_results or self._apply_results[s] is None],
1162
- stream_errors=self._omnata_log_handler.stream_global_errors,
1163
- total_records_estimate=self._total_records_estimate
1164
- )
1165
- if self._last_stream_progress_update is None or new_progress_update != self._last_stream_progress_update:
1166
- result = self._plugin_message(
1167
- message=new_progress_update,
1168
- ignore_errors=ignore_errors
1208
+ with self._apply_results_lock:
1209
+ new_progress_update = PluginMessageStreamProgressUpdate(
1210
+ stream_total_counts=self._stream_record_counts,
1211
+ # records could have been marked as completed, but still have results to apply
1212
+ completed_streams=[s for s in self._completed_streams
1213
+ if s not in self._apply_results
1214
+ or self._apply_results[s] is None
1215
+ or len(self._apply_results[s]) == 0],
1216
+ stream_errors=self._omnata_log_handler.stream_global_errors,
1217
+ total_records_estimate=self._total_records_estimate
1169
1218
  )
1170
- if result is None:
1171
- return False
1172
- self._last_stream_progress_update = new_progress_update
1219
+ if self._last_stream_progress_update is None or new_progress_update != self._last_stream_progress_update:
1220
+ result = self._plugin_message(
1221
+ message=new_progress_update,
1222
+ ignore_errors=ignore_errors
1223
+ )
1224
+ if result is None:
1225
+ return False
1226
+ self._last_stream_progress_update = new_progress_update
1227
+ completed_streams_awaiting_results_upload = [
1228
+ s for s in self._completed_streams if s in self._apply_results and self._apply_results[s] is not None
1229
+ ]
1230
+ if len(completed_streams_awaiting_results_upload) > 0:
1231
+ logger.debug(
1232
+ f"Streams marked as completed but awaiting upload: {', '.join(completed_streams_awaiting_results_upload)}"
1233
+ )
1173
1234
  return True
1174
1235
 
1175
1236
  def apply_cancellation(self):
@@ -1224,9 +1285,9 @@ class InboundSyncRequest(SyncRequest):
1224
1285
  if stream_name in self._apply_results:
1225
1286
  existing_results = self._apply_results[stream_name]
1226
1287
  existing_results.append(self._preprocess_results_list(stream_name, results, is_delete))
1288
+ if new_state is not None:
1289
+ existing_results.append(new_state) # append the new state at the end
1227
1290
  self._apply_results[stream_name] = existing_results
1228
- current_latest = self._latest_states or {}
1229
- self._latest_states = {**current_latest, **{stream_name: new_state}}
1230
1291
  # if the total size of all the dataframes exceeds 200MB, apply the results immediately
1231
1292
  # we'll use df.memory_usage(index=True) for this
1232
1293
  if self.development_mode is False:
@@ -1236,7 +1297,7 @@ class InboundSyncRequest(SyncRequest):
1236
1297
  # if the total exceeds 200MB, we apply the results immediately
1237
1298
  all_df_lists:List[List[pandas.DataFrame]] = list(self._apply_results.values())
1238
1299
  # flatten
1239
- all_dfs:List[pandas.DataFrame] = [x for sublist in all_df_lists for x in sublist]
1300
+ all_dfs:List[pandas.DataFrame] = [x for sublist in all_df_lists for x in sublist if isinstance(x, pandas.DataFrame)]
1240
1301
  combined_length = sum([len(x) for x in all_dfs])
1241
1302
  # first, don't bother if the count is less than 10000, since it's unlikely to be even close
1242
1303
  if combined_length > 10000:
@@ -1286,7 +1347,7 @@ class InboundSyncRequest(SyncRequest):
1286
1347
  combined_length = sum([len(x) for x in all_dfs])
1287
1348
  # first, don't both if the count is less than 10000, since it's unlikely to be even close
1288
1349
  if combined_length > 10000:
1289
- if sum([x.memory_usage(index=True).sum() for x in all_dfs]) > 200000000:
1350
+ if sum([x.memory_usage(index=True).sum() for x in all_dfs if isinstance(x, pandas.DataFrame)]) > 200000000:
1290
1351
  logger.debug(f"Applying criteria deletes queue immediately due to combined dataframe size")
1291
1352
  self.apply_results_queue()
1292
1353
 
@@ -1295,9 +1356,11 @@ class InboundSyncRequest(SyncRequest):
1295
1356
  Marks a stream as completed, this is called automatically per stream when using @managed_inbound_processing.
1296
1357
  If @managed_inbound_processing is not used, call this whenever a stream has finished recieving records.
1297
1358
  """
1298
- self._completed_streams.append(stream_name)
1299
- # dedup just in case it's called twice
1300
- self._completed_streams = list(set(self._completed_streams))
1359
+ logger.debug(f"Marking stream {stream_name} as completed locally")
1360
+ with self._apply_results_lock:
1361
+ self._completed_streams.append(stream_name)
1362
+ # dedup just in case it's called twice
1363
+ self._completed_streams = list(set(self._completed_streams))
1301
1364
 
1302
1365
  def set_stream_record_count(self, stream_name: str, count: int):
1303
1366
  """
@@ -1321,9 +1384,42 @@ class InboundSyncRequest(SyncRequest):
1321
1384
  instead you should store state using the new_state parameter in the enqueue_results
1322
1385
  method to ensure it's applied along with the associated new records.
1323
1386
  """
1387
+ self.enqueue_state(
1388
+ stream_name=stream_name,
1389
+ new_state=new_state,
1390
+ query_id=None # query_id will be generated automatically if not provided
1391
+ )
1392
+
1393
+ def enqueue_state(self, stream_name: str, new_state: Any, query_id: Optional[str] = None):
1394
+ """
1395
+ Enqueues some new stream state to be stored. This method should be called whenever the state of a stream changes.
1396
+
1397
+ If there have been records enqueued here for this stream, it is assumed that the state is related to those records.
1398
+ In this case, the state will be applied after the records are applied.
1399
+ If there are no records enqueued for this stream, the state will be applied immediately as it is assumed that the results
1400
+ were directly inserted, and therefore we need to capture the current query ID before more results are inserted.
1401
+ """
1324
1402
  with self._apply_results_lock:
1325
- current_latest = self._latest_states or {}
1326
- self._latest_states = {**current_latest, **{stream_name: new_state}}
1403
+ if stream_name in self._apply_results:
1404
+ if len(self._apply_results[stream_name]) > 0:
1405
+ self._apply_results[stream_name].append(new_state)
1406
+ return
1407
+
1408
+ self._directly_insert_to_state_register(
1409
+ {
1410
+ stream_name: new_state
1411
+ }, query_id=query_id
1412
+ )
1413
+
1414
+
1415
+ def _get_query_id_for_now(self):
1416
+ """
1417
+ Gets a Snowflake query ID right now. Note that this does not require a Snowflake lock, the caller
1418
+ should ensure that this is called in a thread-safe manner.
1419
+ """
1420
+ job=self._session.sql("select 1").collect_nowait()
1421
+ job.result()
1422
+ return job.query_id
1327
1423
 
1328
1424
  def get_queued_results(self, stream_name: str):
1329
1425
  """
@@ -1337,7 +1433,7 @@ class InboundSyncRequest(SyncRequest):
1337
1433
  "get_queued_results was called, but no results have been queued"
1338
1434
  )
1339
1435
  concat_results = pandas.concat(self._apply_results[stream_name])
1340
- return concat_results
1436
+ return [c for c in concat_results if c is not None and isinstance(c, pandas.DataFrame) and len(c) > 0]
1341
1437
 
1342
1438
  def _convert_by_json_schema(
1343
1439
  self, stream_name: str, data: Dict, json_schema: Dict
@@ -1512,10 +1608,11 @@ class InboundSyncRequest(SyncRequest):
1512
1608
  hash_object = hashlib.sha256(key_string.encode())
1513
1609
  return hash_object.hexdigest()
1514
1610
 
1515
- def _apply_results_dataframe(self, stream_names: List[str], results_df: pandas.DataFrame):
1611
+ def _apply_results_dataframe(self, stream_names: List[str], results_df: pandas.DataFrame) -> Optional[str]:
1516
1612
  """
1517
1613
  Applies results for an inbound sync. The results are staged into a temporary
1518
1614
  table in Snowflake, so that we can make an atomic commit at the end.
1615
+ Returns a query ID that can be used for checkpointing after the copy into command has run.
1519
1616
  """
1520
1617
  if len(results_df) > 0:
1521
1618
  with self._snowflake_query_lock:
@@ -1538,6 +1635,7 @@ class InboundSyncRequest(SyncRequest):
1538
1635
  raise ValueError(
1539
1636
  f"Failed to write results to table {self._full_results_table_name}"
1540
1637
  )
1638
+ query_id = self._get_query_id_for_now()
1541
1639
  logger.debug(
1542
1640
  f"Wrote {nrows} rows and {nchunks} chunks to table {self._full_results_table_name}"
1543
1641
  )
@@ -1550,19 +1648,10 @@ class InboundSyncRequest(SyncRequest):
1550
1648
  # )
1551
1649
  for stream_name in stream_names:
1552
1650
  self._results_exist[stream_name] = True
1651
+ return query_id
1553
1652
  else:
1554
1653
  logger.debug("Results dataframe is empty, not applying")
1555
1654
 
1556
- def _apply_latest_states(self):
1557
- """
1558
- Updates the SYNC table to have the latest stream states.
1559
- TODO: This should be done in concert with the results, revisit
1560
- """
1561
- if self._last_states_update is None or json.dumps(self._latest_states) != json.dumps(self._last_states_update):
1562
- self._last_states_update = json.loads(json.dumps(self._latest_states))
1563
- self._plugin_message(PluginMessageStreamState(stream_state=self._latest_states))
1564
-
1565
-
1566
1655
  def _apply_criteria_deletes_dataframe(self, results_df: pandas.DataFrame):
1567
1656
  """
1568
1657
  Applies results for an inbound sync. The results are staged into a temporary
@@ -250,7 +250,6 @@ class PluginEntrypoint:
250
250
  self._plugin_instance._configuration_parameters = parameters
251
251
 
252
252
  inbound_sync_request.update_activity("Invoking plugin")
253
- logger.info(f"inbound sync request: {inbound_sync_request}")
254
253
  # plugin_instance._inbound_sync_request = outbound_sync_request
255
254
  with tracer.start_as_current_span("invoke_plugin"):
256
255
  with HttpRateLimiting(inbound_sync_request, parameters):
@@ -283,6 +282,11 @@ class PluginEntrypoint:
283
282
  # token is set. We throw it here as an error since that's currently how it flows back to the engine with a DELAYED state
284
283
  raise DeadlineReachedException()
285
284
  finally:
285
+ # try to upload any remaining results
286
+ try:
287
+ inbound_sync_request.apply_results_queue()
288
+ except Exception as e:
289
+ logger.warning(f"Error uploading remaining results: {str(e)}", exc_info=True)
286
290
  # cancel the thread so we don't leave anything hanging around and cop a nasty error
287
291
  try:
288
292
  inbound_sync_request._thread_cancellation_token.set() # pylint: disable=protected-access