omnata-plugin-runtime 0.11.0a302__tar.gz → 0.11.7a324__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of omnata-plugin-runtime might be problematic. Click here for more details.

@@ -1,12 +1,12 @@
1
- Metadata-Version: 2.3
1
+ Metadata-Version: 2.4
2
2
  Name: omnata-plugin-runtime
3
- Version: 0.11.0a302
3
+ Version: 0.11.7a324
4
4
  Summary: Classes and common runtime components for building and running Omnata Plugins
5
+ License-File: LICENSE
5
6
  Author: James Weakley
6
7
  Author-email: james.weakley@omnata.com
7
- Requires-Python: >=3.8,<=3.11
8
+ Requires-Python: >=3.9,<=3.11
8
9
  Classifier: Programming Language :: Python :: 3
9
- Classifier: Programming Language :: Python :: 3.8
10
10
  Classifier: Programming Language :: Python :: 3.9
11
11
  Classifier: Programming Language :: Python :: 3.10
12
12
  Classifier: Programming Language :: Python :: 3.11
@@ -1,13 +1,13 @@
1
1
  [tool.poetry]
2
2
  name = "omnata-plugin-runtime"
3
- version = "0.11.0-a302"
3
+ version = "0.11.7-a324"
4
4
  description = "Classes and common runtime components for building and running Omnata Plugins"
5
5
  authors = ["James Weakley <james.weakley@omnata.com>"]
6
6
  readme = "README.md"
7
7
  packages = [{include = "omnata_plugin_runtime", from = "src"}]
8
8
 
9
9
  [tool.poetry.dependencies]
10
- python = ">=3.8, <=3.11"
10
+ python = ">=3.9, <=3.11"
11
11
  snowflake-snowpark-python = ">=1.20.0,<=1.24.0" # latest version available on Snowflake Anaconda, but allow pinning to 1.20.0 for to_pandas_batches workaround
12
12
  snowflake-connector-python = "^3, <=3.12.0" # latest version available on Snowflake Anaconda
13
13
  cryptography = "<=43.0.0"
@@ -7,6 +7,7 @@ from typing import Any, Dict, Optional, Literal, List, Union, Tuple
7
7
  from typing_extensions import Self
8
8
  from pydantic import BaseModel, Field, model_validator, computed_field
9
9
  from jinja2 import Environment
10
+ from graphlib import TopologicalSorter
10
11
  from .logging import logger
11
12
 
12
13
  class JsonSchemaProperty(BaseModel):
@@ -271,57 +272,60 @@ class SnowflakeViewColumn(BaseModel):
271
272
  )
272
273
 
273
274
  @classmethod
274
- def order_by_reference(cls,current_stream_name:str,columns:List[Self]) -> List[Self]:
275
+ def order_by_reference(cls, current_stream_name: str, columns: List[Self]) -> List[Self]:
275
276
  """
276
- In some situations, column expressions may reference the alias of another column
277
- This is allowed in Snowflake, as long as the aliased column is defined before it's used in a later column
278
- So we need to sort the columns so that if the name of the column appears (in quotes) in the expression of another column, it is ordered first
277
+ Uses topological sorting to order columns so that if a column references another column,
278
+ the referenced column appears first in the list. This is required by Snowflake when
279
+ column expressions reference the alias of another column.
280
+
281
+ OMNATA_ system columns are always placed at the front of the result.
279
282
  """
280
283
  logger.debug(
281
284
  f"Ordering columns by reference for stream: {current_stream_name} ({len(columns)} columns)"
282
285
  )
283
- # Collect columns to be moved
284
- columns_to_move:List[Self] = []
285
- # Collect Omnata System columns and keep them at the front
286
- omnata_system_columns_start = []
287
- for column in columns[:]:
288
- if column.original_name.startswith("OMNATA_"):
289
- columns.remove(column)
290
- omnata_system_columns_start.append(column)
291
-
286
+
287
+ # Separate OMNATA system columns - they always go first
288
+ omnata_system_columns = []
289
+ regular_columns = []
292
290
  for column in columns:
293
- for other_column in columns:
294
- if column==other_column:
295
- continue
296
- if column.original_name in (other_column.referenced_columns or {}).get(current_stream_name,[]):
297
- if column not in columns_to_move:
298
- logger.debug(
299
- f"Column {column.original_name} references {other_column.original_name}, moving it to the front"
300
- )
301
- columns_to_move.append(column)
302
- # we need to do another pass just on columns_to_move, because they may reference each other
303
- # if any do, they go to the front, otherwise they are appended
304
- columns_to_move_final:List[Self] = []
305
- for column in columns_to_move:
306
- for other_column in columns_to_move:
307
- if column==other_column:
308
- continue
309
- if column.original_name in (other_column.referenced_columns or {}).get(current_stream_name,[]):
310
- if column not in columns_to_move_final:
311
- logger.debug(
312
- f"Second pass: Column {column.original_name} is referenced by {other_column.original_name}, moving it to the front"
313
- )
314
- columns_to_move_final.insert(0, column)
315
- continue
316
- if column not in columns_to_move_final:
317
- columns_to_move_final.append(column)
291
+ if column.original_name.startswith("OMNATA_"):
292
+ omnata_system_columns.append(column)
293
+ else:
294
+ regular_columns.append(column)
318
295
 
