snowpark-connect 0.25.0__py3-none-any.whl → 0.27.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. snowflake/snowpark_connect/config.py +10 -3
  2. snowflake/snowpark_connect/dataframe_container.py +16 -0
  3. snowflake/snowpark_connect/expression/map_expression.py +15 -0
  4. snowflake/snowpark_connect/expression/map_udf.py +68 -27
  5. snowflake/snowpark_connect/expression/map_unresolved_attribute.py +18 -0
  6. snowflake/snowpark_connect/expression/map_unresolved_function.py +38 -28
  7. snowflake/snowpark_connect/includes/jars/json4s-native_2.12-3.7.0-M11.jar +0 -0
  8. snowflake/snowpark_connect/includes/jars/paranamer-2.8.3.jar +0 -0
  9. snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
  10. snowflake/snowpark_connect/relation/map_extension.py +9 -7
  11. snowflake/snowpark_connect/relation/map_map_partitions.py +36 -72
  12. snowflake/snowpark_connect/relation/map_relation.py +15 -2
  13. snowflake/snowpark_connect/relation/map_row_ops.py +8 -1
  14. snowflake/snowpark_connect/relation/map_show_string.py +2 -0
  15. snowflake/snowpark_connect/relation/map_sql.py +63 -2
  16. snowflake/snowpark_connect/relation/map_udtf.py +96 -44
  17. snowflake/snowpark_connect/relation/utils.py +44 -0
  18. snowflake/snowpark_connect/relation/write/map_write.py +135 -24
  19. snowflake/snowpark_connect/resources_initializer.py +18 -5
  20. snowflake/snowpark_connect/server.py +12 -2
  21. snowflake/snowpark_connect/utils/artifacts.py +4 -5
  22. snowflake/snowpark_connect/utils/concurrent.py +4 -0
  23. snowflake/snowpark_connect/utils/context.py +41 -1
  24. snowflake/snowpark_connect/utils/external_udxf_cache.py +36 -0
  25. snowflake/snowpark_connect/utils/pandas_udtf_utils.py +86 -2
  26. snowflake/snowpark_connect/utils/scala_udf_utils.py +250 -242
  27. snowflake/snowpark_connect/utils/session.py +4 -0
  28. snowflake/snowpark_connect/utils/udf_utils.py +71 -118
  29. snowflake/snowpark_connect/utils/udtf_helper.py +17 -7
  30. snowflake/snowpark_connect/utils/udtf_utils.py +3 -16
  31. snowflake/snowpark_connect/version.py +2 -3
  32. {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.27.0.dist-info}/METADATA +2 -2
  33. {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.27.0.dist-info}/RECORD +41 -37
  34. {snowpark_connect-0.25.0.data → snowpark_connect-0.27.0.data}/scripts/snowpark-connect +0 -0
  35. {snowpark_connect-0.25.0.data → snowpark_connect-0.27.0.data}/scripts/snowpark-session +0 -0
  36. {snowpark_connect-0.25.0.data → snowpark_connect-0.27.0.data}/scripts/snowpark-submit +0 -0
  37. {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.27.0.dist-info}/WHEEL +0 -0
  38. {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.27.0.dist-info}/licenses/LICENSE-binary +0 -0
  39. {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.27.0.dist-info}/licenses/LICENSE.txt +0 -0
  40. {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.27.0.dist-info}/licenses/NOTICE-binary +0 -0
  41. {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.27.0.dist-info}/top_level.txt +0 -0
@@ -15,10 +15,10 @@ Key components:
15
15
  - Type mapping functions for different type systems
16
16
  - UDF creation and management utilities
17
17
  """
18
-
18
+ import re
19
19
  from dataclasses import dataclass
20
20
  from enum import Enum
21
- from typing import Callable, List
21
+ from typing import List, Union
22
22
 
23
23
  import snowflake.snowpark.types as snowpark_type
24
24
  import snowflake.snowpark_connect.includes.python.pyspark.sql.connect.proto.types_pb2 as types_proto
@@ -130,12 +130,13 @@ class ScalaUDFDef:
130
130
  name: str
131
131
  signature: Signature
132
132
  scala_signature: Signature
133
+ scala_invocation_args: List[str]
133
134
  imports: List[str]
134
135
  null_handling: NullHandling = NullHandling.RETURNS_NULL_ON_NULL_INPUT
135
136
 
136
137
  # -------------------- DDL Emitter --------------------
137
138
 
138
- def _gen_body_sql(self) -> str:
139
+ def _gen_body_scala(self) -> str:
139
140
  """
