snowpark-connect 0.25.0__py3-none-any.whl → 0.26.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (32) hide show
  1. snowflake/snowpark_connect/config.py +10 -0
  2. snowflake/snowpark_connect/dataframe_container.py +16 -0
  3. snowflake/snowpark_connect/expression/map_udf.py +68 -27
  4. snowflake/snowpark_connect/expression/map_unresolved_function.py +22 -21
  5. snowflake/snowpark_connect/includes/jars/json4s-native_2.12-3.7.0-M11.jar +0 -0
  6. snowflake/snowpark_connect/includes/jars/paranamer-2.8.3.jar +0 -0
  7. snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
  8. snowflake/snowpark_connect/relation/map_map_partitions.py +9 -4
  9. snowflake/snowpark_connect/relation/map_relation.py +12 -1
  10. snowflake/snowpark_connect/relation/map_row_ops.py +8 -1
  11. snowflake/snowpark_connect/relation/map_udtf.py +96 -44
  12. snowflake/snowpark_connect/relation/utils.py +44 -0
  13. snowflake/snowpark_connect/relation/write/map_write.py +113 -22
  14. snowflake/snowpark_connect/resources_initializer.py +18 -5
  15. snowflake/snowpark_connect/server.py +8 -1
  16. snowflake/snowpark_connect/utils/concurrent.py +4 -0
  17. snowflake/snowpark_connect/utils/external_udxf_cache.py +36 -0
  18. snowflake/snowpark_connect/utils/scala_udf_utils.py +250 -242
  19. snowflake/snowpark_connect/utils/session.py +4 -0
  20. snowflake/snowpark_connect/utils/udf_utils.py +7 -17
  21. snowflake/snowpark_connect/utils/udtf_utils.py +3 -16
  22. snowflake/snowpark_connect/version.py +1 -1
  23. {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.26.0.dist-info}/METADATA +1 -1
  24. {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.26.0.dist-info}/RECORD +32 -28
  25. {snowpark_connect-0.25.0.data → snowpark_connect-0.26.0.data}/scripts/snowpark-connect +0 -0
  26. {snowpark_connect-0.25.0.data → snowpark_connect-0.26.0.data}/scripts/snowpark-session +0 -0
  27. {snowpark_connect-0.25.0.data → snowpark_connect-0.26.0.data}/scripts/snowpark-submit +0 -0
  28. {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.26.0.dist-info}/WHEEL +0 -0
  29. {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.26.0.dist-info}/licenses/LICENSE-binary +0 -0
  30. {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.26.0.dist-info}/licenses/LICENSE.txt +0 -0
  31. {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.26.0.dist-info}/licenses/NOTICE-binary +0 -0
  32. {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.26.0.dist-info}/top_level.txt +0 -0
@@ -22,6 +22,9 @@ from snowflake.snowpark.exceptions import SnowparkSQLException
22
22
  from snowflake.snowpark.types import TimestampTimeZone, TimestampType
23
23
  from snowflake.snowpark_connect.utils.concurrent import SynchronizedDict
24
24
  from snowflake.snowpark_connect.utils.context import get_session_id
25
+ from snowflake.snowpark_connect.utils.external_udxf_cache import (
26
+ clear_external_udxf_cache,
27
+ )
25
28
  from snowflake.snowpark_connect.utils.session import get_or_create_snowpark_session
26
29
  from snowflake.snowpark_connect.utils.snowpark_connect_logging import logger
27
30
  from snowflake.snowpark_connect.utils.telemetry import (
@@ -136,6 +139,9 @@ class GlobalConfig:
136
139
  "spark.sql.parser.quotedRegexColumnNames": "false",
137
140
  # custom configs
138
141
  "snowpark.connect.version": ".".join(map(str, sas_version)),
142
+ # Control whether repartition(n) on a DataFrame forces splitting into n files during writes
143
+ # This matches spark behavior more closely, but introduces overhead.
144
+ "snowflake.repartition.for.writes": "false",
139
145
  }
140
146
 