319
- # Move collected columns to the front
320
- columns_to_move_final.reverse()
321
- for column in columns_to_move_final:
322
- columns.remove(column)
323
- columns.insert(0, column)
324
- return omnata_system_columns_start + columns
296
+ # Build dependency graph: column_name -> list of columns it depends on
297
+ # (i.e., columns that must appear BEFORE it in the final order)
298
+ graph: Dict[str, List[str]] = {}
299
+ column_by_name: Dict[str, Self] = {}
300
+
301
+ for column in regular_columns:
302
+ column_by_name[column.original_name] = column
303
+ # Initialize with empty dependencies
304
+ graph[column.original_name] = []
305
+
306
+ # Add dependencies from referenced_columns
307
+ if column.referenced_columns:
308
+ referenced_in_current_stream = column.referenced_columns.get(current_stream_name, [])
309
+ for ref_col_name in referenced_in_current_stream:
310
+ # This column depends on ref_col_name, so ref_col_name must come first
311
+ graph[column.original_name].append(ref_col_name)
312
+ logger.debug(
313
+ f"Column {column.original_name} depends on {ref_col_name}"
314
+ )
315
+
316
+ # Use TopologicalSorter to sort the columns
317
+ try:
318
+ ts = TopologicalSorter(graph)
319
+ sorted_column_names = list(ts.static_order())
320
+ except ValueError as e:
321
+ # This would indicate a circular dependency
322
+ raise ValueError(f"Circular dependency detected in column references for stream {current_stream_name}: {e}")
323
+
324
+ # Reconstruct the column list in topological order
325
+ sorted_columns = [column_by_name[name] for name in sorted_column_names if name in column_by_name]
326
+
327
+ # Return OMNATA system columns first, followed by sorted regular columns
328
+ return omnata_system_columns + sorted_columns
325
329
 
326
330
 
327
331
  class SnowflakeViewJoin(BaseModel):
@@ -470,10 +474,20 @@ class SnowflakeViewPart(BaseModel):
470
474
  c.name_with_comment(binding_list) for c in self.columns
471
475
  ]
472
476
 
473
- def cte_text(self,original_name: bool = False, include_only_columns:Optional[List[str]] = None) -> str:
477
+ def cte_text(self,original_name: bool = False,
478
+ include_only_columns:Optional[List[str]] = None,
479
+ include_extra_columns:Optional[List[str]] = None
480
+ ) -> str:
474
481
  """
475
482
  Returns the CTE text for this view part.
476
483
  """
484
+ if include_extra_columns is not None:
485
+ # includes direct columns plus any extra specified
486
+ return f""" "{self.stream_name}" as (
487
+ select {', '.join([c.definition(original_name=original_name,remove_stream_prefix=self.stream_name) for c in self.columns
488
+ if c.original_name in include_extra_columns or not c.is_join_column])}
489
+ from {self.raw_table_location.get_fully_qualified_name()}
490
+ ) """
477
491
  if include_only_columns is None:
478
492
  return f""" "{self.stream_name}" as (
479
493
  select {', '.join([c.definition(original_name=original_name,remove_stream_prefix=self.stream_name) for c in self.direct_columns()])}
@@ -504,6 +518,30 @@ class SnowflakeViewParts(BaseModel):
504
518
  ..., description="The other streams that are joined to the main stream"
505
519
  )
506
520
 
521
+ def column_indirectly_references_other_streams(
522
+ self,
523
+ all_view_parts:List[SnowflakeViewPart],
524
+ stream_name:str,column_name:str) -> bool:
525
+
526
+ for part in all_view_parts:
527
+ if part.stream_name == stream_name:
528
+ for col in part.columns:
529
+ if col.original_name == column_name:
530
+ if col.referenced_columns:
531
+ for ref_stream, ref_cols in col.referenced_columns.items():
532
+ if ref_stream != stream_name:
533
+ return True
534
+ else:
535
+ # we have to call this recursively in case the referenced column also references other streams
536
+ result = any(
537
+ self.column_indirectly_references_other_streams(
538
+ all_view_parts, ref_stream, ref_col
539
+ ) for ref_col in ref_cols
540
+ )
541
+ if result:
542
+ return True
543
+ return False
544
+
507
545
  def view_body(self):
508
546
  """
509
547
  Creates a view definition from the parts.
@@ -519,31 +557,55 @@ class SnowflakeViewParts(BaseModel):
519
557
 
520
558
  # first, we need to collapse all referenced columns into a single map
521
559
  all_referenced_columns:Dict[str,List[str]] = {}
560
+
561
+ # if a column references other columns, but there are no dependencies outside of its own stream, we can include those columns in the initial CTE for that stream
562
+ # because they can be calculated directly without needing joins
563
+ columns_only_referencing_own_stream:Dict[str,List[str]] = {}
564
+
522
565
  for part in [self.main_part] + self.joined_parts:
523
- # if the main part references any columns in this part in its joins, we need to include those columns
566
+ # if the main part references any columns in this part in its joins, we need to include those columns because they are used in the join condition
524
567
  aliases_for_stream = [j.join_stream_alias for j in self.main_part.joins
525
568
  if j.join_stream_name == part.stream_name]
526
569
  columns_used_in_joins = [
527
570
  j.left_column for j in self.main_part.joins if j.left_alias in aliases_for_stream
528
571
  ]