140
141
  Generate the Scala code body for the UDF.
141
142
 
@@ -145,36 +146,78 @@ class ScalaUDFDef:
145
146
  Returns:
146
147
  String containing the complete Scala code for the UDF body
147
148
  """
148
- scala_return_type = self.scala_signature.returns.data_type
149
- # Convert Array to Seq for Scala compatibility in function signatures
150
- cast_scala_input_types = (
149
+ # Convert Array to Seq for Scala compatibility in function signatures.
150
+ udf_func_input_types = (
151
151
  ", ".join(p.data_type for p in self.scala_signature.params)
152
152
  ).replace("Array", "Seq")
153
- scala_arg_and_input_types_str = ", ".join(
153
+ # Create the Scala arguments and input types string: "arg0: Type0, arg1: Type1, ...".
154
+ joined_wrapper_arg_and_input_types_str = ", ".join(
154
155
  f"{p.name}: {p.data_type}" for p in self.scala_signature.params
155
156
  )
156
- scala_args_str = ", ".join(f"{p.name}" for p in self.scala_signature.params)
157
- return f"""import org.apache.spark.sql.connect.common.UdfPacket
157
+ # This is used in defining the input types for the wrapper function. For Maps to work correctly with Scala UDFs,
158
+ # we need to set the Map types to Map[String, String]. These get cast to the respective original types
159
+ # when the original UDF function is invoked.
160
+ wrapper_arg_and_input_types_str = re.sub(
161
+ pattern=r"Map\[\w+,\s\w+\]",
162
+ repl="Map[String, String]",
163
+ string=joined_wrapper_arg_and_input_types_str,
164
+ )
165
+ invocation_args = ", ".join(self.scala_invocation_args)
166
+
167
+ # Cannot directly return a map from a Scala UDF due to issues with non-String values. Snowflake SQL Scala only
168
+ # supports Map[String, String] as input types. Therefore, we convert the map to a JSON string before returning.
169
+ # This is processed as a Variant by SQL.
170
+ udf_func_return_type = self.scala_signature.returns.data_type
171
+ is_map_return = udf_func_return_type.startswith("Map")
172
+ wrapper_return_type = "String" if is_map_return else udf_func_return_type
173
+
174
+ # Need to call the map to JSON string converter when a map is returned by the user's function.
175
+ invoke_udf_func = (
176
+ f"write(func({invocation_args}))"
177
+ if is_map_return
178
+ else f"func({invocation_args})"
179
+ )
180
+
181
+ # The lines of code below are required only when a Map is returned by the UDF. This is needed to serialize the
182
+ # map output to a JSON string.
183
+ map_return_imports = (
184
+ ""
185
+ if not is_map_return
186
+ else """
187
+ import org.json4s._
188
+ import org.json4s.native.Serialization._
189
+ import org.json4s.native.Serialization
190
+ """
191
+ )
192
+ map_return_formatter = (
193
+ ""
194
+ if not is_map_return
195
+ else """
196
+ implicit val formats = Serialization.formats(NoTypeHints)
197
+ """
198
+ )
158
199
 
200
+ return f"""import org.apache.spark.sql.connect.common.UdfPacket
201
+ {map_return_imports}
159
202
  import java.io.{{ByteArrayInputStream, ObjectInputStream}}
160
203
  import java.nio.file.{{Files, Paths}}
161
204
 
162
- object SparkUdf {{
163
-
164
- lazy val func: ({cast_scala_input_types}) => {scala_return_type} = {{
205
+ object __RecreatedSparkUdf {{
206
+ {map_return_formatter}
207
+ private lazy val func: ({udf_func_input_types}) => {udf_func_return_type} = {{
165
208
  val importDirectory = System.getProperty("com.snowflake.import_directory")
166
209
  val fPath = importDirectory + "{self.name}.bin"
167
210
  val bytes = Files.readAllBytes(Paths.get(fPath))
168
211
  val ois = new ObjectInputStream(new ByteArrayInputStream(bytes))
169
212
  try {{
170
- ois.readObject().asInstanceOf[UdfPacket].function.asInstanceOf[({cast_scala_input_types}) => {scala_return_type}]
213
+ ois.readObject().asInstanceOf[UdfPacket].function.asInstanceOf[({udf_func_input_types}) => {udf_func_return_type}]
171
214
  }} finally {{
172
215
  ois.close()
173
216
  }}
174
217
  }}
175
218
 
176
- def run({scala_arg_and_input_types_str}): {scala_return_type} = {{
177
- func({scala_args_str})
219
+ def __wrapperFunc({wrapper_arg_and_input_types_str}): {wrapper_return_type} = {{
220
+ {invoke_udf_func}
178
221
  }}
179
222
  }}
180
223
  """
