snowpark-connect 0.21.0__py3-none-any.whl → 0.22.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of snowpark-connect might be problematic. Click here for more details.

Files changed (41) hide show
  1. snowflake/snowpark_connect/config.py +19 -3
  2. snowflake/snowpark_connect/error/error_utils.py +25 -0
  3. snowflake/snowpark_connect/expression/map_udf.py +4 -4
  4. snowflake/snowpark_connect/expression/map_unresolved_function.py +203 -128
  5. snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2_grpc.py +4 -0
  6. snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2_grpc.py +4 -0
  7. snowflake/snowpark_connect/relation/map_aggregate.py +102 -18
  8. snowflake/snowpark_connect/relation/map_column_ops.py +21 -2
  9. snowflake/snowpark_connect/relation/map_map_partitions.py +3 -1
  10. snowflake/snowpark_connect/relation/map_sql.py +18 -191
  11. snowflake/snowpark_connect/relation/map_udtf.py +4 -4
  12. snowflake/snowpark_connect/relation/read/map_read_json.py +12 -1
  13. snowflake/snowpark_connect/relation/read/reader_config.py +1 -0
  14. snowflake/snowpark_connect/relation/write/map_write.py +68 -24
  15. snowflake/snowpark_connect/server.py +9 -0
  16. snowflake/snowpark_connect/type_mapping.py +4 -0
  17. snowflake/snowpark_connect/utils/describe_query_cache.py +2 -9
  18. snowflake/snowpark_connect/utils/session.py +0 -4
  19. snowflake/snowpark_connect/utils/telemetry.py +213 -61
  20. snowflake/snowpark_connect/utils/udxf_import_utils.py +14 -0
  21. snowflake/snowpark_connect/version.py +1 -1
  22. snowflake/snowpark_decoder/__init__.py +0 -0
  23. snowflake/snowpark_decoder/_internal/proto/generated/DataframeProcessorMsg_pb2.py +36 -0
  24. snowflake/snowpark_decoder/_internal/proto/generated/DataframeProcessorMsg_pb2.pyi +156 -0
  25. snowflake/snowpark_decoder/dp_session.py +111 -0
  26. snowflake/snowpark_decoder/spark_decoder.py +76 -0
  27. {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.22.1.dist-info}/METADATA +2 -2
  28. {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.22.1.dist-info}/RECORD +40 -29
  29. {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.22.1.dist-info}/top_level.txt +1 -0
  30. spark/__init__.py +0 -0
  31. spark/connect/__init__.py +0 -0
  32. spark/connect/envelope_pb2.py +31 -0
  33. spark/connect/envelope_pb2.pyi +46 -0
  34. snowflake/snowpark_connect/includes/jars/jackson-mapper-asl-1.9.13.jar +0 -0
  35. {snowpark_connect-0.21.0.data → snowpark_connect-0.22.1.data}/scripts/snowpark-connect +0 -0
  36. {snowpark_connect-0.21.0.data → snowpark_connect-0.22.1.data}/scripts/snowpark-session +0 -0
  37. {snowpark_connect-0.21.0.data → snowpark_connect-0.22.1.data}/scripts/snowpark-submit +0 -0
  38. {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.22.1.dist-info}/WHEEL +0 -0
  39. {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.22.1.dist-info}/licenses/LICENSE-binary +0 -0
  40. {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.22.1.dist-info}/licenses/LICENSE.txt +0 -0
  41. {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.22.1.dist-info}/licenses/NOTICE-binary +0 -0
@@ -214,27 +214,71 @@ def map_write(request: proto_base.ExecutePlanRequest):
214
214
  )
215
215
  snowpark_table_name = _spark_to_snowflake(table_name)
216
216
 
