snowpark-connect 0.25.0__py3-none-any.whl → 0.26.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of snowpark-connect might be problematic. Click here for more details.
- snowflake/snowpark_connect/config.py +10 -0
- snowflake/snowpark_connect/dataframe_container.py +16 -0
- snowflake/snowpark_connect/expression/map_udf.py +68 -27
- snowflake/snowpark_connect/expression/map_unresolved_function.py +22 -21
- snowflake/snowpark_connect/includes/jars/json4s-native_2.12-3.7.0-M11.jar +0 -0
- snowflake/snowpark_connect/includes/jars/paranamer-2.8.3.jar +0 -0
- snowflake/snowpark_connect/includes/jars/sas-scala-udf_2.12-0.1.0.jar +0 -0
- snowflake/snowpark_connect/relation/map_map_partitions.py +9 -4
- snowflake/snowpark_connect/relation/map_relation.py +12 -1
- snowflake/snowpark_connect/relation/map_row_ops.py +8 -1
- snowflake/snowpark_connect/relation/map_udtf.py +96 -44
- snowflake/snowpark_connect/relation/utils.py +44 -0
- snowflake/snowpark_connect/relation/write/map_write.py +113 -22
- snowflake/snowpark_connect/resources_initializer.py +18 -5
- snowflake/snowpark_connect/server.py +8 -1
- snowflake/snowpark_connect/utils/concurrent.py +4 -0
- snowflake/snowpark_connect/utils/external_udxf_cache.py +36 -0
- snowflake/snowpark_connect/utils/scala_udf_utils.py +250 -242
- snowflake/snowpark_connect/utils/session.py +4 -0
- snowflake/snowpark_connect/utils/udf_utils.py +7 -17
- snowflake/snowpark_connect/utils/udtf_utils.py +3 -16
- snowflake/snowpark_connect/version.py +1 -1
- {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.26.0.dist-info}/METADATA +1 -1
- {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.26.0.dist-info}/RECORD +32 -28
- {snowpark_connect-0.25.0.data → snowpark_connect-0.26.0.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.25.0.data → snowpark_connect-0.26.0.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.25.0.data → snowpark_connect-0.26.0.data}/scripts/snowpark-submit +0 -0
- {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.26.0.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.26.0.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.26.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.26.0.dist-info}/licenses/NOTICE-binary +0 -0
- {snowpark_connect-0.25.0.dist-info → snowpark_connect-0.26.0.dist-info}/top_level.txt +0 -0
|
@@ -15,10 +15,10 @@ Key components:
|
|
|
15
15
|
- Type mapping functions for different type systems
|
|
16
16
|
- UDF creation and management utilities
|
|
17
17
|
"""
|
|
18
|
-
|
|
18
|
+
import re
|
|
19
19
|
from dataclasses import dataclass
|
|
20
20
|
from enum import Enum
|
|
21
|
-
from typing import
|
|
21
|
+
from typing import List, Union
|
|
22
22
|
|
|
23
23
|
import snowflake.snowpark.types as snowpark_type
|
|
24
24
|
import snowflake.snowpark_connect.includes.python.pyspark.sql.connect.proto.types_pb2 as types_proto
|
|
@@ -130,12 +130,13 @@ class ScalaUDFDef:
|
|
|
130
130
|
name: str
|
|
131
131
|
signature: Signature
|
|
132
132
|
scala_signature: Signature
|
|
133
|
+
scala_invocation_args: List[str]
|
|
133
134
|
imports: List[str]
|
|
134
135
|
null_handling: NullHandling = NullHandling.RETURNS_NULL_ON_NULL_INPUT
|
|
135
136
|
|
|
136
137
|
# -------------------- DDL Emitter --------------------
|
|
137
138
|
|
|
138
|
-
def
|
|
139
|
+
def _gen_body_scala(self) -> str:
|
|
139
140
|
"""
|
|
140
141
|
Generate the Scala code body for the UDF.
|
|
141
142
|
|
|
@@ -145,36 +146,78 @@ class ScalaUDFDef:
|
|
|
145
146
|
Returns:
|
|
146
147
|
String containing the complete Scala code for the UDF body
|
|
147
148
|
"""
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
cast_scala_input_types = (
|
|
149
|
+
# Convert Array to Seq for Scala compatibility in function signatures.
|
|
150
|
+
udf_func_input_types = (
|
|
151
151
|
", ".join(p.data_type for p in self.scala_signature.params)
|
|
152
152
|
).replace("Array", "Seq")
|
|
153
|
-
|
|
153
|
+
# Create the Scala arguments and input types string: "arg0: Type0, arg1: Type1, ...".
|
|
154
|
+
joined_wrapper_arg_and_input_types_str = ", ".join(
|
|
154
155
|
f"{p.name}: {p.data_type}" for p in self.scala_signature.params
|
|
155
156
|
)
|
|
156
|
-
|
|
157
|
-
|
|
157
|
+
# This is used in defining the input types for the wrapper function. For Maps to work correctly with Scala UDFs,
|
|
158
|
+
# we need to set the Map types to Map[String, String]. These get cast to the respective original types
|
|
159
|
+
# when the original UDF function is invoked.
|
|
160
|
+
wrapper_arg_and_input_types_str = re.sub(
|
|
161
|
+
pattern=r"Map\[\w+,\s\w+\]",
|
|
162
|
+
repl="Map[String, String]",
|
|
163
|
+
string=joined_wrapper_arg_and_input_types_str,
|
|
164
|
+
)
|
|
165
|
+
invocation_args = ", ".join(self.scala_invocation_args)
|
|
166
|
+
|
|
167
|
+
# Cannot directly return a map from a Scala UDF due to issues with non-String values. Snowflake SQL Scala only
|
|
168
|
+
# supports Map[String, String] as input types. Therefore, we convert the map to a JSON string before returning.
|
|
169
|
+
# This is processed as a Variant by SQL.
|
|
170
|
+
udf_func_return_type = self.scala_signature.returns.data_type
|
|
171
|
+
is_map_return = udf_func_return_type.startswith("Map")
|
|
172
|
+
wrapper_return_type = "String" if is_map_return else udf_func_return_type
|
|
173
|
+
|
|
174
|
+
# Need to call the map to JSON string converter when a map is returned by the user's function.
|
|
175
|
+
invoke_udf_func = (
|
|
176
|
+
f"write(func({invocation_args}))"
|
|
177
|
+
if is_map_return
|
|
178
|
+
else f"func({invocation_args})"
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
# The lines of code below are required only when a Map is returned by the UDF. This is needed to serialize the
|
|
182
|
+
# map output to a JSON string.
|
|
183
|
+
map_return_imports = (
|
|
184
|
+
""
|
|
185
|
+
if not is_map_return
|
|
186
|
+
else """
|
|
187
|
+
import org.json4s._
|
|
188
|
+
import org.json4s.native.Serialization._
|
|
189
|
+
import org.json4s.native.Serialization
|
|
190
|
+
"""
|
|
191
|
+
)
|
|
192
|
+
map_return_formatter = (
|
|
193
|
+
""
|
|
194
|
+
if not is_map_return
|
|
195
|
+
else """
|
|
196
|
+
implicit val formats = Serialization.formats(NoTypeHints)
|
|
197
|
+
"""
|
|
198
|
+
)
|
|
158
199
|
|
|
200
|
+
return f"""import org.apache.spark.sql.connect.common.UdfPacket
|
|
201
|
+
{map_return_imports}
|
|
159
202
|
import java.io.{{ByteArrayInputStream, ObjectInputStream}}
|
|
160
203
|
import java.nio.file.{{Files, Paths}}
|
|
161
204
|
|
|
162
|
-
object
|
|
163
|
-
|
|
164
|
-
lazy val func: ({
|
|
205
|
+
object __RecreatedSparkUdf {{
|
|
206
|
+
{map_return_formatter}
|
|
207
|
+
private lazy val func: ({udf_func_input_types}) => {udf_func_return_type} = {{
|
|
165
208
|
val importDirectory = System.getProperty("com.snowflake.import_directory")
|
|
166
209
|
val fPath = importDirectory + "{self.name}.bin"
|
|
167
210
|
val bytes = Files.readAllBytes(Paths.get(fPath))
|
|
168
211
|
val ois = new ObjectInputStream(new ByteArrayInputStream(bytes))
|
|
169
212
|
try {{
|
|
170
|
-
ois.readObject().asInstanceOf[UdfPacket].function.asInstanceOf[({
|
|
213
|
+
ois.readObject().asInstanceOf[UdfPacket].function.asInstanceOf[({udf_func_input_types}) => {udf_func_return_type}]
|
|
171
214
|
}} finally {{
|
|
172
215
|
ois.close()
|
|
173
216
|
}}
|
|
174
217
|
}}
|
|
175
218
|
|
|
176
|
-
def
|
|
177
|
-
|
|
219
|
+
def __wrapperFunc({wrapper_arg_and_input_types_str}): {wrapper_return_type} = {{
|
|
220
|
+
{invoke_udf_func}
|
|
178
221
|
}}
|
|
179
222
|
}}
|
|
180
223
|
"""
|
|
@@ -210,14 +253,14 @@ LANGUAGE SCALA
|
|
|
210
253
|
RUNTIME_VERSION = 2.12
|
|
211
254
|
PACKAGES = ('com.snowflake:snowpark:latest')
|
|
212
255
|
{imports_sql}
|
|
213
|
-
HANDLER = '
|
|
256
|
+
HANDLER = '__RecreatedSparkUdf.__wrapperFunc'
|
|
214
257
|
AS
|
|
215
258
|
$$
|
|
216
|
-
{self.
|
|
259
|
+
{self._gen_body_scala()}
|
|
217
260
|
$$;"""
|
|
218
261
|
|
|
219
262
|
|
|
220
|
-
def build_scala_udf_imports(session, payload, udf_name):
|
|
263
|
+
def build_scala_udf_imports(session, payload, udf_name, is_map_return) -> List[str]:
|
|
221
264
|
"""
|
|
222
265
|
Build the list of imports needed for the Scala UDF.
|
|
223
266
|
|
|
@@ -230,6 +273,7 @@ def build_scala_udf_imports(session, payload, udf_name):
|
|
|
230
273
|
session: Snowpark session
|
|
231
274
|
payload: Binary payload containing the serialized UDF
|
|
232
275
|
udf_name: Name of the UDF (used for the binary file name)
|
|
276
|
+
is_map_return: Indicates if the UDF returns a Map (affects imports)
|
|
233
277
|
|
|
234
278
|
Returns:
|
|
235
279
|
List of JAR file paths to be imported by the UDF
|
|
@@ -254,14 +298,30 @@ def build_scala_udf_imports(session, payload, udf_name):
|
|
|
254
298
|
if RESOURCE_PATH not in row[0]:
|
|
255
299
|
# Remove the stage path since it is not properly formatted.
|
|
256
300
|
user_jars.append(row[0][row[0].find("/") :])
|
|
301
|
+
|
|
302
|
+
# Jars used when the return type is a Map.
|
|
303
|
+
map_jars = (
|
|
304
|
+
[]
|
|
305
|
+
if not is_map_return
|
|
306
|
+
else [
|
|
307
|
+
f"{stage_resource_path}/json4s-core_2.12-3.7.0-M11.jar",
|
|
308
|
+
f"{stage_resource_path}/json4s-native_2.12-3.7.0-M11.jar",
|
|
309
|
+
f"{stage_resource_path}/paranamer-2.8.3.jar",
|
|
310
|
+
]
|
|
311
|
+
)
|
|
312
|
+
|
|
257
313
|
# Format the user jars to be used in the IMPORTS clause of the stored procedure.
|
|
258
|
-
return
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
314
|
+
return (
|
|
315
|
+
[
|
|
316
|
+
closure_binary_file,
|
|
317
|
+
f"{stage_resource_path}/spark-connect-client-jvm_2.12-3.5.6.jar",
|
|
318
|
+
f"{stage_resource_path}/spark-common-utils_2.12-3.5.6.jar",
|
|
319
|
+
f"{stage_resource_path}/spark-sql_2.12-3.5.6.jar",
|
|
320
|
+
f"{stage_resource_path}/json4s-ast_2.12-3.7.0-M11.jar",
|
|
321
|
+
]
|
|
322
|
+
+ map_jars
|
|
323
|
+
+ [f"{stage + jar}" for jar in user_jars]
|
|
324
|
+
)
|
|
265
325
|
|
|
266
326
|
|
|
267
327
|
def create_scala_udf(pciudf: ProcessCommonInlineUserDefinedFunction) -> ScalaUdf:
|
|
@@ -298,49 +358,48 @@ def create_scala_udf(pciudf: ProcessCommonInlineUserDefinedFunction) -> ScalaUdf
|
|
|
298
358
|
cached_udf = session._udfs[udf_name]
|
|
299
359
|
return ScalaUdf(cached_udf.name, cached_udf.input_types, cached_udf.return_type)
|
|
300
360
|
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
Args:
|
|
312
|
-
pciudf: The UDF definition object
|
|
313
|
-
mapper: Function to map Snowpark types to target type system
|
|
361
|
+
# In case the Scala UDF was created with `spark.udf.register`, the Spark Scala input types (from protobuf) are
|
|
362
|
+
# stored in pciudf.scala_input_types.
|
|
363
|
+
# We cannot rely solely on the inputTypes field from the Scala UDF or the Snowpark input types, since:
|
|
364
|
+
# - spark.udf.register arguments come from the inputTypes field
|
|
365
|
+
# - UDFs created with a data type (like below) do not populate the inputTypes field. This requires the input types
|
|
366
|
+
# inferred by Snowpark. e.g.: udf((i: Long) => (i + 1).toInt, IntegerType)
|
|
367
|
+
input_types = (
|
|
368
|
+
pciudf._scala_input_types if pciudf._scala_input_types else pciudf._input_types
|
|
369
|
+
)
|
|
314
370
|
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
371
|
+
scala_input_params: List[Param] = []
|
|
372
|
+
sql_input_params: List[Param] = []
|
|
373
|
+
scala_invocation_args: List[str] = [] # arguments passed into the udf function
|
|
374
|
+
if input_types: # input_types can be None when no arguments are provided
|
|
375
|
+
for i, input_type in enumerate(input_types):
|
|
376
|
+
param_name = "arg" + str(i)
|
|
377
|
+
# Create the Scala arguments and input types string: "arg0: Type0, arg1: Type1, ...".
|
|
378
|
+
scala_input_params.append(
|
|
379
|
+
Param(param_name, map_type_to_scala_type(input_type))
|
|
380
|
+
)
|
|
381
|
+
# Create the Snowflake SQL arguments and input types string: "arg0 TYPE0, arg1 TYPE1, ...".
|
|
382
|
+
sql_input_params.append(
|
|
383
|
+
Param(param_name, map_type_to_snowflake_type(input_type))
|
|
384
|
+
)
|
|
385
|
+
# In the case of Map input types, we need to cast the argument to the correct type in Scala.
|
|
386
|
+
# Snowflake SQL Scala can only handle MAP[VARCHAR, VARCHAR] as input types.
|
|
387
|
+
scala_invocation_args.append(
|
|
388
|
+
cast_scala_map_args_from_given_type(param_name, input_type)
|
|
326
389
|
)
|
|
327
|
-
else:
|
|
328
|
-
return [
|
|
329
|
-
Param(name=f"arg{i}", data_type=spark_type_mapper(input_type))
|
|
330
|
-
for i, input_type in enumerate(pciudf._scala_input_types)
|
|
331
|
-
]
|
|
332
390
|
|
|
333
|
-
|
|
334
|
-
#
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
391
|
+
scala_return_type = map_type_to_scala_type(pciudf._original_return_type)
|
|
392
|
+
# If the SQL return type is a MAP, change this to VARIANT because of issues with Scala UDFs.
|
|
393
|
+
sql_return_type = map_type_to_snowflake_type(pciudf._original_return_type)
|
|
394
|
+
imports = build_scala_udf_imports(
|
|
395
|
+
session,
|
|
396
|
+
pciudf._payload,
|
|
397
|
+
udf_name,
|
|
398
|
+
is_map_return=sql_return_type.startswith("MAP"),
|
|
338
399
|
)
|
|
339
|
-
sql_return_type =
|
|
340
|
-
|
|
341
|
-
pciudf, map_snowpark_type_to_scala_type, map_spark_type_to_scala_type
|
|
400
|
+
sql_return_type = (
|
|
401
|
+
"VARIANT" if sql_return_type.startswith("MAP") else sql_return_type
|
|
342
402
|
)
|
|
343
|
-
scala_return_type = map_snowpark_type_to_scala_type(pciudf._return_type)
|
|
344
403
|
|
|
345
404
|
udf_def = ScalaUDFDef(
|
|
346
405
|
name=udf_name,
|
|
@@ -351,6 +410,7 @@ def create_scala_udf(pciudf: ProcessCommonInlineUserDefinedFunction) -> ScalaUdf
|
|
|
351
410
|
scala_signature=Signature(
|
|
352
411
|
params=scala_input_params, returns=ReturnType(scala_return_type)
|
|
353
412
|
),
|
|
413
|
+
scala_invocation_args=scala_invocation_args,
|
|
354
414
|
)
|
|
355
415
|
create_udf_sql = udf_def.to_create_function_sql()
|
|
356
416
|
logger.info(f"Creating Scala UDF: {create_udf_sql}")
|
|
@@ -358,56 +418,60 @@ def create_scala_udf(pciudf: ProcessCommonInlineUserDefinedFunction) -> ScalaUdf
|
|
|
358
418
|
return ScalaUdf(udf_name, pciudf._input_types, pciudf._return_type)
|
|
359
419
|
|
|
360
420
|
|
|
361
|
-
def
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
match type(t):
|
|
378
|
-
case snowpark_type.ArrayType:
|
|
379
|
-
return f"Array[{map_snowpark_type_to_scala_type(t.element_type)}]"
|
|
380
|
-
case snowpark_type.BinaryType:
|
|
421
|
+
def map_type_to_scala_type(
|
|
422
|
+
t: Union[snowpark_type.DataType, types_proto.DataType]
|
|
423
|
+
) -> str:
|
|
424
|
+
"""Maps a Snowpark or Spark protobuf type to a Scala type string."""
|
|
425
|
+
if not t:
|
|
426
|
+
return "String"
|
|
427
|
+
is_snowpark_type = isinstance(t, snowpark_type.DataType)
|
|
428
|
+
condition = type(t) if is_snowpark_type else t.WhichOneof("kind")
|
|
429
|
+
match condition:
|
|
430
|
+
case snowpark_type.ArrayType | "array":
|
|
431
|
+
return (
|
|
432
|
+
f"Array[{map_type_to_scala_type(t.element_type)}]"
|
|
433
|
+
if is_snowpark_type
|
|
434
|
+
else f"Array[{map_type_to_scala_type(t.array.element_type)}]"
|
|
435
|
+
)
|
|
436
|
+
case snowpark_type.BinaryType | "binary":
|
|
381
437
|
return "Array[Byte]"
|
|
382
|
-
case snowpark_type.BooleanType:
|
|
438
|
+
case snowpark_type.BooleanType | "boolean":
|
|
383
439
|
return "Boolean"
|
|
384
|
-
case snowpark_type.ByteType:
|
|
440
|
+
case snowpark_type.ByteType | "byte":
|
|
385
441
|
return "Byte"
|
|
386
|
-
case snowpark_type.DateType:
|
|
442
|
+
case snowpark_type.DateType | "date":
|
|
387
443
|
return "java.sql.Date"
|
|
388
|
-
case snowpark_type.DecimalType:
|
|
444
|
+
case snowpark_type.DecimalType | "decimal":
|
|
389
445
|
return "java.math.BigDecimal"
|
|
390
|
-
case snowpark_type.DoubleType:
|
|
446
|
+
case snowpark_type.DoubleType | "double":
|
|
391
447
|
return "Double"
|
|
392
|
-
case snowpark_type.FloatType:
|
|
448
|
+
case snowpark_type.FloatType | "float":
|
|
393
449
|
return "Float"
|
|
394
450
|
case snowpark_type.GeographyType:
|
|
395
451
|
return "Geography"
|
|
396
|
-
case snowpark_type.IntegerType:
|
|
452
|
+
case snowpark_type.IntegerType | "integer":
|
|
397
453
|
return "Int"
|
|
398
|
-
case snowpark_type.LongType:
|
|
454
|
+
case snowpark_type.LongType | "long":
|
|
399
455
|
return "Long"
|
|
400
|
-
case snowpark_type.MapType: # can also map to OBJECT in Snowflake
|
|
401
|
-
key_type =
|
|
402
|
-
|
|
456
|
+
case snowpark_type.MapType | "map": # can also map to OBJECT in Snowflake
|
|
457
|
+
key_type = (
|
|
458
|
+
map_type_to_scala_type(t.key_type)
|
|
459
|
+
if is_snowpark_type
|
|
460
|
+
else map_type_to_scala_type(t.map.key_type)
|
|
461
|
+
)
|
|
462
|
+
value_type = (
|
|
463
|
+
map_type_to_scala_type(t.value_type)
|
|
464
|
+
if is_snowpark_type
|
|
465
|
+
else map_type_to_scala_type(t.map.value_type)
|
|
466
|
+
)
|
|
403
467
|
return f"Map[{key_type}, {value_type}]"
|
|
404
|
-
case snowpark_type.NullType:
|
|
468
|
+
case snowpark_type.NullType | "null":
|
|
405
469
|
return "String" # cannot set the return type to Null in Snowpark Scala UDFs
|
|
406
|
-
case snowpark_type.ShortType:
|
|
470
|
+
case snowpark_type.ShortType | "short":
|
|
407
471
|
return "Short"
|
|
408
|
-
case snowpark_type.StringType:
|
|
472
|
+
case snowpark_type.StringType | "string" | "char" | "varchar":
|
|
409
473
|
return "String"
|
|
410
|
-
case snowpark_type.TimestampType:
|
|
474
|
+
case snowpark_type.TimestampType | "timestamp" | "timestamp_ntz":
|
|
411
475
|
return "java.sql.Timestamp"
|
|
412
476
|
case snowpark_type.VariantType:
|
|
413
477
|
return "Variant"
|
|
@@ -415,61 +479,65 @@ def map_snowpark_type_to_scala_type(t: snowpark_type.DataType) -> str:
|
|
|
415
479
|
raise ValueError(f"Unsupported Snowpark type: {t}")
|
|
416
480
|
|
|
417
481
|
|
|
418
|
-
def
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
match type(t):
|
|
435
|
-
case snowpark_type.ArrayType:
|
|
436
|
-
return f"ARRAY({map_snowpark_type_to_snowflake_type(t.element_type)})"
|
|
437
|
-
case snowpark_type.BinaryType:
|
|
482
|
+
def map_type_to_snowflake_type(
|
|
483
|
+
t: Union[snowpark_type.DataType, types_proto.DataType]
|
|
484
|
+
) -> str:
|
|
485
|
+
"""Maps a Snowpark or Spark protobuf type to a Snowflake type string."""
|
|
486
|
+
if not t:
|
|
487
|
+
return "VARCHAR"
|
|
488
|
+
is_snowpark_type = isinstance(t, snowpark_type.DataType)
|
|
489
|
+
condition = type(t) if is_snowpark_type else t.WhichOneof("kind")
|
|
490
|
+
match condition:
|
|
491
|
+
case snowpark_type.ArrayType | "array":
|
|
492
|
+
return (
|
|
493
|
+
f"ARRAY({map_type_to_snowflake_type(t.element_type)})"
|
|
494
|
+
if is_snowpark_type
|
|
495
|
+
else f"ARRAY({map_type_to_snowflake_type(t.array.element_type)})"
|
|
496
|
+
)
|
|
497
|
+
case snowpark_type.BinaryType | "binary":
|
|
438
498
|
return "BINARY"
|
|
439
|
-
case snowpark_type.BooleanType:
|
|
499
|
+
case snowpark_type.BooleanType | "boolean":
|
|
440
500
|
return "BOOLEAN"
|
|
441
|
-
case snowpark_type.ByteType:
|
|
501
|
+
case snowpark_type.ByteType | "byte":
|
|
442
502
|
return "TINYINT"
|
|
443
|
-
case snowpark_type.DateType:
|
|
503
|
+
case snowpark_type.DateType | "date":
|
|
444
504
|
return "DATE"
|
|
445
|
-
case snowpark_type.DecimalType:
|
|
505
|
+
case snowpark_type.DecimalType | "decimal":
|
|
446
506
|
return "NUMBER"
|
|
447
|
-
case snowpark_type.DoubleType:
|
|
507
|
+
case snowpark_type.DoubleType | "double":
|
|
448
508
|
return "DOUBLE"
|
|
449
|
-
case snowpark_type.FloatType:
|
|
509
|
+
case snowpark_type.FloatType | "float":
|
|
450
510
|
return "FLOAT"
|
|
451
511
|
case snowpark_type.GeographyType:
|
|
452
512
|
return "GEOGRAPHY"
|
|
453
|
-
case snowpark_type.IntegerType:
|
|
513
|
+
case snowpark_type.IntegerType | "integer":
|
|
454
514
|
return "INT"
|
|
455
|
-
case snowpark_type.LongType:
|
|
515
|
+
case snowpark_type.LongType | "long":
|
|
456
516
|
return "BIGINT"
|
|
457
|
-
case snowpark_type.MapType:
|
|
517
|
+
case snowpark_type.MapType | "map":
|
|
458
518
|
# Maps to OBJECT in Snowflake if key and value types are not specified.
|
|
459
|
-
key_type =
|
|
460
|
-
|
|
519
|
+
key_type = (
|
|
520
|
+
map_type_to_snowflake_type(t.key_type)
|
|
521
|
+
if is_snowpark_type
|
|
522
|
+
else map_type_to_snowflake_type(t.map.key_type)
|
|
523
|
+
)
|
|
524
|
+
value_type = (
|
|
525
|
+
map_type_to_snowflake_type(t.value_type)
|
|
526
|
+
if is_snowpark_type
|
|
527
|
+
else map_type_to_snowflake_type(t.map.value_type)
|
|
528
|
+
)
|
|
461
529
|
return (
|
|
462
530
|
f"MAP({key_type}, {value_type})"
|
|
463
531
|
if key_type and value_type
|
|
464
532
|
else "OBJECT"
|
|
465
533
|
)
|
|
466
|
-
case snowpark_type.NullType:
|
|
534
|
+
case snowpark_type.NullType | "null":
|
|
467
535
|
return "VARCHAR"
|
|
468
|
-
case snowpark_type.ShortType:
|
|
536
|
+
case snowpark_type.ShortType | "short":
|
|
469
537
|
return "SMALLINT"
|
|
470
|
-
case snowpark_type.StringType:
|
|
538
|
+
case snowpark_type.StringType | "string" | "char" | "varchar":
|
|
471
539
|
return "VARCHAR"
|
|
472
|
-
case snowpark_type.TimestampType:
|
|
540
|
+
case snowpark_type.TimestampType | "timestamp" | "timestamp_ntz":
|
|
473
541
|
return "TIMESTAMP"
|
|
474
542
|
case snowpark_type.VariantType:
|
|
475
543
|
return "VARIANT"
|
|
@@ -477,112 +545,52 @@ def map_snowpark_type_to_snowflake_type(t: snowpark_type.DataType) -> str:
|
|
|
477
545
|
raise ValueError(f"Unsupported Snowpark type: {t}")
|
|
478
546
|
|
|
479
547
|
|
|
480
|
-
def
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
case _:
|
|
530
|
-
raise ValueError(f"Unsupported Spark type: {t}")
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
def map_spark_type_to_snowflake_type(t: types_proto.DataType) -> str:
|
|
534
|
-
"""
|
|
535
|
-
Maps a Spark DataType (from protobuf) to a Snowflake type string.
|
|
536
|
-
|
|
537
|
-
Converts Spark protobuf DataType objects to their corresponding Snowflake SQL type names.
|
|
538
|
-
This mapping is used when working with Spark Connect protobuf types in Snowflake UDFs.
|
|
539
|
-
|
|
540
|
-
Args:
|
|
541
|
-
t: Spark protobuf DataType to convert
|
|
542
|
-
|
|
543
|
-
Returns:
|
|
544
|
-
String representation of the corresponding Snowflake type
|
|
545
|
-
|
|
546
|
-
Raises:
|
|
547
|
-
ValueError: If the Spark type is not supported
|
|
548
|
-
"""
|
|
549
|
-
match t.WhichOneof("kind"):
|
|
550
|
-
case "array":
|
|
551
|
-
return f"ARRAY({map_spark_type_to_snowflake_type(t.array.element_type)})"
|
|
552
|
-
case "binary":
|
|
553
|
-
return "BINARY"
|
|
554
|
-
case "boolean":
|
|
555
|
-
return "BOOLEAN"
|
|
556
|
-
case "byte":
|
|
557
|
-
return "TINYINT"
|
|
558
|
-
case "date":
|
|
559
|
-
return "DATE"
|
|
560
|
-
case "decimal":
|
|
561
|
-
return "NUMBER"
|
|
562
|
-
case "double":
|
|
563
|
-
return "DOUBLE"
|
|
564
|
-
case "float":
|
|
565
|
-
return "FLOAT"
|
|
566
|
-
case "integer":
|
|
567
|
-
return "INT"
|
|
568
|
-
case "long":
|
|
569
|
-
return "BIGINT"
|
|
570
|
-
case "map":
|
|
571
|
-
# Maps to OBJECT in Snowflake if key and value types are not specified.
|
|
572
|
-
key_type = map_spark_type_to_snowflake_type(t.map.key_type)
|
|
573
|
-
value_type = map_spark_type_to_snowflake_type(t.map.value_type)
|
|
574
|
-
return (
|
|
575
|
-
f"MAP({key_type}, {value_type})"
|
|
576
|
-
if key_type and value_type
|
|
577
|
-
else "OBJECT"
|
|
578
|
-
)
|
|
579
|
-
case "null":
|
|
580
|
-
return "VARCHAR"
|
|
581
|
-
case "short":
|
|
582
|
-
return "SMALLINT"
|
|
583
|
-
case "string" | "char" | "varchar":
|
|
584
|
-
return "VARCHAR"
|
|
585
|
-
case "timestamp" | "timestamp_ntz":
|
|
586
|
-
return "TIMESTAMP"
|
|
587
|
-
case _:
|
|
588
|
-
raise ValueError(f"Unsupported Spark type: {t}")
|
|
548
|
+
def cast_scala_map_args_from_given_type(
|
|
549
|
+
arg_name: str, input_type: Union[snowpark_type.DataType, types_proto.DataType]
|
|
550
|
+
) -> str:
|
|
551
|
+
"""If the input_type is a Map, cast the argument arg_name to a Map[key_type, value_type] in Scala."""
|
|
552
|
+
is_snowpark_type = isinstance(input_type, snowpark_type.DataType)
|
|
553
|
+
|
|
554
|
+
def convert_from_string_to_type(
|
|
555
|
+
arg_name: str, t: Union[snowpark_type.DataType, types_proto.DataType]
|
|
556
|
+
) -> str:
|
|
557
|
+
"""Convert the string argument arg_name to the specified type t in Scala."""
|
|
558
|
+
condition = type(t) if is_snowpark_type else t.WhichOneof("kind")
|
|
559
|
+
match condition:
|
|
560
|
+
case snowpark_type.BinaryType | "binary":
|
|
561
|
+
return arg_name + ".getBytes()"
|
|
562
|
+
case snowpark_type.BooleanType | "boolean":
|
|
563
|
+
return arg_name + ".toBoolean"
|
|
564
|
+
case snowpark_type.ByteType | "byte":
|
|
565
|
+
return arg_name + ".getBytes().head" # TODO: verify if this is correct
|
|
566
|
+
case snowpark_type.DateType | "date":
|
|
567
|
+
return f"java.sql.Date.valueOf({arg_name})"
|
|
568
|
+
case snowpark_type.DecimalType | "decimal":
|
|
569
|
+
return f"new BigDecimal({arg_name})"
|
|
570
|
+
case snowpark_type.DoubleType | "double":
|
|
571
|
+
return arg_name + ".toDouble"
|
|
572
|
+
case snowpark_type.FloatType | "float":
|
|
573
|
+
return arg_name + ".toFloat"
|
|
574
|
+
case snowpark_type.IntegerType | "integer":
|
|
575
|
+
return arg_name + ".toInt"
|
|
576
|
+
case snowpark_type.LongType | "long":
|
|
577
|
+
return arg_name + ".toLong"
|
|
578
|
+
case snowpark_type.ShortType | "short":
|
|
579
|
+
return arg_name + ".toShort"
|
|
580
|
+
case snowpark_type.StringType | "string" | "char" | "varchar":
|
|
581
|
+
return arg_name
|
|
582
|
+
case snowpark_type.TimestampType | "timestamp" | "timestamp_ntz":
|
|
583
|
+
return "java.sql.Timestamp.valueOf({arg_name})"
|
|
584
|
+
case _:
|
|
585
|
+
raise ValueError(f"Unsupported Snowpark type: {t}")
|
|
586
|
+
|
|
587
|
+
if (is_snowpark_type and isinstance(input_type, snowpark_type.MapType)) or (
|
|
588
|
+
not is_snowpark_type and input_type.WhichOneof("kind") == "map"
|
|
589
|
+
):
|
|
590
|
+
key_type = input_type.key_type if is_snowpark_type else input_type.map.key_type
|
|
591
|
+
value_type = (
|
|
592
|
+
input_type.value_type if is_snowpark_type else input_type.map.value_type
|
|
593
|
+
)
|
|
594
|
+
return f"{arg_name}.map {{ case (k, v) => ({convert_from_string_to_type('k', key_type)}, {convert_from_string_to_type('v', value_type)})}}"
|
|
595
|
+
else:
|
|
596
|
+
return arg_name
|
|
@@ -14,6 +14,9 @@ from snowflake.snowpark_connect.constants import DEFAULT_CONNECTION_NAME
|
|
|
14
14
|
from snowflake.snowpark_connect.utils.describe_query_cache import (
|
|
15
15
|
instrument_session_for_describe_cache,
|
|
16
16
|
)
|
|
17
|
+
from snowflake.snowpark_connect.utils.external_udxf_cache import (
|
|
18
|
+
init_external_udxf_cache,
|
|
19
|
+
)
|
|
17
20
|
from snowflake.snowpark_connect.utils.snowpark_connect_logging import logger
|
|
18
21
|
from snowflake.snowpark_connect.utils.telemetry import telemetry
|
|
19
22
|
from snowflake.snowpark_connect.utils.udf_cache import init_builtin_udf_cache
|
|
@@ -63,6 +66,7 @@ def configure_snowpark_session(session: snowpark.Session):
|
|
|
63
66
|
|
|
64
67
|
# built-in udf cache
|
|
65
68
|
init_builtin_udf_cache(session)
|
|
69
|
+
init_external_udxf_cache(session)
|
|
66
70
|
|
|
67
71
|
# Set experimental parameters (warnings globally suppressed)
|
|
68
72
|
session.ast_enabled = False
|