529
- if part.stream_name not in all_referenced_columns:
530
- all_referenced_columns[part.stream_name] = []
531
- all_referenced_columns[part.stream_name] += columns_used_in_joins
572
+ all_referenced_columns.setdefault(part.stream_name, []).extend(columns_used_in_joins)
573
+ # now, for each column in the part, if it references columns in other streams, we need to include those columns
532
574
  for column in part.columns:
533
575
  if column.referenced_columns:
534
576
  for stream_name, referenced_columns in column.referenced_columns.items():
535
- if stream_name not in all_referenced_columns:
536
- all_referenced_columns[stream_name] = []
537
- all_referenced_columns[stream_name] += referenced_columns
577
+ aliases_for_referenced_stream = [j.join_stream_name for j in self.main_part.joins
578
+ if j.join_stream_alias == stream_name]
579
+ all_referenced_columns.setdefault(stream_name, []).extend(referenced_columns)
580
+ # the stream name could be an alias, so we need to check if it's one of the aliases for this part
581
+ for stream_name_for_alias in aliases_for_referenced_stream:
582
+ all_referenced_columns.setdefault(stream_name_for_alias, []).extend(referenced_columns)
583
+ # populate columns_only_referencing_own_stream by following the chain of references until we reach a column that references another stream or has no references
584
+ if self.column_indirectly_references_other_streams(
585
+ [self.main_part] + self.joined_parts, part.stream_name, column.original_name
586
+ ) == False:
587
+ columns_only_referencing_own_stream.setdefault(part.stream_name, []).append(column.original_name)
588
+ else:
589
+ # if the column has no references, it can be included in the initial CTE for its own stream
590
+ # but only if no columns in other streams reference it
591
+ referenced_by_other_columns = False
592
+ for other_column in part.columns:
593
+ if other_column==column:
594
+ continue
595
+ if other_column.referenced_columns:
596
+ for ref_stream, ref_cols in other_column.referenced_columns.items():
597
+ if ref_stream != part.stream_name and column.original_name in ref_cols:
598
+ referenced_by_other_columns = True
599
+ break
600
+ if not referenced_by_other_columns:
601
+ columns_only_referencing_own_stream.setdefault(part.stream_name, []).append(column.original_name)
602
+ # if this part has joins to other streams, we need to include the join columns
538
603
  for join in part.joins:
539
- if join.join_stream_name not in all_referenced_columns:
540
- all_referenced_columns[join.join_stream_name] = []
541
- all_referenced_columns[join.join_stream_name].append(join.join_stream_column)
542
- all_referenced_columns[part.stream_name].append(join.left_column)
543
-
544
-
604
+ all_referenced_columns.setdefault(join.join_stream_name, []).append(join.join_stream_column)
605
+ all_referenced_columns.setdefault(join.join_stream_alias, []).append(join.join_stream_column)
606
+ all_referenced_columns.setdefault(part.stream_name, []).append(join.left_column)
545
607
  ctes = [
546
- self.main_part.cte_text(original_name=True)
608
+ self.main_part.cte_text(original_name=True,include_extra_columns=columns_only_referencing_own_stream.get(self.main_part.stream_name))
547
609
  ] + [
548
610
  part.cte_text(original_name=True,include_only_columns=all_referenced_columns.get(part.stream_name))
549
611
  for part in joined_parts_deduped
@@ -553,9 +615,9 @@ class SnowflakeViewParts(BaseModel):
553
615
  final_cte = f""" OMNATA_FINAL_CTE as (
554
616
  select {', '.join(
555
617
  [
556
- f'"{self.main_part.stream_name}"."{c.original_name}"' for c in self.main_part.direct_columns()
618
+ f'"{self.main_part.stream_name}"."{c.original_name}"' for c in self.main_part.columns if not c.is_join_column or c.original_name in columns_only_referencing_own_stream.get(self.main_part.stream_name,[])
557
619
  ]+[
558
- c.definition(original_name=True) for c in self.main_part.join_columns()
620
+ c.definition(original_name=True) for c in self.main_part.columns if c.is_join_column and c.original_name not in columns_only_referencing_own_stream.get(self.main_part.stream_name,[])
559
621
  ])}
560
622
  from "{self.main_part.stream_name}" """
561
623
  if len(self.main_part.joins) > 0:
@@ -602,9 +664,12 @@ class SnowflakeViewParts(BaseModel):
602
664
  )
603
665
  joined_parts:List[SnowflakeViewPart] = []
604
666
  # remove the joins from the main part if they are not in the raw stream locations
667
+ original_join_count = len(main_stream_view_part.joins)
605
668
  main_stream_view_part.joins = [join for join in main_stream_view_part.joins
606
669
  if join.join_stream_name in raw_stream_locations
607
670
  and join.join_stream_name in stream_schemas]
671
+ if len(main_stream_view_part.joins) < original_join_count:
672
+ logger.debug(f"Removed {original_join_count - len(main_stream_view_part.joins)} joins from stream: {stream_name} due to missing raw stream locations or schemas")
608
673
 
609
674
  for join in main_stream_view_part.joins:
610
675
  logger.debug(f"Generating view parts for join stream: {join.join_stream_name}")
@@ -617,6 +682,8 @@ class SnowflakeViewParts(BaseModel):
617
682
  column_name_expression=column_name_expression,
618
683
  plugin_app_database=plugin_app_database
619
684
  ))
685
+ if len(main_stream_view_part.joins) == 0:
686
+ logger.debug(f"No joins found for stream: {stream_name}")
620
687
  # For each column, the plugin can advise which fields (of the same stream or joined) are required for the join, which comes through as referenced_columns