@@ -210,14 +253,14 @@ LANGUAGE SCALA
210
253
  RUNTIME_VERSION = 2.12
211
254
  PACKAGES = ('com.snowflake:snowpark:latest')
212
255
  {imports_sql}
213
- HANDLER = 'SparkUdf.run'
256
+ HANDLER = '__RecreatedSparkUdf.__wrapperFunc'
214
257
  AS
215
258
  $$
216
- {self._gen_body_sql()}
259
+ {self._gen_body_scala()}
217
260
  $$;"""
218
261
 
219
262
 
220
- def build_scala_udf_imports(session, payload, udf_name):
263
+ def build_scala_udf_imports(session, payload, udf_name, is_map_return) -> List[str]:
221
264
  """
222
265
  Build the list of imports needed for the Scala UDF.
223
266
 
@@ -230,6 +273,7 @@ def build_scala_udf_imports(session, payload, udf_name):
230
273
  session: Snowpark session
231
274
  payload: Binary payload containing the serialized UDF
232
275
  udf_name: Name of the UDF (used for the binary file name)
276
+ is_map_return: Indicates if the UDF returns a Map (affects imports)
233
277
 
234
278
  Returns:
235
279
  List of JAR file paths to be imported by the UDF
@@ -254,14 +298,30 @@ def build_scala_udf_imports(session, payload, udf_name):
254
298
  if RESOURCE_PATH not in row[0]:
255
299
  # Remove the stage path since it is not properly formatted.
256
300
  user_jars.append(row[0][row[0].find("/") :])
301
+
302
+ # Jars used when the return type is a Map.
303
+ map_jars = (
304
+ []
305
+ if not is_map_return
306
+ else [
307
+ f"{stage_resource_path}/json4s-core_2.12-3.7.0-M11.jar",
308
+ f"{stage_resource_path}/json4s-native_2.12-3.7.0-M11.jar",
309
+ f"{stage_resource_path}/paranamer-2.8.3.jar",
310
+ ]
311
+ )
312
+
257
313
  # Format the user jars to be used in the IMPORTS clause of the stored procedure.
258
- return [
259
- closure_binary_file,
260
- f"{stage_resource_path}/spark-connect-client-jvm_2.12-3.5.6.jar",
261
- f"{stage_resource_path}/spark-common-utils_2.12-3.5.6.jar",
262
- f"{stage_resource_path}/spark-sql_2.12-3.5.6.jar",
263
- f"{stage_resource_path}/json4s-ast_2.12-3.7.0-M11.jar",
264
- ] + [f"{stage + jar}" for jar in user_jars]
314
+ return (
315
+ [
316
+ closure_binary_file,
317
+ f"{stage_resource_path}/spark-connect-client-jvm_2.12-3.5.6.jar",
318
+ f"{stage_resource_path}/spark-common-utils_2.12-3.5.6.jar",
319
+ f"{stage_resource_path}/spark-sql_2.12-3.5.6.jar",
320
+ f"{stage_resource_path}/json4s-ast_2.12-3.7.0-M11.jar",
321
+ ]
322
+ + map_jars
323
+ + [f"{stage + jar}" for jar in user_jars]
324
+ )
265
325
 
266
326
 
267
327
  def create_scala_udf(pciudf: ProcessCommonInlineUserDefinedFunction) -> ScalaUdf:
@@ -298,49 +358,48 @@ def create_scala_udf(pciudf: ProcessCommonInlineUserDefinedFunction) -> ScalaUdf
298
358
  cached_udf = session._udfs[udf_name]
299
359
  return ScalaUdf(cached_udf.name, cached_udf.input_types, cached_udf.return_type)
300
360
 