217
- if write_mode == "overwrite":
218
- if check_snowflake_table_existance(snowpark_table_name, session):
219
- session.sql(f"DELETE FROM {snowpark_table_name}").collect()
220
- write_mode = "append"
221
-
222
- if write_mode in (None, "", "overwrite"):
223
- create_iceberg_table(
224
- snowpark_table_name=snowpark_table_name,
225
- location=write_op.options.get("location", None),
226
- schema=input_df.schema,
227
- snowpark_session=session,
228
- )
229
- write_mode = "append"
230
-
231
- _validate_schema_and_get_writer(
232
- input_df, write_mode, snowpark_table_name
233
- ).saveAsTable(
234
- table_name=snowpark_table_name,
235
- mode=write_mode,
236
- column_order=_column_order_for_write,
237
- )
217
+ match write_mode:
218
+ case None | "error" | "errorifexists":
219
+ if check_snowflake_table_existence(snowpark_table_name, session):
220
+ raise AnalysisException(
221
+ f"Table {snowpark_table_name} already exists"
222
+ )
223
+ create_iceberg_table(
224
+ snowpark_table_name=snowpark_table_name,
225
+ location=write_op.options.get("location", None),
226
+ schema=input_df.schema,
227
+ snowpark_session=session,
228
+ )
229
+ _validate_schema_and_get_writer(
230
+ input_df, "append", snowpark_table_name
231
+ ).saveAsTable(
232
+ table_name=snowpark_table_name,
233
+ mode="append",
234
+ column_order=_column_order_for_write,
235
+ )
236
+ case "append":
237
+ _validate_schema_and_get_writer(
238
+ input_df, "append", snowpark_table_name
239
+ ).saveAsTable(
240
+ table_name=snowpark_table_name,
241
+ mode="append",
242
+ column_order=_column_order_for_write,
243
+ )
244
+ case "ignore":
245
+ if not check_snowflake_table_existence(
246
+ snowpark_table_name, session
247
+ ):
248
+ create_iceberg_table(
249
+ snowpark_table_name=snowpark_table_name,
250
+ location=write_op.options.get("location", None),
251
+ schema=input_df.schema,
252
+ snowpark_session=session,
253
+ )
254
+ _validate_schema_and_get_writer(
255
+ input_df, "append", snowpark_table_name
256
+ ).saveAsTable(
257
+ table_name=snowpark_table_name,
258
+ mode="append",
259
+ column_order=_column_order_for_write,
260
+ )
261
+ case "overwrite":
262
+ if check_snowflake_table_existence(snowpark_table_name, session):
263
+ session.sql(f"DELETE FROM {snowpark_table_name}").collect()
264
+ else:
265
+ create_iceberg_table(
266
+ snowpark_table_name=snowpark_table_name,
267
+ location=write_op.options.get("location", None),
268
+ schema=input_df.schema,
269
+ snowpark_session=session,
270
+ )
271
+ _validate_schema_and_get_writer(
272
+ input_df, "append", snowpark_table_name
273
+ ).saveAsTable(
274
+ table_name=snowpark_table_name,
275
+ mode="append",
276
+ column_order=_column_order_for_write,
277
+ )
278
+ case _:
279
+ raise SnowparkConnectNotImplementedError(
280
+ f"Write mode {write_mode} is not supported"
281
+ )
238
282
  case _:
239
283
  snowpark_table_name = _spark_to_snowflake(write_op.table.table_name)
240
284
 
@@ -299,14 +343,14 @@ def map_write_v2(request: proto_base.ExecutePlanRequest):
299
343
  commands_proto.WriteOperationV2.MODE_OVERWRITE,
300
344
  commands_proto.WriteOperationV2.MODE_APPEND,
301
345
  ):
302
- if not check_snowflake_table_existance(snowpark_table_name, session):
346
+ if not check_snowflake_table_existence(snowpark_table_name, session):
303
347
  raise AnalysisException(
304
348
  f"[TABLE_OR_VIEW_NOT_FOUND] The table or view `{write_op.table_name}` cannot be found. "
305
349
  f"Verify the spelling and correctness of the schema and catalog.\n"
306
350
  )
307
351
 
308
352
  if write_op.provider.lower() == "iceberg":