621
688
  # on the SnowflakeViewColumn object.
622
689
  # Until this generate function is called with the raw stream names, we don't know which streams the user has actually selected, nor which
@@ -635,7 +702,8 @@ class SnowflakeViewParts(BaseModel):
635
702
 
636
703
  # Process all joins to build the mappings
637
704
  for part in [main_stream_view_part] + joined_parts:
638
- logger.debug(f"Processing joins for stream: {part.stream_name}")
705
+ joined_parts_names = [j.join_stream_name for j in part.joins]
706
+ logger.debug(f"Processing joins for stream: {part.stream_name} (joined streams: {joined_parts_names})")
639
707
  # Make sure the part's stream name is in the mappings
640
708
  if part.stream_name not in stream_to_aliases:
641
709
  stream_to_aliases[part.stream_name] = [part.stream_name]
@@ -745,105 +813,209 @@ class SnowflakeViewParts(BaseModel):
745
813
  # If we get here, no circular references were found
746
814
  logger.debug("No circular references found")
747
815
 
748
- # Now proceed with the actual pruning process
749
- # First, removing unavailable columns from other streams
750
- # then, we can do a final pass and remove columns that reference fields that are not available in the current stream
751
-
752
- # Now proceed with the actual pruning process
753
- # First, removing unavailable columns from other streams
754
- # then, we can do a final pass and remove columns that reference fields that are not available in the current stream
755
-
756
- prune_count = 0
757
- while prune(main_stream_view_part, joined_parts):
758
- prune_count += 1
759
- if prune_count > 10:
760
- raise ValueError("Pruning of columns from the view has entered an infinite loop")
816
+ # Prune columns using graph-based dependency resolution (single pass)
817
+ prune(main_stream_view_part, joined_parts)
761
818
 
762
819
  return cls(main_part=main_stream_view_part, joined_parts=joined_parts)
763
820
 
821
+
822
+ # Helper function to find a view part by stream name
823
+ def find_part(view_part: SnowflakeViewPart, joined_parts: List[SnowflakeViewPart], stream_name: str) -> Optional[SnowflakeViewPart]:
824
+ if stream_name == view_part.stream_name:
825
+ return view_part
826
+ for part in joined_parts:
827
+ if part.stream_name == stream_name:
828
+ return part
829
+ for join in view_part.joins:
830
+ if join.join_stream_alias == stream_name:
831
+ # this is the join, we need to find the actual stream
832
+ for part in joined_parts:
833
+ if part.stream_name == join.join_stream_name:
834
+ return part
835
+ logger.warning(
836
+ f"Join alias {stream_name} maps to stream {join.join_stream_name}, but that stream is not in the joined parts"
837
+ )
838
+ return None
839
+
764
840
  def prune(view_part: SnowflakeViewPart, joined_parts: List[SnowflakeViewPart]) -> bool:
765
841
  """
766
- Prunes columns from view parts that reference fields that don't exist in the referenced streams.
842
+ Prunes columns from view parts using graph-based dependency resolution.
767
843
 
768
- This function handles:
769
- 1. Direct dependencies - removing columns that directly reference non-existent columns
770
- 2. Transitive dependencies - removing columns that depend on columns that were removed
844
+ Uses TopologicalSorter to:
845
+ 1. Build a complete dependency graph of all columns across all parts
846
+ 2. Identify "root" columns that must be kept (in main part or used in joins)
847
+ 3. Traverse dependencies to find all transitively required columns
848
+ 4. Remove columns that aren't needed
771
849
 
772
850
  Returns True if any columns were removed, False otherwise.
773
- Raises ValueError if a cyclic dependency is detected.
774
851
  """
775
- columns_removed = False
776
852
 
777
- # Helper function to find a view part by stream name
778
- def find_part(stream_name: str) -> Optional[SnowflakeViewPart]:
779
- if stream_name == view_part.stream_name:
780
- return view_part
781
- return next((p for p in joined_parts if p.stream_name == stream_name), None)
782
-
783
- # Helper function to check if a column should be kept or removed
784
- def should_keep_column(column: SnowflakeViewColumn, part: SnowflakeViewPart) -> bool:
785
- """
786
- Checks if a column should be kept based on its dependencies.
787
- Returns True if the column should be kept, False if it should be removed.
788
- """
789
- # If no references, keep the column
790
- if not column.referenced_columns:
791
- return True
853
+ all_parts = [view_part] + joined_parts
854
+
855
+ # Build column registry: (stream_name, column_name) -> column object
856
+ all_columns: Dict[Tuple[str, str], SnowflakeViewColumn] = {}
857
+ for part in all_parts:
858
+ for column in part.columns:
859
+ all_columns[(part.stream_name, column.original_name)] = column
860
+
861
+ # Build dependency graph for topological analysis
862
+ # Key: (stream, column), Value: list of (stream, column) dependencies
863
+ # Also track columns with invalid dependencies (reference non-existent columns)
864
+ dependency_graph: Dict[Tuple[str, str], List[Tuple[str, str]]] = {}
865
+ columns_with_invalid_deps: set[Tuple[str, str]] = set()
866
+
867
+ # First pass: build dependency graph and detect direct invalid references
868
+ for part in all_parts:
869
+ for column in part.columns:
870
+ key = (part.stream_name, column.original_name)
871
+ deps = []
872
+ has_invalid_dep = False
792
873
 