301
- imports = build_scala_udf_imports(session, pciudf._payload, udf_name)
302
-
303
- def _build_params(
304
- pciudf: ProcessCommonInlineUserDefinedFunction,
305
- snowpark_type_mapper: Callable[[snowpark_type.DataType], str],
306
- spark_type_mapper: Callable[[types_proto.DataType], str],
307
- ) -> List[Param]:
308
- """
309
- Build the parameter list for the UDF signature.
310
-
311
- Args:
312
- pciudf: The UDF definition object
313
- mapper: Function to map Snowpark types to target type system
361
+ # In case the Scala UDF was created with `spark.udf.register`, the Spark Scala input types (from protobuf) are
362
+ # stored in pciudf.scala_input_types.
363
+ # We cannot rely solely on the inputTypes field from the Scala UDF or the Snowpark input types, since:
364
+ # - spark.udf.register arguments come from the inputTypes field
365
+ # - UDFs created with a data type (like below) do not populate the inputTypes field. This requires the input types
366
+ # inferred by Snowpark. e.g.: udf((i: Long) => (i + 1).toInt, IntegerType)
367
+ input_types = (
368
+ pciudf._scala_input_types if pciudf._scala_input_types else pciudf._input_types
369
+ )
314
370
 
315
- Returns:
316
- List of Param objects representing the function parameters
317
- """
318
- if not pciudf._scala_input_types:
319
- return (
320
- [
321
- Param(name=f"arg{i}", data_type=snowpark_type_mapper(input_type))
322
- for i, input_type in enumerate(pciudf._input_types)
323
- ]
324
- if pciudf._input_types
325
- else []
371
+ scala_input_params: List[Param] = []
372
+ sql_input_params: List[Param] = []
373
+ scala_invocation_args: List[str] = [] # arguments passed into the udf function
374
+ if input_types: # input_types can be None when no arguments are provided
375
+ for i, input_type in enumerate(input_types):
376
+ param_name = "arg" + str(i)
377
+ # Create the Scala arguments and input types string: "arg0: Type0, arg1: Type1, ...".
378
+ scala_input_params.append(
379
+ Param(param_name, map_type_to_scala_type(input_type))
380
+ )
381
+ # Create the Snowflake SQL arguments and input types string: "arg0 TYPE0, arg1 TYPE1, ...".
382
+ sql_input_params.append(
383
+ Param(param_name, map_type_to_snowflake_type(input_type))
384
+ )
385
+ # In the case of Map input types, we need to cast the argument to the correct type in Scala.
386
+ # Snowflake SQL Scala can only handle MAP[VARCHAR, VARCHAR] as input types.
387
+ scala_invocation_args.append(
388
+ cast_scala_map_args_from_given_type(param_name, input_type)
326
389
  )
327
- else:
328
- return [
329
- Param(name=f"arg{i}", data_type=spark_type_mapper(input_type))
330
- for i, input_type in enumerate(pciudf._scala_input_types)
331
- ]
332
390
 
333
- # Create the Scala arguments and input types string: "arg0: Type0, arg1: Type1, ...".
334
- # In case the Scala UDF was created with `spark.udf.register`, the Spark Scala input types (from protobuf) are
335
- # stored in pciudf.scala_input_types.
336
- sql_input_params = _build_params(
337
- pciudf, map_snowpark_type_to_snowflake_type, map_spark_type_to_snowflake_type
391
+ scala_return_type = map_type_to_scala_type(pciudf._original_return_type)
392
+ # If the SQL return type is a MAP, change this to VARIANT because of issues with Scala UDFs.
393
+ sql_return_type = map_type_to_snowflake_type(pciudf._original_return_type)
394
+ imports = build_scala_udf_imports(
395
+ session,
396
+ pciudf._payload,
397
+ udf_name,
398
+ is_map_return=sql_return_type.startswith("MAP"),
338
399
  )
339
- sql_return_type = map_snowpark_type_to_snowflake_type(pciudf._return_type)
340
- scala_input_params = _build_params(
341
- pciudf, map_snowpark_type_to_scala_type, map_spark_type_to_scala_type
400
+ sql_return_type = (
401
+ "VARIANT" if sql_return_type.startswith("MAP") else sql_return_type
342
402
  )
343
- scala_return_type = map_snowpark_type_to_scala_type(pciudf._return_type)
344
403
 
345
404
  udf_def = ScalaUDFDef(
346
405
  name=udf_name,
@@ -351,6 +410,7 @@ def create_scala_udf(pciudf: ProcessCommonInlineUserDefinedFunction) -> ScalaUdf
351
410
  scala_signature=Signature(
352
411
  params=scala_input_params, returns=ReturnType(scala_return_type)
353
412
  ),
413
+ scala_invocation_args=scala_invocation_args,
354
414
  )
355
415
  create_udf_sql = udf_def.to_create_function_sql()
356
416
  logger.info(f"Creating Scala UDF: {create_udf_sql}")
@@ -358,56 +418,60 @@ def create_scala_udf(pciudf: ProcessCommonInlineUserDefinedFunction) -> ScalaUdf
358
418
  return ScalaUdf(udf_name, pciudf._input_types, pciudf._return_type)
359
419
 
360
420
 
361
- def map_snowpark_type_to_scala_type(t: snowpark_type.DataType) -> str:
362
- """
363
- Maps a Snowpark type to a Scala type string.
364
-
365
- Converts Snowpark DataType objects to their corresponding Scala type names.
366
- This mapping is used when generating Scala code for UDFs.
367
-
368
- Args:
369
- t: Snowpark DataType to convert
370
-
371
- Returns:
372
- String representation of the corresponding Scala type
373
-
374
- Raises:
375
- ValueError: If the Snowpark type is not supported
376
- """
377
- match type(t):
378
- case snowpark_type.ArrayType:
379
- return f"Array[{map_snowpark_type_to_scala_type(t.element_type)}]"
380
- case snowpark_type.BinaryType:
421
+ def map_type_to_scala_type(
422
+ t: Union[snowpark_type.DataType, types_proto.DataType]
423
+ ) -> str:
424
+ """Maps a Snowpark or Spark protobuf type to a Scala type string."""
425
+ if not t:
426
+ return "String"
427
+ is_snowpark_type = isinstance(t, snowpark_type.DataType)
428
+ condition = type(t) if is_snowpark_type else t.WhichOneof("kind")
429
+ match condition:
430
+ case snowpark_type.ArrayType | "array":
431
+ return (
432
+ f"Array[{map_type_to_scala_type(t.element_type)}]"
433
+ if is_snowpark_type
434
+ else f"Array[{map_type_to_scala_type(t.array.element_type)}]"
435
+ )
436
+ case snowpark_type.BinaryType | "binary":
381
437
  return "Array[Byte]"
382
- case snowpark_type.BooleanType:
438
+ case snowpark_type.BooleanType | "boolean":
383
439
  return "Boolean"
384
- case snowpark_type.ByteType:
440
+ case snowpark_type.ByteType | "byte":
385
441
  return "Byte"
386
- case snowpark_type.DateType:
442
+ case snowpark_type.DateType | "date":
387
443
  return "java.sql.Date"
388
- case snowpark_type.DecimalType:
444
+ case snowpark_type.DecimalType | "decimal":
389
445
  return "java.math.BigDecimal"
390
- case snowpark_type.DoubleType:
446
+ case snowpark_type.DoubleType | "double":
391
447
  return "Double"
392
- case snowpark_type.FloatType:
448
+ case snowpark_type.FloatType | "float":
393
449
  return "Float"
394
450
  case snowpark_type.GeographyType:
395
451
  return "Geography"
396
- case snowpark_type.IntegerType:
452
+ case snowpark_type.IntegerType | "integer":
397
453
  return "Int"
398
- case snowpark_type.LongType:
454
+ case snowpark_type.LongType | "long":
399
455
  return "Long"
400
- case snowpark_type.MapType: # can also map to OBJECT in Snowflake
401
- key_type = map_snowpark_type_to_scala_type(t.key_type)
402
- value_type = map_snowpark_type_to_scala_type(t.value_type)
456
+ case snowpark_type.MapType | "map": # can also map to OBJECT in Snowflake
457
+ key_type = (
458
+ map_type_to_scala_type(t.key_type)
459
+ if is_snowpark_type
460
+ else map_type_to_scala_type(t.map.key_type)
461
+ )
462
+ value_type = (
463
+ map_type_to_scala_type(t.value_type)
464
+ if is_snowpark_type
465
+ else map_type_to_scala_type(t.map.value_type)
466
+ )
403
467
  return f"Map[{key_type}, {value_type}]"
404
- case snowpark_type.NullType:
468
+ case snowpark_type.NullType | "null":
405
469
  return "String" # cannot set the return type to Null in Snowpark Scala UDFs
406
- case snowpark_type.ShortType:
470
+ case snowpark_type.ShortType | "short":
407
471
  return "Short"
408
- case snowpark_type.StringType:
472
+ case snowpark_type.StringType | "string" | "char" | "varchar":
409
473
  return "String"
410
- case snowpark_type.TimestampType:
474
+ case snowpark_type.TimestampType | "timestamp" | "timestamp_ntz":
411
475
  return "java.sql.Timestamp"
412
476
  case snowpark_type.VariantType:
413
477
  return "Variant"
@@ -415,61 +479,65 @@ def map_snowpark_type_to_scala_type(t: snowpark_type.DataType) -> str:
415
479
  raise ValueError(f"Unsupported Snowpark type: {t}")
416
480
 
417
481
 
418
- def map_snowpark_type_to_snowflake_type(t: snowpark_type.DataType) -> str:
419
- """
420
- Maps a Snowpark type to a Snowflake type string.
421
-
422
- Converts Snowpark DataType objects to their corresponding Snowflake SQL type names.
423
- This mapping is used when generating CREATE FUNCTION SQL statements.
424
-
425
- Args:
426
- t: Snowpark DataType to convert
427
-
428
- Returns:
429
- String representation of the corresponding Snowflake type
430
-
431
- Raises:
432
- ValueError: If the Snowpark type is not supported
433
- """
434
- match type(t):
435
- case snowpark_type.ArrayType:
436
- return f"ARRAY({map_snowpark_type_to_snowflake_type(t.element_type)})"
437
- case snowpark_type.BinaryType:
482
+ def map_type_to_snowflake_type(
483
+ t: Union[snowpark_type.DataType, types_proto.DataType]
484
+ ) -> str:
485
+ """Maps a Snowpark or Spark protobuf type to a Snowflake type string."""
486
+ if not t:
487
+ return "VARCHAR"
488
+ is_snowpark_type = isinstance(t, snowpark_type.DataType)
489
+ condition = type(t) if is_snowpark_type else t.WhichOneof("kind")
490
+ match condition:
491
+ case snowpark_type.ArrayType | "array":
492
+ return (
493
+ f"ARRAY({map_type_to_snowflake_type(t.element_type)})"
494
+ if is_snowpark_type
495
+ else f"ARRAY({map_type_to_snowflake_type(t.array.element_type)})"
496
+ )
497
+ case snowpark_type.BinaryType | "binary":
438
498
  return "BINARY"
439
- case snowpark_type.BooleanType:
499
+ case snowpark_type.BooleanType | "boolean":
440
500
  return "BOOLEAN"
441
- case snowpark_type.ByteType:
501
+ case snowpark_type.ByteType | "byte":
442
502
  return "TINYINT"
443
- case snowpark_type.DateType:
503
+ case snowpark_type.DateType | "date":
444
504
  return "DATE"
445
- case snowpark_type.DecimalType:
505
+ case snowpark_type.DecimalType | "decimal":
446
506
  return "NUMBER"
447
- case snowpark_type.DoubleType:
507
+ case snowpark_type.DoubleType | "double":
448
508
  return "DOUBLE"
449
- case snowpark_type.FloatType:
509
+ case snowpark_type.FloatType | "float":
450
510
  return "FLOAT"
451
511
  case snowpark_type.GeographyType:
452
512
  return "GEOGRAPHY"
453
- case snowpark_type.IntegerType:
513
+ case snowpark_type.IntegerType | "integer":
454
514
  return "INT"
455
- case snowpark_type.LongType:
515
+ case snowpark_type.LongType | "long":
456
516
  return "BIGINT"
457
- case snowpark_type.MapType:
517
+ case snowpark_type.MapType | "map":
458
518
  # Maps to OBJECT in Snowflake if key and value types are not specified.
459
- key_type = map_snowpark_type_to_snowflake_type(t.key_type)
460
- value_type = map_snowpark_type_to_snowflake_type(t.value_type)
519
+ key_type = (
520
+ map_type_to_snowflake_type(t.key_type)
521
+ if is_snowpark_type
522
+ else map_type_to_snowflake_type(t.map.key_type)
523
+ )
524
+ value_type = (
525
+ map_type_to_snowflake_type(t.value_type)
526
+ if is_snowpark_type
527
+ else map_type_to_snowflake_type(t.map.value_type)
528
+ )
461
529
  return (
462
530
  f"MAP({key_type}, {value_type})"
463
531
  if key_type and value_type
464
532
  else "OBJECT"
465
533
  )
466
- case snowpark_type.NullType:
534
+ case snowpark_type.NullType | "null":
467
535
  return "VARCHAR"
468
- case snowpark_type.ShortType:
536
+ case snowpark_type.ShortType | "short":
469
537
  return "SMALLINT"
470
- case snowpark_type.StringType:
538
+ case snowpark_type.StringType | "string" | "char" | "varchar":
471
539
  return "VARCHAR"
472
- case snowpark_type.TimestampType:
540
+ case snowpark_type.TimestampType | "timestamp" | "timestamp_ntz":
473
541
  return "TIMESTAMP"
474
542
  case snowpark_type.VariantType:
475
543
  return "VARIANT"
@@ -477,112 +545,52 @@ def map_snowpark_type_to_snowflake_type(t: snowpark_type.DataType) -> str:
477
545
  raise ValueError(f"Unsupported Snowpark type: {t}")
478
546
 
479
547
 
480
- def map_spark_type_to_scala_type(t: types_proto.DataType) -> str:
481
- """
482
- Maps a Spark DataType (from protobuf) to a Scala type string.
483
-
484
- Converts Spark protobuf DataType objects to their corresponding Scala type names.
485
- This mapping is used when working with Spark Connect protobuf types.
486
-
487
- Args:
488
- t: Spark protobuf DataType to convert
489
-
490
- Returns:
491
- String representation of the corresponding Scala type
492
-
493
- Raises:
494
- ValueError: If the Spark type is not supported
495
- """
496
- match t.WhichOneof("kind"):
497
- case "array":
498
- return f"Array[{map_spark_type_to_scala_type(t.array.element_type)}]"
499
- case "binary":
500
- return "Array[Byte]"
501
- case "boolean":
502
- return "Boolean"
503
- case "byte":
504
- return "Byte"
505
- case "date":
506
- return "java.sql.Date"
507
- case "decimal":
508
- return "java.math.BigDecimal"
509
- case "double":
510
- return "Double"
511
- case "float":
512
- return "Float"
513
- case "integer":
514
- return "Int"
515
- case "long":
516
- return "Long"
517
- case "map":
518
- key_type = map_spark_type_to_scala_type(t.map.key_type)
519
- value_type = map_spark_type_to_scala_type(t.map.value_type)
520
- return f"Map[{key_type}, {value_type}]"
521
- case "null":
522
- return "String" # cannot set the return type to Null in Snowpark Scala UDFs
523
- case "short":
524
- return "Short"
525
- case "string" | "char" | "varchar":
526
- return "String"
527
- case "timestamp" | "timestamp_ntz":
528
- return "java.sql.Timestamp"
529
- case _:
530
- raise ValueError(f"Unsupported Spark type: {t}")
531
-
532
-
533
- def map_spark_type_to_snowflake_type(t: types_proto.DataType) -> str:
534
- """
535
- Maps a Spark DataType (from protobuf) to a Snowflake type string.
536
-
537
- Converts Spark protobuf DataType objects to their corresponding Snowflake SQL type names.
538
- This mapping is used when working with Spark Connect protobuf types in Snowflake UDFs.
539
-
540
- Args:
541
- t: Spark protobuf DataType to convert
542
-
543
- Returns:
544
- String representation of the corresponding Snowflake type
545
-
546
- Raises:
547
- ValueError: If the Spark type is not supported
548
- """
549
- match t.WhichOneof("kind"):
550
- case "array":
551
- return f"ARRAY({map_spark_type_to_snowflake_type(t.array.element_type)})"
552
- case "binary":
553
- return "BINARY"
554
- case "boolean":
555
- return "BOOLEAN"
556
- case "byte":
557
- return "TINYINT"
558
- case "date":
559
- return "DATE"
560
- case "decimal":
561
- return "NUMBER"
562
- case "double":
563
- return "DOUBLE"
564
- case "float":
565
- return "FLOAT"
566
- case "integer":
567
- return "INT"
568
- case "long":
569
- return "BIGINT"
570
- case "map":
571
- # Maps to OBJECT in Snowflake if key and value types are not specified.
572
- key_type = map_spark_type_to_snowflake_type(t.map.key_type)
573
- value_type = map_spark_type_to_snowflake_type(t.map.value_type)
574
- return (
575
- f"MAP({key_type}, {value_type})"
576
- if key_type and value_type
577
- else "OBJECT"
578
- )
579
- case "null":
580
- return "VARCHAR"
581
- case "short":
582
- return "SMALLINT"
583
- case "string" | "char" | "varchar":
584
- return "VARCHAR"
585
- case "timestamp" | "timestamp_ntz":
586
- return "TIMESTAMP"
587
- case _:
588
- raise ValueError(f"Unsupported Spark type: {t}")
548
+ def cast_scala_map_args_from_given_type(
549
+ arg_name: str, input_type: Union[snowpark_type.DataType, types_proto.DataType]
550
+ ) -> str:
551
+ """If the input_type is a Map, cast the argument arg_name to a Map[key_type, value_type] in Scala."""
552
+ is_snowpark_type = isinstance(input_type, snowpark_type.DataType)
553
+
554
+ def convert_from_string_to_type(
555
+ arg_name: str, t: Union[snowpark_type.DataType, types_proto.DataType]
556
+ ) -> str:
557
+ """Convert the string argument arg_name to the specified type t in Scala."""
558
+ condition = type(t) if is_snowpark_type else t.WhichOneof("kind")
559
+ match condition:
560
+ case snowpark_type.BinaryType | "binary":
561
+ return arg_name + ".getBytes()"
562
+ case snowpark_type.BooleanType | "boolean":
563
+ return arg_name + ".toBoolean"
564
+ case snowpark_type.ByteType | "byte":
565
+ return arg_name + ".getBytes().head" # TODO: verify if this is correct
566
+ case snowpark_type.DateType | "date":
567
+ return f"java.sql.Date.valueOf({arg_name})"
568
+ case snowpark_type.DecimalType | "decimal":
569
+ return f"new BigDecimal({arg_name})"
570
+ case snowpark_type.DoubleType | "double":
571
+ return arg_name + ".toDouble"
572
+ case snowpark_type.FloatType | "float":
573
+ return arg_name + ".toFloat"
574
+ case snowpark_type.IntegerType | "integer":
575
+ return arg_name + ".toInt"
576
+ case snowpark_type.LongType | "long":
577
+ return arg_name + ".toLong"
578
+ case snowpark_type.ShortType | "short":
579
+ return arg_name + ".toShort"
580
+ case snowpark_type.StringType | "string" | "char" | "varchar":
581
+ return arg_name
582
+ case snowpark_type.TimestampType | "timestamp" | "timestamp_ntz":
583
+ return "java.sql.Timestamp.valueOf({arg_name})"
584
+ case _:
585
+ raise ValueError(f"Unsupported Snowpark type: {t}")
586
+
587
+ if (is_snowpark_type and isinstance(input_type, snowpark_type.MapType)) or (
588
+ not is_snowpark_type and input_type.WhichOneof("kind") == "map"
589
+ ):
590
+ key_type = input_type.key_type if is_snowpark_type else input_type.map.key_type
591
+ value_type = (
592
+ input_type.value_type if is_snowpark_type else input_type.map.value_type
593
+ )
594
+ return f"{arg_name}.map {{ case (k, v) => ({convert_from_string_to_type('k', key_type)}, {convert_from_string_to_type('v', value_type)})}}"
595
+ else:
596
+ return arg_name
@@ -14,6 +14,9 @@ from snowflake.snowpark_connect.constants import DEFAULT_CONNECTION_NAME
14
14
  from snowflake.snowpark_connect.utils.describe_query_cache import (
15
15
  instrument_session_for_describe_cache,
16
16
  )
17
+ from snowflake.snowpark_connect.utils.external_udxf_cache import (
18
+ init_external_udxf_cache,
19
+ )
17
20
  from snowflake.snowpark_connect.utils.snowpark_connect_logging import logger
18
21
  from snowflake.snowpark_connect.utils.telemetry import telemetry
19
22
  from snowflake.snowpark_connect.utils.udf_cache import init_builtin_udf_cache
@@ -63,6 +66,7 @@ def configure_snowpark_session(session: snowpark.Session):
63
66
 
64
67
  # built-in udf cache
65
68
  init_builtin_udf_cache(session)
69
+ init_external_udxf_cache(session)
66
70
 
67
71
  # Set experimental parameters (warnings globally suppressed)
68
72
  session.ast_enabled = False