141
147
  boolean_config_list = [
@@ -148,6 +154,7 @@ class GlobalConfig:
148
154
  "spark.sql.legacy.allowHashOnMapType",
149
155
  "spark.Catalog.databaseFilterInformationSchema",
150
156
  "spark.sql.parser.quotedRegexColumnNames",
157
+ "snowflake.repartition.for.writes",
151
158
  ]
152
159
 
153
160
  int_config_list = [
@@ -592,6 +599,9 @@ def parse_imports(session: snowpark.Session, imports: str | None) -> None:
592
599
  if not imports:
593
600
  return
594
601
 
602
+ # UDF needs to be recreated to include new imports
603
+ clear_external_udxf_cache(session)
604
+
595
605
  for udf_import in imports.strip("[] ").split(","):
596
606
  session.add_import(udf_import)
597
607
 
@@ -29,6 +29,7 @@ class DataFrameContainer:
29
29
  table_name: str | None = None,
30
30
  alias: str | None = None,
31
31
  cached_schema_getter: Callable[[], StructType] | None = None,
32
+ partition_hint: int | None = None,
32
33
  ) -> None:
33
34
  """
34
35
  Initialize a new DataFrameContainer.
@@ -39,11 +40,13 @@ class DataFrameContainer:
39
40
  table_name: Optional table name for the DataFrame
40
41
  alias: Optional alias for the DataFrame
41
42
  cached_schema_getter: Optional function to get cached schema
43
+ partition_hint: Optional partition count from repartition() operations
42
44
  """
43
45
  self._dataframe = dataframe
44
46
  self._column_map = self._create_default_column_map(column_map)
45
47
  self._table_name = table_name
46
48
  self._alias = alias
49
+ self._partition_hint = partition_hint
47
50
 
48
51
  if cached_schema_getter is not None:
49
52
  self._apply_cached_schema_getter(cached_schema_getter)
@@ -62,6 +65,7 @@ class DataFrameContainer:
62
65
  table_name: str | None = None,
63
66
  alias: str | None = None,
64
67
  cached_schema_getter: Callable[[], StructType] | None = None,
68
+ partition_hint: int | None = None,
65
69
  ) -> DataFrameContainer:
66
70
  """
67
71
  Create a new container with complete column mapping configuration.
@@ -78,6 +82,7 @@ class DataFrameContainer:
78
82
  table_name: Optional table name
79
83
  alias: Optional alias
80
84
  cached_schema_getter: Optional function to get cached schema
85
+ partition_hint: Optional partition count from repartition() operations
81
86
 
82
87
  Returns:
83
88
  A new DataFrameContainer instance
@@ -123,6 +128,7 @@ class DataFrameContainer:
123
128
  table_name=table_name,
124
129
  alias=alias,
125
130
  cached_schema_getter=final_schema_getter,
131
+ partition_hint=partition_hint,
126
132
  )
127
133
 
128
134
  @property
@@ -163,6 +169,16 @@ class DataFrameContainer:
163
169
  """Set the alias name."""
164
170
  self._alias = value
165
171
 
172
+ @property
173
+ def partition_hint(self) -> int | None:
174
+ """Get the partition hint count."""
175
+ return self._partition_hint
176
+
177
+ @partition_hint.setter
178
+ def partition_hint(self, value: int | None) -> None:
179
+ """Set the partition hint count."""
180
+ self._partition_hint = value
181
+
166
182
  def _create_default_column_map(
167
183
  self, column_map: ColumnNameMap | None
168
184
  ) -> ColumnNameMap:
@@ -13,6 +13,10 @@ from snowflake.snowpark_connect.config import global_config
13
13
  from snowflake.snowpark_connect.expression.typer import ExpressionTyper
14
14
  from snowflake.snowpark_connect.type_mapping import proto_to_snowpark_type
15
15
  from snowflake.snowpark_connect.typed_column import TypedColumn
16
+ from snowflake.snowpark_connect.utils.external_udxf_cache import (
17
+ cache_external_udf,
18
+ get_external_udf_from_cache,
19
+ )
16
20
  from snowflake.snowpark_connect.utils.session import get_or_create_snowpark_session
17
21
  from snowflake.snowpark_connect.utils.udf_helper import (
18
22
  SnowparkUDF,
@@ -30,6 +34,39 @@ from snowflake.snowpark_connect.utils.udxf_import_utils import (
30
34
  )
31
35
 
32
36
 
37
+ def cache_external_udf_wrapper(from_register_udf: bool):
38
+ def outer_wrapper(wrapper_func):
39
+ def wrapper(
40
+ udf_proto: expressions_proto.CommonInlineUserDefinedFunction,
41
+ ) -> SnowparkUDF | None:
42
+ udf_hash = hash(str(udf_proto))
43
+ cached_udf = get_external_udf_from_cache(udf_hash)
44
+
45
+ if cached_udf:
46
+ session = get_or_create_snowpark_session()
47
+ function_type = udf_proto.WhichOneof("function")
48
+ # TODO: Align this with SNOW-2316798 after merge
49
+ match function_type:
50
+ case "scalar_scala_udf":
51
+ session._udfs[cached_udf.name] = cached_udf
52
+ case "python_udf" if from_register_udf:
53
+ session._udfs[udf_proto.function_name.lower()] = cached_udf
54
+ case "python_udf":
55
+ pass
56
+ case _:
57
+ raise ValueError(f"Unsupported UDF type: {function_type}")
58
+
59
+ return cached_udf
60
+
61
+ snowpark_udf = wrapper_func(udf_proto)
62
+ cache_external_udf(udf_hash, snowpark_udf)
63
+ return snowpark_udf
64
+
65
+ return wrapper
66
+
67
+ return outer_wrapper
68
+
69
+
33
70
  def process_udf_return_type(
34
71
  return_type: types_proto.DataType,
35
72
  ) -> tuple[snowpark.types.DataType, snowpark.types.DataType]:
@@ -49,6 +86,7 @@ def process_udf_return_type(
49
86
  return original_snowpark_type, original_snowpark_type
50
87
 
51
88
 
89
+ @cache_external_udf_wrapper(from_register_udf=True)
52
90
  def register_udf(
53
91
  udf_proto: expressions_proto.CommonInlineUserDefinedFunction,
54
92
  ) -> SnowparkUDF:
@@ -84,12 +122,10 @@ def register_udf(
84
122
  return_type=udf._return_type,
85
123
  original_return_type=original_return_type,
86
124
  )
87
- # the create udf does register the udf but this seems to be for the client side check
88
- # TODO: check if this is needed
125
+ session._udfs[udf_proto.function_name.lower()] = udf
126
+ # scala udfs can be also accessed using `udf.name`
89
127
  if udf_processor._function_type == "scalar_scala_udf":
90
128
  session._udfs[udf.name] = udf
91
- else:
92
- session._udfs[udf_proto.function_name.lower()] = udf
93
129
  return udf
94
130
 
95
131
 
@@ -114,29 +150,34 @@ def map_common_inline_user_defined_udf(
114
150
  udf_proto.scalar_scala_udf.outputType
115
151
  )
116
152
 
117
- session = get_or_create_snowpark_session()
118
- kwargs = {
119
- "common_inline_user_defined_function": udf_proto,
120
- "input_types": input_types,
121
- "called_from": "map_common_inline_user_defined_udf",
122
- "return_type": processed_return_type,
123
- "udf_packages": global_config.get("snowpark.connect.udf.packages", ""),
124
- "udf_imports": get_python_udxf_import_files(session),
125
- "original_return_type": original_return_type,
126
- }
127
- if require_creating_udf_in_sproc(udf_proto):
128
- snowpark_udf = process_udf_in_sproc(**kwargs)
129
- else:
130
- udf_processor = ProcessCommonInlineUserDefinedFunction(**kwargs)
131
- udf = udf_processor.create_udf()
132
- snowpark_udf = SnowparkUDF(
133
- name=udf.name,
134
- input_types=udf._input_types,
135
- return_type=udf._return_type,
136
- original_return_type=original_return_type,
137
- )
138
- if udf_processor._function_type == "scalar_scala_udf":
139
- session._udfs[udf.name] = snowpark_udf
153
+ @cache_external_udf_wrapper(from_register_udf=False)
154
+ def get_snowpark_udf(
155
+ udf_proto: expressions_proto.CommonInlineUserDefinedFunction,
156
+ ) -> SnowparkUDF:
157
+ session = get_or_create_snowpark_session()
158
+ kwargs = {
159
+ "common_inline_user_defined_function": udf_proto,
160
+ "input_types": input_types,
161
+ "called_from": "map_common_inline_user_defined_udf",
162
+ "return_type": processed_return_type,
163
+ "udf_packages": global_config.get("snowpark.connect.udf.packages", ""),
164
+ "udf_imports": get_python_udxf_import_files(session),
165
+ "original_return_type": original_return_type,
166
+ }
167
+ if require_creating_udf_in_sproc(udf_proto):
168
+ snowpark_udf = process_udf_in_sproc(**kwargs)
169
+ else:
170
+ udf_processor = ProcessCommonInlineUserDefinedFunction(**kwargs)
171
+ udf = udf_processor.create_udf()
172
+ snowpark_udf = SnowparkUDF(
173
+ name=udf.name,
174
+ input_types=udf._input_types,
175
+ return_type=udf._return_type,
176
+ original_return_type=original_return_type,
177
+ )
178
+ return snowpark_udf
179
+
180
+ snowpark_udf = get_snowpark_udf(udf_proto)
140
181
  udf_call_expr = snowpark_fn.call_udf(snowpark_udf.name, *snowpark_udf_args)
141
182
 
142
183
  # If the original return type was MapType or StructType but we converted it to VariantType,
@@ -476,11 +476,8 @@ def map_unresolved_function(
476
476
  return TypedColumn(result, lambda: expected_types)
477
477
 
478
478
  match function_name:
479
- case func_name if (
480
- get_is_evaluating_sql() and func_name.lower() in session._udfs
481
- ):
482
- # TODO: In Spark, UDFs can override built-in functions in SQL,
483
- # but not in DataFrame ops.
479
+ case func_name if func_name.lower() in session._udfs:
480
+ # In Spark, UDFs can override built-in functions
484
481
  udf = session._udfs[func_name.lower()]
485
482
  result_exp = snowpark_fn.call_udf(
486
483
  udf.name,
@@ -6479,6 +6476,18 @@ def map_unresolved_function(
6479
6476
  if pattern_value is None:
6480
6477
  return snowpark_fn.lit(None)
6481
6478
 
6479
+ # Optimization: treat escaped regex that resolves to a pure literal delimiter
6480
+ # - Single char: "\\."
6481
+ # - Multi char: e.g., "\\.505\\."
6482
+ if re.fullmatch(r"(?:\\.)+", pattern_value):
6483
+ literal_delim = re.sub(r"\\(.)", r"\1", pattern_value)
6484
+ return snowpark_fn.when(
6485
+ limit <= 0,
6486
+ snowpark_fn.split(
6487
+ str_, snowpark_fn.lit(literal_delim)
6488
+ ).cast(result_type),
6489
+ ).otherwise(native_split)
6490
+
6482
6491
  is_regexp = re.match(
6483
6492
  ".*[\\[\\.\\]\\*\\?\\+\\^\\$\\{\\}\\|\\(\\)\\\\].*",
6484
6493
  pattern_value,
@@ -8285,15 +8294,6 @@ def map_unresolved_function(
8285
8294
  ),
8286
8295
  )
8287
8296
  result_type = BinaryType()
8288
- case udf_name if udf_name.lower() in session._udfs:
8289
- # TODO: In Spark, UDFs can override built-in functions in SQL,
8290
- # but not in DataFrame ops.
8291
- udf = session._udfs[udf_name.lower()]
8292
- result_exp = snowpark_fn.call_udf(
8293
- udf.name,
8294
- *(snowpark_fn.cast(arg, VariantType()) for arg in snowpark_args),
8295
- )
8296
- result_type = udf.return_type
8297
8297
  case udtf_name if udtf_name.lower() in session._udtfs:
8298
8298
  udtf, spark_col_names = session._udtfs[udtf_name.lower()]
8299
8299
  result_exp = snowpark_fn.call_table_function(
@@ -9623,13 +9623,14 @@ def _get_decimal_division_result_exp(
9623
9623
  snowpark_args: list[Column],
9624
9624
  spark_function_name: str,
9625
9625
  ) -> Column:
9626
- if isinstance(other_type, DecimalType) and overflow_detected:
9627
- if global_config.spark_sql_ansi_enabled:
9628
- raise ArithmeticException(
9629
- f'[NUMERIC_VALUE_OUT_OF_RANGE] {spark_function_name} cannot be represented as Decimal({result_type.precision}, {result_type.scale}). If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error, and return NULL instead.'
9630
- )
9631
- else:
9632
- result_exp = snowpark_fn.lit(None).cast(result_type)
9626
+ if (
9627
+ isinstance(other_type, DecimalType)
9628
+ and overflow_detected
9629
+ and global_config.spark_sql_ansi_enabled
9630
+ ):
9631
+ raise ArithmeticException(
9632
+ f'[NUMERIC_VALUE_OUT_OF_RANGE] {spark_function_name} cannot be represented as Decimal({result_type.precision}, {result_type.scale}). If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error, and return NULL instead.'
9633
+ )
9633
9634
  else:
9634
9635
  dividend = snowpark_args[0].cast(DoubleType())
9635
9636
  divisor = snowpark_args[1]
@@ -46,9 +46,10 @@ def map_map_partitions(
46
46
  udf_check(udf_proto)
47
47
 
48
48
  # Check if this is mapInArrow (eval_type == 207)
49
- eval_type = udf_proto.python_udf.eval_type
50
-
51
- if eval_type == MAP_IN_ARROW_EVAL_TYPE:
49
+ if (
50
+ udf_proto.WhichOneof("function") == "python_udf"
51
+ and udf_proto.python_udf.eval_type == MAP_IN_ARROW_EVAL_TYPE
52
+ ):
52
53
  return _map_in_arrow_with_pandas_udtf(input_container, udf_proto)
53
54
  else:
54
55
  return _map_partitions_with_udf(input_df, udf_proto)
@@ -126,7 +127,11 @@ def _map_partitions_with_udf(
126
127
  "udf_name": "spark_map_partitions_udf",
127
128
  "input_column_names": input_column_names,
128
129
  "replace": True,
129
- "return_type": proto_to_snowpark_type(udf_proto.python_udf.output_type),
130
+ "return_type": proto_to_snowpark_type(
131
+ udf_proto.python_udf.output_type
132
+ if udf_proto.WhichOneof("function") == "python_udf"
133
+ else udf_proto.scalar_scala_udf.outputType
134
+ ),
130
135
  "udf_packages": global_config.get("snowpark.connect.udf.packages", ""),
131
136
  "udf_imports": get_python_udxf_import_files(input_df.session),
132
137
  }
@@ -90,6 +90,7 @@ def map_relation(
90
90
  table_name=copy.deepcopy(cached_container.table_name),
91
91
  alias=cached_container.alias,
92
92
  cached_schema_getter=lambda: cached_df.schema,
93
+ partition_hint=cached_container.partition_hint,
93
94
  )
94
95
  # If we don't make a copy of the df._output, the expression IDs for attributes in Snowpark DataFrames will differ from those stored in the cache,
95
96
  # leading to errors during query execution.
@@ -189,13 +190,23 @@ def map_relation(
189
190
  case "read":
190
191
  result = read.map_read(rel)
191
192
  case "repartition":
192
- # TODO: Snowpark df identity transform with annotation
193
+ # Preserve partition hint for file output control
194
+ # This handles both repartition(n) with shuffle=True and coalesce(n) with shuffle=False
193
195
  result = map_relation(rel.repartition.input)
196
+ if rel.repartition.num_partitions > 0:
197
+ result.partition_hint = rel.repartition.num_partitions
194
198
  case "repartition_by_expression":
195
199
  # This is a no-op operation in SAS as Snowpark doesn't have the concept of partitions.
196
200
  # All the data in the dataframe will be treated as a single partition, and this will not
197
201
  # have any side effects.
198
202
  result = map_relation(rel.repartition_by_expression.input)
203
+ # Only preserve partition hint if num_partitions is explicitly specified and > 0
204
+ # Column-based repartitioning without count should clear any existing partition hints
205
+ if rel.repartition_by_expression.num_partitions > 0:
206
+ result.partition_hint = rel.repartition_by_expression.num_partitions
207
+ else:
208
+ # Column-based repartitioning clears partition hint (resets to default behavior)
209
+ result.partition_hint = None
199
210
  case "replace":
200
211
  result = map_row_ops.map_replace(rel)
201
212
  case "sample":
@@ -553,7 +553,14 @@ def map_filter(
553
553
  rel.filter.condition, input_container.column_map, typer
554
554
  )
555
555
 
556
- result = input_df.filter(condition.col)
556
+ if rel.filter.input.WhichOneof("rel_type") == "subquery_alias":
557
+ # map_subquery_alias does not actually wrap the DataFrame in an alias or subquery.
558
+ # Apparently, there are cases (e.g., TpcdsQ53) where this is required, without it, we get
559
+ # SQL compilation error.
560
+ # To mitigate it, we are doing .select("*"), .alias() introduces additional describe queries
561
+ result = input_df.select("*").filter(condition.col)
562
+ else:
563
+ result = input_df.filter(condition.col)
557
564
 
558
565
  return DataFrameContainer(
559
566
  result,
@@ -31,6 +31,10 @@ from snowflake.snowpark_connect.type_mapping import (
31
31
  proto_to_snowpark_type,
32
32
  )
33
33
  from snowflake.snowpark_connect.utils.context import push_udtf_context
34
+ from snowflake.snowpark_connect.utils.external_udxf_cache import (
35
+ cache_external_udtf,
36
+ get_external_udtf_from_cache,
37
+ )
34
38
  from snowflake.snowpark_connect.utils.session import get_or_create_snowpark_session
35
39
  from snowflake.snowpark_connect.utils.udtf_helper import (
36
40
  SnowparkUDTF,
@@ -44,6 +48,34 @@ from snowflake.snowpark_connect.utils.udxf_import_utils import (
44
48
  )
45
49
 
46
50
 
51
+ def cache_external_udtf_wrapper(from_register_udtf: bool):
52
+ def outer_wrapper(wrapper_func):
53
+ def wrapper(
54
+ udtf_proto: relation_proto.CommonInlineUserDefinedTableFunction,
55
+ spark_column_names,
56
+ ) -> SnowparkUDTF | None:
57
+ udf_hash = hash(str(udtf_proto))
58
+ cached_udtf = get_external_udtf_from_cache(udf_hash)
59
+
60
+ if cached_udtf:
61
+ if from_register_udtf:
62
+ session = get_or_create_snowpark_session()
63
+ session._udtfs[udtf_proto.function_name.lower()] = (
64
+ cached_udtf,
65
+ spark_column_names,
66
+ )
67
+
68
+ return cached_udtf
69
+
70
+ snowpark_udf = wrapper_func(udtf_proto, spark_column_names)
71
+ cache_external_udtf(udf_hash, snowpark_udf)
72
+ return snowpark_udf
73
+
74
+ return wrapper
75
+
76
+ return outer_wrapper
77
+
78
+
47
79
  def build_expected_types_from_parsed(
48
80
  parsed_return: types_proto.DataType,
49
81
  ) -> List[Tuple[str, Any]]:
@@ -165,26 +197,37 @@ def register_udtf(
165
197
  ) = process_return_type(python_udft.return_type)
166
198
  function_name = udtf_proto.function_name
167
199
 
168
- kwargs = {
169
- "session": session,
170
- "udtf_proto": udtf_proto,
171
- "expected_types": expected_types,
172
- "output_schema": output_schema,
173
- "packages": global_config.get("snowpark.connect.udf.packages", ""),
174
- "imports": get_python_udxf_import_files(session),
175
- "called_from": "register_udtf",
176
- "is_arrow_enabled": is_arrow_enabled_in_udtf(),
177
- "is_spark_compatible_udtf_mode_enabled": is_spark_compatible_udtf_mode_enabled(),
178
- }
179
-
180
- if require_creating_udtf_in_sproc(udtf_proto):
181
- snowpark_udtf = create_udtf_in_sproc(**kwargs)
182
- else:
183
- udtf = create_udtf(**kwargs)
184
- snowpark_udtf = SnowparkUDTF(
185
- name=udtf.name, input_types=udtf._input_types, output_schema=output_schema
186
- )
200
+ @cache_external_udtf_wrapper(from_register_udtf=True)
201
+ def _register_udtf(
202
+ udtf_proto: relation_proto.CommonInlineUserDefinedTableFunction,
203
+ spark_column_names,
204
+ ):
205
+ kwargs = {
206
+ "session": session,
207
+ "udtf_proto": udtf_proto,
208
+ "expected_types": expected_types,
209
+ "output_schema": output_schema,
210
+ "packages": global_config.get("snowpark.connect.udf.packages", ""),
211
+ "imports": get_python_udxf_import_files(session),
212
+ "called_from": "register_udtf",
213
+ "is_arrow_enabled": is_arrow_enabled_in_udtf(),
214
+ "is_spark_compatible_udtf_mode_enabled": is_spark_compatible_udtf_mode_enabled(),
215
+ }
216
+
217
+ if require_creating_udtf_in_sproc(udtf_proto):
218
+ snowpark_udtf = create_udtf_in_sproc(**kwargs)
219
+ else:
220
+ udtf = create_udtf(**kwargs)
221
+ snowpark_udtf = SnowparkUDTF(
222
+ name=udtf.name,
223
+ input_types=udtf._input_types,
224
+ output_schema=output_schema,
225
+ )
226
+
227
+ return snowpark_udtf
187
228
 
229
+ snowpark_udtf = _register_udtf(udtf_proto, spark_column_names)
230
+ # We have to update cached _udtfs here, because function could have been cached in map_common_inline_user_defined_table_function
188
231
  session._udtfs[function_name.lower()] = (snowpark_udtf, spark_column_names)
189
232
  return snowpark_udtf
190
233
 
@@ -213,32 +256,41 @@ def map_common_inline_user_defined_table_function(
213
256
  spark_column_names,
214
257
  ) = process_return_type(python_udft.return_type)
215
258
 
216
- kwargs = {
217
- "session": session,
218
- "udtf_proto": rel,
219
- "expected_types": expected_types,
220
- "output_schema": output_schema,
221
- "packages": global_config.get("snowpark.connect.udf.packages", ""),
222
- "imports": get_python_udxf_import_files(session),
223
- "called_from": "map_common_inline_user_defined_table_function",
224
- "is_arrow_enabled": is_arrow_enabled_in_udtf(),
225
- "is_spark_compatible_udtf_mode_enabled": is_spark_compatible_udtf_mode_enabled(),
226
- }
227
-
228
- if require_creating_udtf_in_sproc(rel):
229
- snowpark_udtf_or_error = create_udtf_in_sproc(**kwargs)
230
- if isinstance(snowpark_udtf_or_error, str):
231
- raise PythonException(snowpark_udtf_or_error)
232
- snowpark_udtf = snowpark_udtf_or_error
233
- else:
234
- udtf_or_error = create_udtf(**kwargs)
235
- if isinstance(udtf_or_error, str):
236
- raise PythonException(udtf_or_error)
237
- udtf = udtf_or_error
238
- snowpark_udtf = SnowparkUDTF(
239
- name=udtf.name, input_types=udtf._input_types, output_schema=output_schema
240
- )
259
+ @cache_external_udtf_wrapper(from_register_udtf=False)
260
+ def _get_udtf(
261
+ udtf_proto: relation_proto.CommonInlineUserDefinedTableFunction,
262
+ spark_column_names,
263
+ ):
264
+ kwargs = {
265
+ "session": session,
266
+ "udtf_proto": udtf_proto,
267
+ "expected_types": expected_types,
268
+ "output_schema": output_schema,
269
+ "packages": global_config.get("snowpark.connect.udf.packages", ""),
270
+ "imports": get_python_udxf_import_files(session),
271
+ "called_from": "map_common_inline_user_defined_table_function",
272
+ "is_arrow_enabled": is_arrow_enabled_in_udtf(),
273
+ "is_spark_compatible_udtf_mode_enabled": is_spark_compatible_udtf_mode_enabled(),
274
+ }
275
+
276
+ if require_creating_udtf_in_sproc(udtf_proto):
277
+ snowpark_udtf_or_error = create_udtf_in_sproc(**kwargs)
278
+ if isinstance(snowpark_udtf_or_error, str):
279
+ raise PythonException(snowpark_udtf_or_error)
280
+ snowpark_udtf = snowpark_udtf_or_error
281
+ else:
282
+ udtf_or_error = create_udtf(**kwargs)
283
+ if isinstance(udtf_or_error, str):
284
+ raise PythonException(udtf_or_error)
285
+ udtf = udtf_or_error
286
+ snowpark_udtf = SnowparkUDTF(
287
+ name=udtf.name,
288
+ input_types=udtf._input_types,
289
+ output_schema=output_schema,
290
+ )
291
+ return snowpark_udtf
241
292
 
293
+ snowpark_udtf = _get_udtf(rel, spark_column_names)
242
294
  column_map = ColumnNameMap([], [])
243
295
  snowpark_udtf_args = []
244
296
 
@@ -6,6 +6,7 @@ import random
6
6
  import re
7
7
  import string
8
8
  import time
9
+ import uuid
9
10
  from typing import Sequence
10
11
 
11
12
  import pyspark.sql.connect.proto.relations_pb2 as relation_proto
@@ -153,6 +154,49 @@ def random_string(
153
154
  return "".join([prefix, random_part, suffix])
154
155
 
155
156
 
157
+ def generate_spark_compatible_filename(
158
+ task_id: int = 0,
159
+ attempt_number: int = 0,
160
+ compression: str = None,
161
+ format_ext: str = "parquet",
162
+ ) -> str:
163
+ """Generate a Spark-compatible filename following the convention:
164
+ part-<task-id>-<uuid>-c<attempt-number>.<compression>.<format>
165
+
166
+ Args:
167
+ task_id: Task ID (usually 0 for single partition)
168
+ attempt_number: Attempt number (usually 0)
169
+ compression: Compression type (e.g., 'snappy', 'gzip', 'none')
170
+ format_ext: File format extension (e.g., 'parquet', 'csv', 'json')
171
+
172
+ Returns:
173
+ A filename string following Spark's naming convention
174
+ """
175
+ # Generate a UUID for uniqueness
176
+ file_uuid = str(uuid.uuid4())
177
+
178
+ # Format task ID with leading zeros (5 digits)
179
+ formatted_task_id = f"{task_id:05d}"
180
+
181
+ # Format attempt number with leading zeros (3 digits)
182
+ formatted_attempt = f"{attempt_number:03d}"
183
+
184
+ # Build the base filename
185
+ base_name = f"part-{formatted_task_id}-{file_uuid}-c{formatted_attempt}"
186
+
187
+ # Add compression if specified and not 'none'
188
+ if compression and compression.lower() not in ("none", "uncompressed"):
189
+ compression_part = f".{compression.lower()}"
190
+ else:
191
+ compression_part = ""
192
+
193
+ # Add format extension if specified
194
+ if format_ext:
195
+ return f"{base_name}{compression_part}.{format_ext}"
196
+ else:
197
+ return f"{base_name}{compression_part}"
198
+
199
+
156
200
  def _normalize_query_for_semantic_hash(query_str: str) -> str:
157
201
  """
158
202
  Normalize a query string for semantic comparison by extracting original names from