793
- # Check each referenced stream and its fields
794
- for ref_stream_name, ref_fields in column.referenced_columns.items():
795
- # Find the referenced part
796
- ref_part = find_part(ref_stream_name)
874
+ if column.referenced_columns:
875
+ for ref_stream_name, ref_fields in column.referenced_columns.items():
876
+ # Resolve stream alias to actual stream name
877
+ resolved_stream = ref_stream_name
878
+ for join in view_part.joins:
879
+ if join.join_stream_alias == ref_stream_name:
880
+ resolved_stream = join.join_stream_name
881
+ break
882
+
883
+ for ref_field in ref_fields:
884
+ dep_key = (resolved_stream, ref_field)
885
+ if dep_key in all_columns:
886
+ deps.append(dep_key)
887
+ else:
888
+ logger.warning(
889
+ f"Column {column.original_name} in {part.stream_name} references "
890
+ f"{ref_field} in {resolved_stream}, which doesn't exist"
891
+ )
892
+ has_invalid_dep = True
797
893
 
798
- # If referenced stream doesn't exist, remove the column
799
- if ref_part is None:
800
- logger.warning(
801
- f"Column {column.name} in stream {part.stream_name} references stream "
802
- f"{ref_stream_name}, but it was not provided"
803
- )
804
- return False
805
-
806
- # Check each referenced field
807
- for ref_field in ref_fields:
808
- # Find the referenced column
809
- ref_column = next((c for c in ref_part.columns if c.original_name == ref_field), None)
894
+ dependency_graph[key] = deps
895
+ if has_invalid_dep:
896
+ columns_with_invalid_deps.add(key)
897
+
898
+ # Second pass: propagate invalidity to columns that depend on invalid columns
899
+ # Keep iterating until no new invalid columns are found
900
+ changed = True
901
+ while changed:
902
+ changed = False
903
+ for col_key, deps in dependency_graph.items():
904
+ if col_key not in columns_with_invalid_deps:
905
+ # Check if any dependency is invalid
906
+ for dep_key in deps:
907
+ if dep_key in columns_with_invalid_deps:
908
+ logger.warning(
909
+ f"Column {col_key[1]} in {col_key[0]} depends on "
910
+ f"{dep_key[1]} in {dep_key[0]}, which has invalid dependencies"
911
+ )
912
+ columns_with_invalid_deps.add(col_key)
913
+ changed = True
914
+ break
915
+
916
+ # Build alias to stream mapping
917
+ alias_to_stream: Dict[str, str] = {}
918
+ for part in all_parts:
919
+ alias_to_stream[part.stream_name] = part.stream_name
920
+ for join in part.joins:
921
+ alias_to_stream[join.join_stream_alias] = join.join_stream_name
922
+ # left_alias might be an alias for a joined stream, resolve it
923
+ if join.left_alias not in alias_to_stream:
924
+ # Try to find the stream for this alias
925
+ for other_part in all_parts:
926
+ if other_part.stream_name == join.left_alias:
927
+ alias_to_stream[join.left_alias] = other_part.stream_name
928
+ break
929
+
930
+ # Identify root columns that must be kept
931
+ needed_columns: set[Tuple[str, str]] = set()
932
+
933
+ # 1. All columns in the main part are needed (except those with invalid dependencies)
934
+ for column in view_part.columns:
935
+ col_key = (view_part.stream_name, column.original_name)
936
+ if col_key not in columns_with_invalid_deps:
937
+ needed_columns.add(col_key)
938
+
939
+ # 2. All columns used in join conditions are needed (except those with invalid dependencies)
940
+ for part in all_parts:
941
+ for join in part.joins:
942
+ # Resolve left_alias to actual stream name
943
+ left_stream = alias_to_stream.get(join.left_alias, join.left_alias)
944
+ left_key = (left_stream, join.left_column)
945
+ right_key = (join.join_stream_name, join.join_stream_column)
946
+ if left_key not in columns_with_invalid_deps:
947
+ needed_columns.add(left_key)
948
+ if right_key not in columns_with_invalid_deps:
949
+ needed_columns.add(right_key)
950
+
951
+ logger.debug(f"Identified {len(needed_columns)} root columns to keep (excluding {len(columns_with_invalid_deps)} with invalid deps)")
952
+
953
+ # 3. Find all transitive dependencies using recursive traversal
954
+ # Skip columns with invalid dependencies and their dependents
955
+ def collect_dependencies(col_key: Tuple[str, str], visited: set[Tuple[str, str]]) -> None:
956
+ """Recursively collect all columns that col_key depends on"""
957
+ if col_key in visited or col_key not in dependency_graph:
958
+ return
959
+ if col_key in columns_with_invalid_deps:
960
+ return # Don't traverse dependencies of invalid columns
961
+ visited.add(col_key)
962
+
963
+ for dep_key in dependency_graph[col_key]:
964
+ if dep_key in all_columns and dep_key not in columns_with_invalid_deps:
965
+ needed_columns.add(dep_key)
966
+ collect_dependencies(dep_key, visited)
967
+
968
+ visited_global: set[Tuple[str, str]] = set()
969
+ for root_col in list(needed_columns):
970
+ collect_dependencies(root_col, visited_global)
971
+
972
+ # Remove columns that are not needed
973
+ columns_removed = False
974
+ for part in all_parts:
975
+ original_count = len(part.columns)
976
+ removed_cols = [col for col in part.columns
977
+ if (part.stream_name, col.original_name) not in needed_columns]
978
+
979
+ # Log warnings for each removed column with the reason
980
+ for col in removed_cols:
981
+ # Determine why the column is being removed
982
+ col_key = (part.stream_name, col.original_name)
983
+ if col.referenced_columns:
984
+ # Check if any referenced columns don't exist
985
+ missing_refs = []
986
+ for ref_stream_name, ref_fields in col.referenced_columns.items():
987
+ resolved_stream = ref_stream_name
988
+ for join in view_part.joins:
989
+ if join.join_stream_alias == ref_stream_name:
990
+ resolved_stream = join.join_stream_name
991
+ break
992
+ for ref_field in ref_fields:
993
+ if (resolved_stream, ref_field) not in all_columns:
994
+ missing_refs.append(f"{ref_field} in {resolved_stream}")
810
995
 