309
- if write_mode == "overwrite" and check_snowflake_table_existance(
353
+ if write_mode == "overwrite" and check_snowflake_table_existence(
310
354
  snowpark_table_name, session
311
355
  ):
312
356
  session.sql(f"DELETE FROM {snowpark_table_name}").collect()
@@ -584,7 +628,7 @@ def _truncate_directory(directory_path: Path) -> None:
584
628
  shutil.rmtree(file)
585
629
 
586
630
 
587
- def check_snowflake_table_existance(
631
+ def check_snowflake_table_existence(
588
632
  snowpark_table_name: str,
589
633
  snowpark_session: snowpark.Session,
590
634
  ):
@@ -981,6 +981,7 @@ def start_session(
981
981
  stop_event: threading.Event = None,
982
982
  snowpark_session: Optional[snowpark.Session] = None,
983
983
  connection_parameters: Optional[Dict[str, str]] = None,
984
+ max_grpc_message_size: int = _SPARK_CONNECT_GRPC_MAX_MESSAGE_SIZE,
984
985
  ) -> threading.Thread | None:
985
986
  """
986
987
  Starts Spark Connect server connected to Snowflake. No-op if the Server is already running.
@@ -1003,6 +1004,14 @@ def start_session(
1003
1004
  provided, the `snowpark_session` parameter must be None.
1004
1005
  """
1005
1006
  try:
1007
+ # Changing the value of our global variable based on the grpc message size provided by the user.
1008
+ global _SPARK_CONNECT_GRPC_MAX_MESSAGE_SIZE
1009
+ _SPARK_CONNECT_GRPC_MAX_MESSAGE_SIZE = max_grpc_message_size
1010
+
1011
+ from pyspark.sql.connect.client import ChannelBuilder
1012
+
1013
+ ChannelBuilder.MAX_MESSAGE_LENGTH = max_grpc_message_size
1014
+
1006
1015
  if os.environ.get("SPARK_ENV_LOADED"):
1007
1016
  raise RuntimeError(
1008
1017
  "Snowpark Connect cannot be run inside of a Spark environment"
@@ -324,6 +324,8 @@ def cast_to_match_snowpark_type(
324
324
  return str(content)
325
325
  case snowpark.types.VariantType:
326
326
  return str(content)
327
+ case snowpark.types.TimestampType:
328
+ return str(content)
327
329
  case _:
328
330
  raise SnowparkConnectNotImplementedError(
329
331
  f"Unsupported snowpark data type in casting: {data_type}"
@@ -779,6 +781,8 @@ def map_simple_types(simple_type: str) -> snowpark.types.DataType:
779
781
  return snowpark.types.TimestampType()
780
782
  case "timestamp_ntz":
781
783
  return snowpark.types.TimestampType(snowpark.types.TimestampTimeZone.NTZ)
784
+ case "timestamp_ltz":
785
+ return snowpark.types.TimestampType(snowpark.types.TimestampTimeZone.LTZ)
782
786
  case "day_time_interval":
783
787
  # this is not a column type in snowflake so there won't be a dataframe column
784
788
  # with this, for now this type won't make any sense
@@ -131,21 +131,14 @@ def instrument_session_for_describe_cache(session: snowpark.Session):
131
131
  logger.debug(f"DDL detected, clearing describe query cache: '{query}'")
132
132
  cache.clear()
133
133
 
134
- def report_query(qid: str, is_internal: bool) -> None:
135
- if is_internal:
136
- telemetry.report_internal_query()
137
- elif qid:
138
- telemetry.report_query_id(qid)
139
-
140
134
  def wrap_execute(wrapped_fn):
141
135
  def fn(query: str, **kwargs):
142
136
  update_cache_for_query(query)
143
- is_internal = kwargs.get("_is_internal", False)
144
137
  try:
145
138
  result = wrapped_fn(query, **kwargs)
146
- report_query(result.sfqid, is_internal)
139
+ telemetry.report_query(result, **kwargs)
147
140
  except Exception as e:
148
- report_query(e.sfqid, is_internal)
141
+ telemetry.report_query(e, **kwargs)
149
142
  raise e
150
143
  return result
151
144
 
@@ -181,7 +181,3 @@ def set_query_tags(spark_tags: Sequence[str]) -> None:
181
181
 
182
182
  if spark_tags_str != snowpark_session.query_tag:
183
183
  snowpark_session.query_tag = spark_tags_str
184
-
185
-
186
- def get_python_udxf_import_files(session: snowpark.Session) -> str:
187
- return ",".join([file for file in [*session._python_files, *session._import_files]])
@@ -1,19 +1,22 @@
1
1
  #
2
2
  # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
3
3
  #
4
-
4
+ import functools
5
5
  import json
6
6
  import os
7
7
  import queue
8
8
  import threading
9
+ import uuid
9
10
  from abc import ABC, abstractmethod
10
11
  from collections import defaultdict
12
+ from collections.abc import Iterable
11
13
  from contextvars import ContextVar
12
14
  from enum import Enum, unique
13
15
  from typing import Dict
14
16
 
15
17
  import google.protobuf.message
16
18
 
19
+ from snowflake.connector.cursor import SnowflakeCursor
17
20
  from snowflake.connector.telemetry import (
18
21
  TelemetryClient as PCTelemetryClient,
19
22
  TelemetryData as PCTelemetryData,
@@ -21,7 +24,6 @@ from snowflake.connector.telemetry import (
21
24
  )
22
25
  from snowflake.connector.time_util import get_time_millis
23
26
  from snowflake.snowpark import Session
24
- from snowflake.snowpark._internal.telemetry import safe_telemetry
25
27
  from snowflake.snowpark._internal.utils import get_os_name, get_python_version
26
28
  from snowflake.snowpark.version import VERSION as snowpark_version
27
29
  from snowflake.snowpark_connect.utils.snowpark_connect_logging import logger
@@ -43,6 +45,7 @@ class TelemetryField(Enum):
43
45
  KEY_OS = "operating_system"
44
46
  KEY_DATA = "data"
45
47
  KEY_START_TIME = "start_time"
48
+ KEY_EVENT_ID = "event_id"
46
49
 
47
50
 
48
51
  class TelemetryType(Enum):
@@ -107,7 +110,34 @@ REDACTED_PLAN_SUFFIXES = [
107
110
  ]
108
111
 
109
112
 
113
+ def _basic_telemetry_data() -> Dict:
114
+ return {
115
+ **STATIC_TELEMETRY_DATA,
116
+ TelemetryField.KEY_EVENT_ID.value: str(uuid.uuid4()),
117
+ }
118
+
119
+
120
+ def safe(func):
121
+ """
122
+ Decorator to safely execute telemetry functions, catching and logging exceptions
123
+ without affecting the main application flow.
124
+ """
125
+
126
+ @functools.wraps(func)
127
+ def wrap(*args, **kwargs):
128
+ try:
129
+ func(*args, **kwargs)
130
+ except Exception:
131
+ # We don't really care if telemetry fails, just want to be safe for the user
132
+ logger.warning(f"Telemetry operation failed: {func}", exc_info=True)
133
+
134
+ return wrap
135
+
136
+
110
137
  class TelemetrySink(ABC):
138
+ MAX_BUFFER_ELEMENTS = 20
139
+ MAX_WAIT_MS = 10000 # 10 seconds
140
+
111
141
  @abstractmethod
112
142
  def add_telemetry_data(self, message: dict, timestamp: int) -> None:
113
143
  pass
@@ -128,23 +158,44 @@ class NoOpTelemetrySink(TelemetrySink):
128
158
  class ClientTelemetrySink(TelemetrySink):
129
159
  def __init__(self, telemetry_client: PCTelemetryClient) -> None:
130
160
  self._telemetry_client = telemetry_client
161
+ self._lock = threading.Lock()
162
+ self._reset()
131
163
 
132
164
  def add_telemetry_data(self, message: dict, timestamp: int) -> None:
133
165
  telemetry_data = PCTelemetryData(message=message, timestamp=timestamp)
134
166
  self._telemetry_client.try_add_log_to_batch(telemetry_data)
167
+ with self._lock:
168
+ self._events_since_last_flush += 1
169
+ # flush more often than the underlying telemetry client
170
+ if self._should_flush():
171
+ self.flush()
135
172
 
136
173
  def flush(self) -> None:
174
+ with self._lock:
175
+ self._reset()
137
176
  self._telemetry_client.send_batch()
138
177
 
178
+ def _should_flush(self) -> bool:
179
+ current_time = get_time_millis()
180
+
181
+ return (
182
+ self._events_since_last_flush >= TelemetrySink.MAX_BUFFER_ELEMENTS
183
+ or (current_time - self._last_flush_time) >= TelemetrySink.MAX_WAIT_MS
184
+ )
185
+
186
+ def _reset(self):
187
+ self._events_since_last_flush = 0
188
+ self._last_flush_time = get_time_millis()
189
+
139
190
 
140
191
  class QueryTelemetrySink(TelemetrySink):
141
192
 
142
- MAX_BUFFER_SIZE = 100 * 1024 # 100KB
143
- MAX_WAIT_MS = 10000 # 10 seconds
193
+ MAX_BUFFER_SIZE = 20 * 1024 # 20KB
144
194
  TELEMETRY_JOB_ID = "43e72d9b-56d0-4cdb-a615-6b5b5059d6df"
145
195
 
146
196
  def __init__(self, session: Session) -> None:
147
197
  self._session = session
198
+ self._lock = threading.Lock()
148
199
  self._reset()
149
200
 
150
201
  def add_telemetry_data(self, message: dict, timestamp: int) -> None:
@@ -152,31 +203,37 @@ class QueryTelemetrySink(TelemetrySink):
152
203
 
153
204
  # stringify entry, and escape single quotes
154
205
  entry_str = json.dumps(telemetry_entry).replace("'", "''")
155
- self._buffer.append(entry_str)
156
- self._buffer_size += len(entry_str)
157
206
 
158
- current_time = get_time_millis()
159
- if (
160
- self._buffer_size > QueryTelemetrySink.MAX_BUFFER_SIZE
161
- or (current_time - self._last_export_time) > QueryTelemetrySink.MAX_WAIT_MS
162
- ):
207
+ with self._lock:
208
+ self._buffer.append(entry_str)
209
+ self._buffer_size += len(entry_str)
210
+
211
+ if self._should_flush():
163
212
  self.flush()
164
213
 
165
214
  def flush(self) -> None:
166
- if not self._buffer:
167
- return
215
+ with self._lock:
216
+ if not self._buffer:
217
+ return
218
+ # prefix query with a unique identifier for easier tracking
219
+ query = f"select '{self.TELEMETRY_JOB_ID}' as scos_telemetry_export, '[{','.join(self._buffer)}]'"
220
+ self._reset()
168
221
 
169
- # prefix query with a unique identifier for easier tracking
170
- query = f"select '{self.TELEMETRY_JOB_ID}' as scos_telemetry_export, '[{','.join(self._buffer)}]'"
171
222
  self._session.sql(query).collect_nowait()
172
223
 
173
- self._reset()
174
-
175
224
  def _reset(self) -> None:
176
225
  self._buffer = []
177
226
  self._buffer_size = 0
178
227
  self._last_export_time = get_time_millis()
179
228
 
229
+ def _should_flush(self):
230
+ current_time = get_time_millis()
231
+ return (
232
+ self._buffer_size >= QueryTelemetrySink.MAX_BUFFER_SIZE
233
+ or len(self._buffer) >= TelemetrySink.MAX_BUFFER_ELEMENTS
234
+ or (current_time - self._last_export_time) >= TelemetrySink.MAX_WAIT_MS
235
+ )
236
+
180
237
 
181
238
  class Telemetry:
182
239
  def __init__(self, is_enabled=True) -> None:
@@ -185,6 +242,8 @@ class Telemetry:
185
242
  "request_summary", default={}
186
243
  )
187
244
  self._is_enabled = is_enabled
245
+ self._is_initialized = False
246
+ self._lock = threading.Lock()
188
247
 
189
248
  # Async processing setup
190
249
  self._message_queue = queue.Queue(maxsize=10000)
@@ -202,6 +261,12 @@ class Telemetry:
202
261
  if not self._is_enabled:
203
262
  return
204
263
 
264
+ with self._lock:
265
+ if self._is_initialized:
266
+ logger.warning("Telemetry is already initialized")
267
+ return
268
+ self._is_initialized = True
269
+
205
270
  telemetry = getattr(session._conn._conn, "_telemetry", None)
206
271
  if telemetry is None:
207
272
  # no telemetry client available, so we export with queries
@@ -210,8 +275,9 @@ class Telemetry:
210
275
  self._sink = ClientTelemetrySink(telemetry)
211
276
 
212
277
  self._start_worker_thread()
278
+ logger.info(f"Telemetry initialized with {type(self._sink)}")
213
279
 
214
- @safe_telemetry
280
+ @safe
215
281
  def initialize_request_summary(
216
282
  self, request: google.protobuf.message.Message
217
283
  ) -> None:
@@ -234,8 +300,29 @@ class Telemetry:
234
300
  request.plan, REDACTED_PLAN_SUFFIXES
235
301
  )
236
302
 
237
- @safe_telemetry
303
+ def _not_in_request(self):
304
+ # we don't want to add things to the summary if it's not initialized
305
+ return "created_on" not in self._request_summary.get()
306
+
307
+ @safe
308
+ def report_parsed_sql_plan(self, plan: google.protobuf.message.Message) -> None:
309
+ if self._not_in_request():
310
+ return
311
+
312
+ summary = self._request_summary.get()
313
+
314
+ if "parsed_sql_plans" not in summary:
315
+ summary["parsed_sql_plans"] = []
316
+
317
+ summary["parsed_sql_plans"].append(
318
+ _protobuf_to_json_with_redaction(plan, REDACTED_PLAN_SUFFIXES)
319
+ )
320
+
321
+ @safe
238
322
  def report_function_usage(self, function_name: str) -> None:
323
+ if self._not_in_request():
324
+ return
325
+
239
326
  summary = self._request_summary.get()
240
327
 
241
328
  if "used_functions" not in summary:
@@ -243,8 +330,11 @@ class Telemetry:
243
330
 
244
331
  summary["used_functions"][function_name] += 1
245
332
 
246
- @safe_telemetry
333
+ @safe
247
334
  def report_request_failure(self, e: Exception) -> None:
335
+ if self._not_in_request():
336
+ return
337
+
248
338
  summary = self._request_summary.get()
249
339
 
250
340
  summary["was_successful"] = False
@@ -255,37 +345,78 @@ class Telemetry:
255
345
  if error_location:
256
346
  summary["error_location"] = error_location
257
347
 
258
- @safe_telemetry
259
- def report_config_set(self, key, value):
348
+ @safe
349
+ def report_config_set(self, pairs: Iterable) -> None:
350
+ if self._not_in_request():
351
+ return
352
+
260
353
  summary = self._request_summary.get()
261
354
 
262
355
  if "config_set" not in summary:
263
356
  summary["config_set"] = []
264
357
 
265
- summary["config_set"].append(
266
- {
267
- "key": key,
268
- "value": value if key in RECORDED_CONFIG_KEYS else "<redacted>",
269
- }
270
- )
358
+ for p in pairs:
359
+ summary["config_set"].append(
360
+ {
361
+ "key": p.key,
362
+ "value": p.value if p.key in RECORDED_CONFIG_KEYS else "<redacted>",
363
+ }
364
+ )
365
+
366
+ @safe
367
+ def report_config_unset(self, keys: Iterable[str]) -> None:
368
+ if self._not_in_request():
369
+ return
271
370
 
272
- @safe_telemetry
273
- def report_config_unset(self, key):
274
371
  summary = self._request_summary.get()
275
372
 
276
373
  if "config_unset" not in summary:
277
374
  summary["config_unset"] = []
278
375
 
279
- summary["config_unset"].append(key)
376
+ summary["config_unset"].extend(keys)
377
+
378
+ @safe
379
+ def report_config_get(self, keys: Iterable[str]) -> None:
380
+ if self._not_in_request():
381
+ return
280
382
 
281
- @safe_telemetry
282
- def report_config_op_type(self, op_type: str):
283
383
  summary = self._request_summary.get()
284
384
 
385
+ if "config_get" not in summary:
386
+ summary["config_get"] = []
387
+
388
+ summary["config_get"].extend(keys)
389
+
390
+ @safe
391
+ def report_config_op_type(self, op_type: str):
392
+ if self._not_in_request():
393
+ return
394
+
395
+ summary = self._request_summary.get()
285
396
  summary["config_op_type"] = op_type
286
397
 
287
- @safe_telemetry
288
- def report_query_id(self, query_id: str):
398
+ @safe
399
+ def report_query(
400
+ self, result: SnowflakeCursor | dict | Exception, **kwargs
401
+ ) -> None:
402
+ if result is None or isinstance(result, dict) or self._not_in_request():
403
+ return
404
+
405
+ # SnowflakeCursor and SQL errors will have sfqid
406
+ # other exceptions will not have it
407
+ # TODO: handle async queries, but filter out telemetry export queries
408
+ qid = getattr(result, "sfqid", None)
409
+
410
+ if qid is None:
411
+ logger.warning("Missing query id in result: %s", result)
412
+
413
+ is_internal = kwargs.get("_is_internal", False)
414
+ if is_internal:
415
+ self._report_internal_query()
416
+ elif qid:
417
+ self._report_query_id(qid)
418
+
419
+ def _report_query_id(self, query_id: str):
289
420
  summary = self._request_summary.get()
290
421
 
291
422
  if "queries" not in summary:
@@ -293,13 +424,19 @@ class Telemetry:
293
424
 
294
425
  summary["queries"].append(query_id)
295
426
 
296
- @safe_telemetry
297
- def report_internal_query(self):
427
+ def _report_internal_query(self):
298
428
  summary = self._request_summary.get()
429
+
430
+ if "internal_queries" not in summary:
431
+ summary["internal_queries"] = 0
432
+
299
433
  summary["internal_queries"] += 1
300
434
 
301
- @safe_telemetry
435
+ @safe
302
436
  def report_udf_usage(self, udf_name: str):
437
+ if self._not_in_request():
438
+ return
439
+
303
440
  summary = self._request_summary.get()
304
441
 
305
442
  if "udf_usage" not in summary:
@@ -307,8 +444,10 @@ class Telemetry:
307
444
 
308
445
  summary["udf_usage"][udf_name] += 1
309
446
 
310
- @safe_telemetry
311
- def report_io(self, op: str, type: str, options: dict | None):
447
+ def _report_io(self, op: str, type: str, options: dict | None):
448
+ if self._not_in_request():
449
+ return
450
+
312
451
  summary = self._request_summary.get()
313
452
 
314
453
  if "io" not in summary:
@@ -321,16 +460,18 @@ class Telemetry:
321
460
 
322
461
  summary["io"].append(io)
323
462
 
463
+ @safe
324
464
  def report_io_read(self, type: str, options: dict | None):
325
- self.report_io("read", type, options)
465
+ self._report_io("read", type, options)
326
466
 
467
+ @safe
327
468
  def report_io_write(self, type: str, options: dict | None):
328
- self.report_io("write", type, options)
469
+ self._report_io("write", type, options)
329
470
 
330
- @safe_telemetry
471
+ @safe
331
472
  def send_server_started_telemetry(self):
332
473
  message = {
333
- **STATIC_TELEMETRY_DATA,
474
+ **_basic_telemetry_data(),
334
475
  TelemetryField.KEY_TYPE.value: TelemetryType.TYPE_EVENT.value,
335
476
  TelemetryType.EVENT_TYPE.value: EventType.SERVER_STARTED.value,
336
477
  TelemetryField.KEY_DATA.value: {
@@ -339,17 +480,22 @@ class Telemetry:
339
480
  }
340
481
  self._send(message)
341
482
 
342
- @safe_telemetry
483
+ @safe
343
484
  def send_request_summary_telemetry(self):
485
+ if self._not_in_request():
486
+ logger.warning(
487
+ "Truing to send request summary telemetry without initializing it"
488
+ )
489
+ return
490
+
344
491
  summary = self._request_summary.get()
345
492
  message = {
346
- **STATIC_TELEMETRY_DATA,
493
+ **_basic_telemetry_data(),
347
494
  TelemetryField.KEY_TYPE.value: TelemetryType.TYPE_REQUEST_SUMMARY.value,
348
495
  TelemetryField.KEY_DATA.value: summary,
349
496
  }
350
497
  self._send(message)
351
498
 
352
- @safe_telemetry
353
499
  def _send(self, msg: Dict) -> None:
354
500
  """Queue a telemetry message for asynchronous processing."""
355
501
  if not self._is_enabled:
@@ -385,19 +531,6 @@ class Telemetry:
385
531
  finally:
386
532
  self._message_queue.task_done()
387
533
 
388
- # Process any remaining messages
389
- while not self._message_queue.empty():
390
- try:
391
- message, timestamp = self._message_queue.get_nowait()
392
- self._sink.add_telemetry_data(message, timestamp)
393
- self._message_queue.task_done()
394
- except Exception:
395
- logger.warning(
396
- "Failed to add remaining telemetry messages to sink during shutdown",
397
- exc_info=True,
398
- )
399
- break
400
-
401
534
  # Flush the sink
402
535
  self._sink.flush()
403
536
 
@@ -439,6 +572,18 @@ def _error_location(e: Exception) -> Dict | None:
439
572
  }
440
573
 
441
574
 
575
+ def _is_map_field(field_descriptor) -> bool:
576
+ """
577
+ Check if a protobuf field is a map.
578
+ """
579
+ return (
580
+ field_descriptor.label == field_descriptor.LABEL_REPEATED
581
+ and field_descriptor.message_type is not None
582
+ and field_descriptor.message_type.has_options
583
+ and field_descriptor.message_type.GetOptions().map_entry
584
+ )
585
+
586
+
442
587
  def _protobuf_to_json_with_redaction(
443
588
  message: google.protobuf.message.Message, redacted_suffixes: list[str]
444
589
  ) -> dict:
@@ -463,7 +608,9 @@ def _protobuf_to_json_with_redaction(
463
608
  return "<redacted>"
464
609
 
465
610
  # Handle different field types
466
- if field_descriptor.type == field_descriptor.TYPE_MESSAGE:
611
+ if _is_map_field(field_descriptor):
612
+ return dict(value)
613
+ elif field_descriptor.type == field_descriptor.TYPE_MESSAGE:
467
614
  if field_descriptor.label == field_descriptor.LABEL_REPEATED:
468
615
  # Repeated message field
469
616
  return [_protobuf_to_json_recursive(item, field_path) for item in value]
@@ -481,6 +628,11 @@ def _protobuf_to_json_with_redaction(
481
628
  msg: google.protobuf.message.Message, current_path: str = ""
482
629
  ) -> dict:
483
630
  """Recursively convert protobuf message to dict"""
631
+
632
+ if not isinstance(msg, google.protobuf.message.Message):
633
+ logger.warning("Expected a protobuf message, got: %s", type(msg))
634
+ return {}
635
+
484
636
  result = {}
485
637
 
486
638
  # Use ListFields() to get all set fields