811
- # If referenced column doesn't exist, remove the column
812
- if ref_column is None:
996
+ if missing_refs:
813
997
  logger.warning(
814
- f"Column {column.name} in stream {part.stream_name} references field "
815
- f"{ref_field} in stream {ref_stream_name}, but it was not provided"
998
+ f"Removing column {col.original_name} from {part.stream_name} because it references "
999
+ f"non-existent column(s): {', '.join(missing_refs)}"
816
1000
  )
817
- return False
818
-
819
- # All dependencies are satisfied
820
- return True
821
-
822
- # Process columns for removal
823
- for column in view_part.columns[:]: # Use a copy to allow safe removal
824
- if not should_keep_column(column, view_part):
825
- view_part.columns.remove(column)
1001
+ else:
1002
+ # Column is not needed (not referenced by main part)
1003
+ logger.debug(
1004
+ f"Removing column {col.original_name} from {part.stream_name} because it is not "
1005
+ f"referenced by the main part or any join conditions"
1006
+ )
1007
+ else:
1008
+ logger.debug(
1009
+ f"Removing column {col.original_name} from {part.stream_name} because it is not "
1010
+ f"referenced by the main part or any join conditions"
1011
+ )
1012
+
1013
+ part.columns = [col for col in part.columns
1014
+ if (part.stream_name, col.original_name) in needed_columns]
1015
+
1016
+ if removed_cols:
826
1017
  columns_removed = True
827
1018
 
828
- # Process joined parts
829
- for joined_part in joined_parts:
830
- # We have to avoid pruning columns that are referenced by joins to this stream.
831
- # first, we determine all aliases for this stream (multiple join paths back to the same stream are allowed)
832
- aliases_for_stream = [j.join_stream_alias for j in view_part.joins if j.join_stream_name == joined_part.stream_name]
833
- # now find all joins using this stream as the join stream
834
- columns_used_in_joins = [
835
- j.left_column for j in view_part.joins if j.left_alias in aliases_for_stream
836
- ]
837
- for column in joined_part.columns[:]: # Use a copy to allow safe removal
838
- # First check if the column is a join column
839
- if column.original_name in columns_used_in_joins:
840
- # If it's a join column, we need to keep it
841
- continue
842
-
843
- if not should_keep_column(column, joined_part):
844
- joined_part.columns.remove(column)
845
- columns_removed = True
846
-
847
1019
  return columns_removed
848
1020
 
849
1021
  class JsonSchemaTopLevel(BaseModel):
@@ -9,9 +9,10 @@ from typing import Dict, List, Optional
9
9
  from snowflake.snowpark import Session
10
10
  from pydantic import ValidationError
11
11
  from snowflake import telemetry
12
- from opentelemetry import trace
12
+ from opentelemetry import trace, metrics
13
13
 
14
14
  tracer = trace.get_tracer('omnata_plugin_runtime')
15
+ meter = metrics.get_meter('omnata_plugin_runtime')
15
16
 
16
17
  class CustomLoggerAdapter(logging.LoggerAdapter):
17
18
  """
@@ -48,7 +48,7 @@ from snowflake.snowpark import Session
48
48
  from snowflake.snowpark.functions import col
49
49
  from tenacity import Retrying, stop_after_attempt, wait_fixed, retry_if_exception_message
50
50
 
51
- from .logging import OmnataPluginLogHandler, logger, tracer
51
+ from .logging import OmnataPluginLogHandler, logger, tracer, meter
52
52
  from opentelemetry import context
53
53
  import math
54
54
  import numpy as np
@@ -1185,41 +1185,52 @@ class InboundSyncRequest(SyncRequest):
1185
1185
  query_id: Optional[str] = None
1186
1186
  ) -> str:
1187
1187
  binding_values = []
1188
- values_clauses = []
1188
+ select_clauses = []
1189
1189
 
1190
1190
  with self._snowflake_query_lock:
1191
1191
  if query_id is None:
1192
1192
  query_id = self._get_query_id_for_now()
1193
1193
  for stream_name, latest_state in stream_states_for_upload.items():
1194
1194
  binding_values.extend([stream_name, query_id, json.dumps(latest_state)])
1195
- values_clauses.append(
1196
- f"(?, ?, PARSE_JSON(?))"
1195
+ select_clauses.append(
1196
+ f"select ?, ?, PARSE_JSON(?)"
1197
1197
  )
1198
1198
  final_query = f"""INSERT INTO {self.state_register_table_name} (STREAM_NAME, QUERY_ID, STATE_VALUE)
1199
- VALUES {','.join(values_clauses)}"""
1199
+ {' union all '.join(select_clauses)}"""
1200
1200
  self._session.sql(final_query, binding_values).collect()
1201
+ streams_included = list(stream_states_for_upload.keys())
1202
+ logger.debug(f"Inserted state for streams: {streams_included} with query ID {query_id}")
1201
1203
 
1202
1204
  def apply_progress_updates(self, ignore_errors:bool = True):
1203
1205
  """
1204
1206
  Sends a message to the plugin with the current progress of the sync run, if it has changed since last time.
1205
1207
  """
1206
- if self._apply_results is not None:
1207
- with self._apply_results_lock:
1208
- new_progress_update = PluginMessageStreamProgressUpdate(
1209
- stream_total_counts=self._stream_record_counts,
1210
- # records could have been marked as completed, but still have results to apply
1211
- completed_streams=[s for s in self._completed_streams if s not in self._apply_results or self._apply_results[s] is None],
1212
- stream_errors=self._omnata_log_handler.stream_global_errors,
1213
- total_records_estimate=self._total_records_estimate
1214
- )
1215
- if self._last_stream_progress_update is None or new_progress_update != self._last_stream_progress_update:
1216
- result = self._plugin_message(
1217
- message=new_progress_update,
1218
- ignore_errors=ignore_errors
1208
+ with self._apply_results_lock:
1209
+ new_progress_update = PluginMessageStreamProgressUpdate(
1210
+ stream_total_counts=self._stream_record_counts,
1211
+ # records could have been marked as completed, but still have results to apply
1212
+ completed_streams=[s for s in self._completed_streams
1213
+ if s not in self._apply_results
1214
+ or self._apply_results[s] is None
1215
+ or len(self._apply_results[s]) == 0],
1216
+ stream_errors=self._omnata_log_handler.stream_global_errors,
1217
+ total_records_estimate=self._total_records_estimate
1219
1218
  )
1220
- if result is None:
1221
- return False
1222
- self._last_stream_progress_update = new_progress_update
1219
+ if self._last_stream_progress_update is None or new_progress_update != self._last_stream_progress_update:
1220
+ result = self._plugin_message(
1221
+ message=new_progress_update,
1222
+ ignore_errors=ignore_errors
1223
+ )
1224
+ if result is None:
1225
+ return False
1226
+ self._last_stream_progress_update = new_progress_update
1227
+ completed_streams_awaiting_results_upload = [
1228
+ s for s in self._completed_streams if s in self._apply_results and self._apply_results[s] is not None
1229
+ ]
1230
+ if len(completed_streams_awaiting_results_upload) > 0:
1231
+ logger.debug(
1232
+ f"Streams marked as completed but awaiting upload: {', '.join(completed_streams_awaiting_results_upload)}"
1233
+ )
1223
1234
  return True
1224
1235
 
1225
1236
  def apply_cancellation(self):
@@ -1286,7 +1297,7 @@ class InboundSyncRequest(SyncRequest):
1286
1297
  # if the total exceeds 200MB, we apply the results immediately
1287
1298
  all_df_lists:List[List[pandas.DataFrame]] = list(self._apply_results.values())
1288
1299
  # flatten
1289
- all_dfs:List[pandas.DataFrame] = [x for sublist in all_df_lists for x in sublist]
1300
+ all_dfs:List[pandas.DataFrame] = [x for sublist in all_df_lists for x in sublist if isinstance(x, pandas.DataFrame)]
1290
1301
  combined_length = sum([len(x) for x in all_dfs])
1291
1302
  # first, don't bother if the count is less than 10000, since it's unlikely to be even close
1292
1303
  if combined_length > 10000:
@@ -1336,7 +1347,7 @@ class InboundSyncRequest(SyncRequest):
1336
1347
  combined_length = sum([len(x) for x in all_dfs])
1337
1348
  # first, don't both if the count is less than 10000, since it's unlikely to be even close
1338
1349
  if combined_length > 10000:
1339
- if sum([x.memory_usage(index=True).sum() for x in all_dfs]) > 200000000:
1350
+ if sum([x.memory_usage(index=True).sum() for x in all_dfs if isinstance(x, pandas.DataFrame)]) > 200000000:
1340
1351
  logger.debug(f"Applying criteria deletes queue immediately due to combined dataframe size")
1341
1352
  self.apply_results_queue()
1342
1353
 
@@ -1345,9 +1356,11 @@ class InboundSyncRequest(SyncRequest):
1345
1356
  Marks a stream as completed, this is called automatically per stream when using @managed_inbound_processing.
1346
1357
  If @managed_inbound_processing is not used, call this whenever a stream has finished recieving records.
1347
1358
  """
1348
- self._completed_streams.append(stream_name)
1349
- # dedup just in case it's called twice
1350
- self._completed_streams = list(set(self._completed_streams))
1359
+ logger.debug(f"Marking stream {stream_name} as completed locally")
1360
+ with self._apply_results_lock:
1361
+ self._completed_streams.append(stream_name)
1362
+ # dedup just in case it's called twice
1363
+ self._completed_streams = list(set(self._completed_streams))
1351
1364
 
1352
1365
  def set_stream_record_count(self, stream_name: str, count: int):
1353
1366
  """
@@ -1845,6 +1858,40 @@ class OmnataPlugin(ABC):
1845
1858
  raise NotImplementedError(
1846
1859
  "Your plugin class must implement the inbound_configuration_form method"
1847
1860
  )
1861
+
1862
+ def outbound_tuning_parameters(
1863
+ self, parameters: OutboundSyncConfigurationParameters
1864
+ ) -> OutboundSyncConfigurationForm:
1865
+ """
1866
+ Returns the form definition for declaring outbound tuning parameters.
1867
+
1868
+ The returned form should consist of static fields with default values that represent the
1869
+ plugin's recommended runtime behaviour. This form is optional and is only rendered when a
1870
+ user opts to override those defaults at sync runtime, so it must be safe to fall back to the
1871
+ provided defaults when no tuning parameters are configured.
1872
+
1873
+ :param OutboundSyncConfigurationParameters parameters the current outbound configuration
1874
+ :return: An OutboundSyncConfigurationForm describing the available tuning parameters
1875
+ :rtype: OutboundSyncConfigurationForm
1876
+ """
1877
+ return OutboundSyncConfigurationForm(fields=[])
1878
+
1879
+ def inbound_tuning_parameters(
1880
+ self, parameters: InboundSyncConfigurationParameters
1881
+ ) -> InboundSyncConfigurationForm:
1882
+ """
1883
+ Returns the form definition for declaring inbound tuning parameters.
1884
+
1885
+ The returned form should consist of static fields with default values that represent the
1886
+ plugin's recommended runtime behaviour. This form is optional and is only rendered when a
1887
+ user opts to override those defaults at sync runtime, so it must be safe to fall back to the
1888
+ provided defaults when no tuning parameters are configured.
1889
+
1890
+ :param InboundSyncConfigurationParameters parameters the current inbound configuration
1891
+ :return: An InboundSyncConfigurationForm describing the available tuning parameters
1892
+ :rtype: InboundSyncConfigurationForm
1893
+ """
1894
+ return InboundSyncConfigurationForm(fields=[])
1848
1895
 
1849
1896
  def inbound_stream_list(
1850
1897
  self, parameters: InboundSyncConfigurationParameters
@@ -2283,6 +2330,15 @@ def __managed_inbound_processing_worker(
2283
2330
  try:
2284
2331
  stream: StoredStreamConfiguration = streams_queue.get_nowait()
2285
2332
  logger.debug(f"stream returned from queue: {stream}")
2333
+ sync_request: InboundSyncRequest = cast(
2334
+ InboundSyncRequest, plugin_class_obj._sync_request
2335
+ ) # pylint: disable=protected-access
2336
+ stream_duration_gauge = meter.create_gauge(
2337
+ name="omnata.sync_run.stream_duration",
2338
+ description="The duration of stream processing",
2339
+ unit="s",
2340
+ )
2341
+ start_time = time.time()
2286
2342
  # restore the first argument, was originally the dataframe/generator but now it's the appropriately sized dataframe
2287
2343
  try:
2288
2344
  with tracer.start_as_current_span("managed_inbound_processing") as managed_inbound_processing_span:
@@ -2294,7 +2350,7 @@ def __managed_inbound_processing_worker(
2294
2350
  logger.info(f"worker {worker_index} requested that {stream.stream_name} be not marked as complete")
2295
2351
  else:
2296
2352
  logger.info(f"worker {worker_index} marking stream {stream.stream_name} as complete")
2297
- plugin_class_obj._sync_request.mark_stream_complete(stream.stream_name)
2353
+ sync_request.mark_stream_complete(stream.stream_name)
2298
2354
  except InterruptedWhileWaitingException:
2299
2355
  # If an inbound run is cancelled while waiting for rate limiting, this should mean that
2300
2356
  # the cancellation is handled elsewhere, so we don't need to do anything special here other than stop waiting
@@ -2314,6 +2370,19 @@ def __managed_inbound_processing_worker(
2314
2370
  omnata_plugin_logger.error(f"{type(e).__name__} syncing stream {stream.stream_name}",
2315
2371
  exc_info=True,
2316
2372
  extra={'stream_name':stream.stream_name})
2373
+ finally:
2374
+ duration = time.time() - start_time
2375
+ stream_duration_gauge.set(
2376
+ amount=duration,
2377
+ attributes={
2378
+ "stream_name": stream.stream_name,
2379
+ "sync_run_id": str(sync_request._run_id),
2380
+ "sync_id": str(sync_request._sync_id),
2381
+ "branch_name": str(sync_request._branch_name) if sync_request._branch_name is not None else 'main',
2382
+ "sync_direction": "inbound",
2383
+ "plugin_id": plugin_class_obj.get_manifest().plugin_id,
2384
+ },
2385
+ )
2317
2386
  except queue.Empty:
2318
2387
  logger.debug("streams queue is empty")
2319
2388
  return
@@ -282,6 +282,11 @@ class PluginEntrypoint:
282
282
  # token is set. We throw it here as an error since that's currently how it flows back to the engine with a DELAYED state
283
283
  raise DeadlineReachedException()
284
284
  finally:
285
+ # try to upload any remaining results
286
+ try:
287
+ inbound_sync_request.apply_results_queue()
288
+ except Exception as e:
289
+ logger.warning(f"Error uploading remaining results: {str(e)}", exc_info=True)
285
290
  # cancel the thread so we don't leave anything hanging around and cop a nasty error
286
291
  try:
287
292
  inbound_sync_request._thread_cancellation_token.set() # pylint: disable=protected-access