snowpark-connect 0.21.0__py3-none-any.whl → 0.22.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of snowpark-connect might be problematic. Click here for more details.
- snowflake/snowpark_connect/config.py +19 -3
- snowflake/snowpark_connect/error/error_utils.py +25 -0
- snowflake/snowpark_connect/expression/map_udf.py +4 -4
- snowflake/snowpark_connect/expression/map_unresolved_function.py +203 -128
- snowflake/snowpark_connect/proto/snowflake_expression_ext_pb2_grpc.py +4 -0
- snowflake/snowpark_connect/proto/snowflake_relation_ext_pb2_grpc.py +4 -0
- snowflake/snowpark_connect/relation/map_aggregate.py +102 -18
- snowflake/snowpark_connect/relation/map_column_ops.py +21 -2
- snowflake/snowpark_connect/relation/map_map_partitions.py +3 -1
- snowflake/snowpark_connect/relation/map_sql.py +18 -191
- snowflake/snowpark_connect/relation/map_udtf.py +4 -4
- snowflake/snowpark_connect/relation/read/map_read_json.py +12 -1
- snowflake/snowpark_connect/relation/read/reader_config.py +1 -0
- snowflake/snowpark_connect/relation/write/map_write.py +68 -24
- snowflake/snowpark_connect/server.py +9 -0
- snowflake/snowpark_connect/type_mapping.py +4 -0
- snowflake/snowpark_connect/utils/describe_query_cache.py +2 -9
- snowflake/snowpark_connect/utils/session.py +0 -4
- snowflake/snowpark_connect/utils/telemetry.py +213 -61
- snowflake/snowpark_connect/utils/udxf_import_utils.py +14 -0
- snowflake/snowpark_connect/version.py +1 -1
- snowflake/snowpark_decoder/__init__.py +0 -0
- snowflake/snowpark_decoder/_internal/proto/generated/DataframeProcessorMsg_pb2.py +36 -0
- snowflake/snowpark_decoder/_internal/proto/generated/DataframeProcessorMsg_pb2.pyi +156 -0
- snowflake/snowpark_decoder/dp_session.py +111 -0
- snowflake/snowpark_decoder/spark_decoder.py +76 -0
- {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.22.1.dist-info}/METADATA +2 -2
- {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.22.1.dist-info}/RECORD +40 -29
- {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.22.1.dist-info}/top_level.txt +1 -0
- spark/__init__.py +0 -0
- spark/connect/__init__.py +0 -0
- spark/connect/envelope_pb2.py +31 -0
- spark/connect/envelope_pb2.pyi +46 -0
- snowflake/snowpark_connect/includes/jars/jackson-mapper-asl-1.9.13.jar +0 -0
- {snowpark_connect-0.21.0.data → snowpark_connect-0.22.1.data}/scripts/snowpark-connect +0 -0
- {snowpark_connect-0.21.0.data → snowpark_connect-0.22.1.data}/scripts/snowpark-session +0 -0
- {snowpark_connect-0.21.0.data → snowpark_connect-0.22.1.data}/scripts/snowpark-submit +0 -0
- {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.22.1.dist-info}/WHEEL +0 -0
- {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.22.1.dist-info}/licenses/LICENSE-binary +0 -0
- {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.22.1.dist-info}/licenses/LICENSE.txt +0 -0
- {snowpark_connect-0.21.0.dist-info → snowpark_connect-0.22.1.dist-info}/licenses/NOTICE-binary +0 -0
|
@@ -214,27 +214,71 @@ def map_write(request: proto_base.ExecutePlanRequest):
|
|
|
214
214
|
)
|
|
215
215
|
snowpark_table_name = _spark_to_snowflake(table_name)
|
|
216
216
|
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
217
|
+
match write_mode:
|
|
218
|
+
case None | "error" | "errorifexists":
|
|
219
|
+
if check_snowflake_table_existence(snowpark_table_name, session):
|
|
220
|
+
raise AnalysisException(
|
|
221
|
+
f"Table {snowpark_table_name} already exists"
|
|
222
|
+
)
|
|
223
|
+
create_iceberg_table(
|
|
224
|
+
snowpark_table_name=snowpark_table_name,
|
|
225
|
+
location=write_op.options.get("location", None),
|
|
226
|
+
schema=input_df.schema,
|
|
227
|
+
snowpark_session=session,
|
|
228
|
+
)
|
|
229
|
+
_validate_schema_and_get_writer(
|
|
230
|
+
input_df, "append", snowpark_table_name
|
|
231
|
+
).saveAsTable(
|
|
232
|
+
table_name=snowpark_table_name,
|
|
233
|
+
mode="append",
|
|
234
|
+
column_order=_column_order_for_write,
|
|
235
|
+
)
|
|
236
|
+
case "append":
|
|
237
|
+
_validate_schema_and_get_writer(
|
|
238
|
+
input_df, "append", snowpark_table_name
|
|
239
|
+
).saveAsTable(
|
|
240
|
+
table_name=snowpark_table_name,
|
|
241
|
+
mode="append",
|
|
242
|
+
column_order=_column_order_for_write,
|
|
243
|
+
)
|
|
244
|
+
case "ignore":
|
|
245
|
+
if not check_snowflake_table_existence(
|
|
246
|
+
snowpark_table_name, session
|
|
247
|
+
):
|
|
248
|
+
create_iceberg_table(
|
|
249
|
+
snowpark_table_name=snowpark_table_name,
|
|
250
|
+
location=write_op.options.get("location", None),
|
|
251
|
+
schema=input_df.schema,
|
|
252
|
+
snowpark_session=session,
|
|
253
|
+
)
|
|
254
|
+
_validate_schema_and_get_writer(
|
|
255
|
+
input_df, "append", snowpark_table_name
|
|
256
|
+
).saveAsTable(
|
|
257
|
+
table_name=snowpark_table_name,
|
|
258
|
+
mode="append",
|
|
259
|
+
column_order=_column_order_for_write,
|
|
260
|
+
)
|
|
261
|
+
case "overwrite":
|
|
262
|
+
if check_snowflake_table_existence(snowpark_table_name, session):
|
|
263
|
+
session.sql(f"DELETE FROM {snowpark_table_name}").collect()
|
|
264
|
+
else:
|
|
265
|
+
create_iceberg_table(
|
|
266
|
+
snowpark_table_name=snowpark_table_name,
|
|
267
|
+
location=write_op.options.get("location", None),
|
|
268
|
+
schema=input_df.schema,
|
|
269
|
+
snowpark_session=session,
|
|
270
|
+
)
|
|
271
|
+
_validate_schema_and_get_writer(
|
|
272
|
+
input_df, "append", snowpark_table_name
|
|
273
|
+
).saveAsTable(
|
|
274
|
+
table_name=snowpark_table_name,
|
|
275
|
+
mode="append",
|
|
276
|
+
column_order=_column_order_for_write,
|
|
277
|
+
)
|
|
278
|
+
case _:
|
|
279
|
+
raise SnowparkConnectNotImplementedError(
|
|
280
|
+
f"Write mode {write_mode} is not supported"
|
|
281
|
+
)
|
|
238
282
|
case _:
|
|
239
283
|
snowpark_table_name = _spark_to_snowflake(write_op.table.table_name)
|
|
240
284
|
|
|
@@ -299,14 +343,14 @@ def map_write_v2(request: proto_base.ExecutePlanRequest):
|
|
|
299
343
|
commands_proto.WriteOperationV2.MODE_OVERWRITE,
|
|
300
344
|
commands_proto.WriteOperationV2.MODE_APPEND,
|
|
301
345
|
):
|
|
302
|
-
if not
|
|
346
|
+
if not check_snowflake_table_existence(snowpark_table_name, session):
|
|
303
347
|
raise AnalysisException(
|
|
304
348
|
f"[TABLE_OR_VIEW_NOT_FOUND] The table or view `{write_op.table_name}` cannot be found. "
|
|
305
349
|
f"Verify the spelling and correctness of the schema and catalog.\n"
|
|
306
350
|
)
|
|
307
351
|
|
|
308
352
|
if write_op.provider.lower() == "iceberg":
|
|
309
|
-
if write_mode == "overwrite" and
|
|
353
|
+
if write_mode == "overwrite" and check_snowflake_table_existence(
|
|
310
354
|
snowpark_table_name, session
|
|
311
355
|
):
|
|
312
356
|
session.sql(f"DELETE FROM {snowpark_table_name}").collect()
|
|
@@ -584,7 +628,7 @@ def _truncate_directory(directory_path: Path) -> None:
|
|
|
584
628
|
shutil.rmtree(file)
|
|
585
629
|
|
|
586
630
|
|
|
587
|
-
def
|
|
631
|
+
def check_snowflake_table_existence(
|
|
588
632
|
snowpark_table_name: str,
|
|
589
633
|
snowpark_session: snowpark.Session,
|
|
590
634
|
):
|
|
@@ -981,6 +981,7 @@ def start_session(
|
|
|
981
981
|
stop_event: threading.Event = None,
|
|
982
982
|
snowpark_session: Optional[snowpark.Session] = None,
|
|
983
983
|
connection_parameters: Optional[Dict[str, str]] = None,
|
|
984
|
+
max_grpc_message_size: int = _SPARK_CONNECT_GRPC_MAX_MESSAGE_SIZE,
|
|
984
985
|
) -> threading.Thread | None:
|
|
985
986
|
"""
|
|
986
987
|
Starts Spark Connect server connected to Snowflake. No-op if the Server is already running.
|
|
@@ -1003,6 +1004,14 @@ def start_session(
|
|
|
1003
1004
|
provided, the `snowpark_session` parameter must be None.
|
|
1004
1005
|
"""
|
|
1005
1006
|
try:
|
|
1007
|
+
# Changing the value of our global variable based on the grpc message size provided by the user.
|
|
1008
|
+
global _SPARK_CONNECT_GRPC_MAX_MESSAGE_SIZE
|
|
1009
|
+
_SPARK_CONNECT_GRPC_MAX_MESSAGE_SIZE = max_grpc_message_size
|
|
1010
|
+
|
|
1011
|
+
from pyspark.sql.connect.client import ChannelBuilder
|
|
1012
|
+
|
|
1013
|
+
ChannelBuilder.MAX_MESSAGE_LENGTH = max_grpc_message_size
|
|
1014
|
+
|
|
1006
1015
|
if os.environ.get("SPARK_ENV_LOADED"):
|
|
1007
1016
|
raise RuntimeError(
|
|
1008
1017
|
"Snowpark Connect cannot be run inside of a Spark environment"
|
|
@@ -324,6 +324,8 @@ def cast_to_match_snowpark_type(
|
|
|
324
324
|
return str(content)
|
|
325
325
|
case snowpark.types.VariantType:
|
|
326
326
|
return str(content)
|
|
327
|
+
case snowpark.types.TimestampType:
|
|
328
|
+
return str(content)
|
|
327
329
|
case _:
|
|
328
330
|
raise SnowparkConnectNotImplementedError(
|
|
329
331
|
f"Unsupported snowpark data type in casting: {data_type}"
|
|
@@ -779,6 +781,8 @@ def map_simple_types(simple_type: str) -> snowpark.types.DataType:
|
|
|
779
781
|
return snowpark.types.TimestampType()
|
|
780
782
|
case "timestamp_ntz":
|
|
781
783
|
return snowpark.types.TimestampType(snowpark.types.TimestampTimeZone.NTZ)
|
|
784
|
+
case "timestamp_ltz":
|
|
785
|
+
return snowpark.types.TimestampType(snowpark.types.TimestampTimeZone.LTZ)
|
|
782
786
|
case "day_time_interval":
|
|
783
787
|
# this is not a column type in snowflake so there won't be a dataframe column
|
|
784
788
|
# with this, for now this type won't make any sense
|
|
@@ -131,21 +131,14 @@ def instrument_session_for_describe_cache(session: snowpark.Session):
|
|
|
131
131
|
logger.debug(f"DDL detected, clearing describe query cache: '{query}'")
|
|
132
132
|
cache.clear()
|
|
133
133
|
|
|
134
|
-
def report_query(qid: str, is_internal: bool) -> None:
|
|
135
|
-
if is_internal:
|
|
136
|
-
telemetry.report_internal_query()
|
|
137
|
-
elif qid:
|
|
138
|
-
telemetry.report_query_id(qid)
|
|
139
|
-
|
|
140
134
|
def wrap_execute(wrapped_fn):
|
|
141
135
|
def fn(query: str, **kwargs):
|
|
142
136
|
update_cache_for_query(query)
|
|
143
|
-
is_internal = kwargs.get("_is_internal", False)
|
|
144
137
|
try:
|
|
145
138
|
result = wrapped_fn(query, **kwargs)
|
|
146
|
-
report_query(result
|
|
139
|
+
telemetry.report_query(result, **kwargs)
|
|
147
140
|
except Exception as e:
|
|
148
|
-
report_query(e
|
|
141
|
+
telemetry.report_query(e, **kwargs)
|
|
149
142
|
raise e
|
|
150
143
|
return result
|
|
151
144
|
|
|
@@ -181,7 +181,3 @@ def set_query_tags(spark_tags: Sequence[str]) -> None:
|
|
|
181
181
|
|
|
182
182
|
if spark_tags_str != snowpark_session.query_tag:
|
|
183
183
|
snowpark_session.query_tag = spark_tags_str
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
def get_python_udxf_import_files(session: snowpark.Session) -> str:
|
|
187
|
-
return ",".join([file for file in [*session._python_files, *session._import_files]])
|
|
@@ -1,19 +1,22 @@
|
|
|
1
1
|
#
|
|
2
2
|
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
|
|
3
3
|
#
|
|
4
|
-
|
|
4
|
+
import functools
|
|
5
5
|
import json
|
|
6
6
|
import os
|
|
7
7
|
import queue
|
|
8
8
|
import threading
|
|
9
|
+
import uuid
|
|
9
10
|
from abc import ABC, abstractmethod
|
|
10
11
|
from collections import defaultdict
|
|
12
|
+
from collections.abc import Iterable
|
|
11
13
|
from contextvars import ContextVar
|
|
12
14
|
from enum import Enum, unique
|
|
13
15
|
from typing import Dict
|
|
14
16
|
|
|
15
17
|
import google.protobuf.message
|
|
16
18
|
|
|
19
|
+
from snowflake.connector.cursor import SnowflakeCursor
|
|
17
20
|
from snowflake.connector.telemetry import (
|
|
18
21
|
TelemetryClient as PCTelemetryClient,
|
|
19
22
|
TelemetryData as PCTelemetryData,
|
|
@@ -21,7 +24,6 @@ from snowflake.connector.telemetry import (
|
|
|
21
24
|
)
|
|
22
25
|
from snowflake.connector.time_util import get_time_millis
|
|
23
26
|
from snowflake.snowpark import Session
|
|
24
|
-
from snowflake.snowpark._internal.telemetry import safe_telemetry
|
|
25
27
|
from snowflake.snowpark._internal.utils import get_os_name, get_python_version
|
|
26
28
|
from snowflake.snowpark.version import VERSION as snowpark_version
|
|
27
29
|
from snowflake.snowpark_connect.utils.snowpark_connect_logging import logger
|
|
@@ -43,6 +45,7 @@ class TelemetryField(Enum):
|
|
|
43
45
|
KEY_OS = "operating_system"
|
|
44
46
|
KEY_DATA = "data"
|
|
45
47
|
KEY_START_TIME = "start_time"
|
|
48
|
+
KEY_EVENT_ID = "event_id"
|
|
46
49
|
|
|
47
50
|
|
|
48
51
|
class TelemetryType(Enum):
|
|
@@ -107,7 +110,34 @@ REDACTED_PLAN_SUFFIXES = [
|
|
|
107
110
|
]
|
|
108
111
|
|
|
109
112
|
|
|
113
|
+
def _basic_telemetry_data() -> Dict:
|
|
114
|
+
return {
|
|
115
|
+
**STATIC_TELEMETRY_DATA,
|
|
116
|
+
TelemetryField.KEY_EVENT_ID.value: str(uuid.uuid4()),
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def safe(func):
|
|
121
|
+
"""
|
|
122
|
+
Decorator to safely execute telemetry functions, catching and logging exceptions
|
|
123
|
+
without affecting the main application flow.
|
|
124
|
+
"""
|
|
125
|
+
|
|
126
|
+
@functools.wraps(func)
|
|
127
|
+
def wrap(*args, **kwargs):
|
|
128
|
+
try:
|
|
129
|
+
func(*args, **kwargs)
|
|
130
|
+
except Exception:
|
|
131
|
+
# We don't really care if telemetry fails, just want to be safe for the user
|
|
132
|
+
logger.warning(f"Telemetry operation failed: {func}", exc_info=True)
|
|
133
|
+
|
|
134
|
+
return wrap
|
|
135
|
+
|
|
136
|
+
|
|
110
137
|
class TelemetrySink(ABC):
|
|
138
|
+
MAX_BUFFER_ELEMENTS = 20
|
|
139
|
+
MAX_WAIT_MS = 10000 # 10 seconds
|
|
140
|
+
|
|
111
141
|
@abstractmethod
|
|
112
142
|
def add_telemetry_data(self, message: dict, timestamp: int) -> None:
|
|
113
143
|
pass
|
|
@@ -128,23 +158,44 @@ class NoOpTelemetrySink(TelemetrySink):
|
|
|
128
158
|
class ClientTelemetrySink(TelemetrySink):
|
|
129
159
|
def __init__(self, telemetry_client: PCTelemetryClient) -> None:
|
|
130
160
|
self._telemetry_client = telemetry_client
|
|
161
|
+
self._lock = threading.Lock()
|
|
162
|
+
self._reset()
|
|
131
163
|
|
|
132
164
|
def add_telemetry_data(self, message: dict, timestamp: int) -> None:
|
|
133
165
|
telemetry_data = PCTelemetryData(message=message, timestamp=timestamp)
|
|
134
166
|
self._telemetry_client.try_add_log_to_batch(telemetry_data)
|
|
167
|
+
with self._lock:
|
|
168
|
+
self._events_since_last_flush += 1
|
|
169
|
+
# flush more often than the underlying telemetry client
|
|
170
|
+
if self._should_flush():
|
|
171
|
+
self.flush()
|
|
135
172
|
|
|
136
173
|
def flush(self) -> None:
|
|
174
|
+
with self._lock:
|
|
175
|
+
self._reset()
|
|
137
176
|
self._telemetry_client.send_batch()
|
|
138
177
|
|
|
178
|
+
def _should_flush(self) -> bool:
|
|
179
|
+
current_time = get_time_millis()
|
|
180
|
+
|
|
181
|
+
return (
|
|
182
|
+
self._events_since_last_flush >= TelemetrySink.MAX_BUFFER_ELEMENTS
|
|
183
|
+
or (current_time - self._last_flush_time) >= TelemetrySink.MAX_WAIT_MS
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
def _reset(self):
|
|
187
|
+
self._events_since_last_flush = 0
|
|
188
|
+
self._last_flush_time = get_time_millis()
|
|
189
|
+
|
|
139
190
|
|
|
140
191
|
class QueryTelemetrySink(TelemetrySink):
|
|
141
192
|
|
|
142
|
-
MAX_BUFFER_SIZE =
|
|
143
|
-
MAX_WAIT_MS = 10000 # 10 seconds
|
|
193
|
+
MAX_BUFFER_SIZE = 20 * 1024 # 20KB
|
|
144
194
|
TELEMETRY_JOB_ID = "43e72d9b-56d0-4cdb-a615-6b5b5059d6df"
|
|
145
195
|
|
|
146
196
|
def __init__(self, session: Session) -> None:
|
|
147
197
|
self._session = session
|
|
198
|
+
self._lock = threading.Lock()
|
|
148
199
|
self._reset()
|
|
149
200
|
|
|
150
201
|
def add_telemetry_data(self, message: dict, timestamp: int) -> None:
|
|
@@ -152,31 +203,37 @@ class QueryTelemetrySink(TelemetrySink):
|
|
|
152
203
|
|
|
153
204
|
# stringify entry, and escape single quotes
|
|
154
205
|
entry_str = json.dumps(telemetry_entry).replace("'", "''")
|
|
155
|
-
self._buffer.append(entry_str)
|
|
156
|
-
self._buffer_size += len(entry_str)
|
|
157
206
|
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
self._buffer_size
|
|
161
|
-
|
|
162
|
-
):
|
|
207
|
+
with self._lock:
|
|
208
|
+
self._buffer.append(entry_str)
|
|
209
|
+
self._buffer_size += len(entry_str)
|
|
210
|
+
|
|
211
|
+
if self._should_flush():
|
|
163
212
|
self.flush()
|
|
164
213
|
|
|
165
214
|
def flush(self) -> None:
|
|
166
|
-
|
|
167
|
-
|
|
215
|
+
with self._lock:
|
|
216
|
+
if not self._buffer:
|
|
217
|
+
return
|
|
218
|
+
# prefix query with a unique identifier for easier tracking
|
|
219
|
+
query = f"select '{self.TELEMETRY_JOB_ID}' as scos_telemetry_export, '[{','.join(self._buffer)}]'"
|
|
220
|
+
self._reset()
|
|
168
221
|
|
|
169
|
-
# prefix query with a unique identifier for easier tracking
|
|
170
|
-
query = f"select '{self.TELEMETRY_JOB_ID}' as scos_telemetry_export, '[{','.join(self._buffer)}]'"
|
|
171
222
|
self._session.sql(query).collect_nowait()
|
|
172
223
|
|
|
173
|
-
self._reset()
|
|
174
|
-
|
|
175
224
|
def _reset(self) -> None:
|
|
176
225
|
self._buffer = []
|
|
177
226
|
self._buffer_size = 0
|
|
178
227
|
self._last_export_time = get_time_millis()
|
|
179
228
|
|
|
229
|
+
def _should_flush(self):
|
|
230
|
+
current_time = get_time_millis()
|
|
231
|
+
return (
|
|
232
|
+
self._buffer_size >= QueryTelemetrySink.MAX_BUFFER_SIZE
|
|
233
|
+
or len(self._buffer) >= TelemetrySink.MAX_BUFFER_ELEMENTS
|
|
234
|
+
or (current_time - self._last_export_time) >= TelemetrySink.MAX_WAIT_MS
|
|
235
|
+
)
|
|
236
|
+
|
|
180
237
|
|
|
181
238
|
class Telemetry:
|
|
182
239
|
def __init__(self, is_enabled=True) -> None:
|
|
@@ -185,6 +242,8 @@ class Telemetry:
|
|
|
185
242
|
"request_summary", default={}
|
|
186
243
|
)
|
|
187
244
|
self._is_enabled = is_enabled
|
|
245
|
+
self._is_initialized = False
|
|
246
|
+
self._lock = threading.Lock()
|
|
188
247
|
|
|
189
248
|
# Async processing setup
|
|
190
249
|
self._message_queue = queue.Queue(maxsize=10000)
|
|
@@ -202,6 +261,12 @@ class Telemetry:
|
|
|
202
261
|
if not self._is_enabled:
|
|
203
262
|
return
|
|
204
263
|
|
|
264
|
+
with self._lock:
|
|
265
|
+
if self._is_initialized:
|
|
266
|
+
logger.warning("Telemetry is already initialized")
|
|
267
|
+
return
|
|
268
|
+
self._is_initialized = True
|
|
269
|
+
|
|
205
270
|
telemetry = getattr(session._conn._conn, "_telemetry", None)
|
|
206
271
|
if telemetry is None:
|
|
207
272
|
# no telemetry client available, so we export with queries
|
|
@@ -210,8 +275,9 @@ class Telemetry:
|
|
|
210
275
|
self._sink = ClientTelemetrySink(telemetry)
|
|
211
276
|
|
|
212
277
|
self._start_worker_thread()
|
|
278
|
+
logger.info(f"Telemetry initialized with {type(self._sink)}")
|
|
213
279
|
|
|
214
|
-
@
|
|
280
|
+
@safe
|
|
215
281
|
def initialize_request_summary(
|
|
216
282
|
self, request: google.protobuf.message.Message
|
|
217
283
|
) -> None:
|
|
@@ -234,8 +300,29 @@ class Telemetry:
|
|
|
234
300
|
request.plan, REDACTED_PLAN_SUFFIXES
|
|
235
301
|
)
|
|
236
302
|
|
|
237
|
-
|
|
303
|
+
def _not_in_request(self):
|
|
304
|
+
# we don't want to add things to the summary if it's not initialized
|
|
305
|
+
return "created_on" not in self._request_summary.get()
|
|
306
|
+
|
|
307
|
+
@safe
|
|
308
|
+
def report_parsed_sql_plan(self, plan: google.protobuf.message.Message) -> None:
|
|
309
|
+
if self._not_in_request():
|
|
310
|
+
return
|
|
311
|
+
|
|
312
|
+
summary = self._request_summary.get()
|
|
313
|
+
|
|
314
|
+
if "parsed_sql_plans" not in summary:
|
|
315
|
+
summary["parsed_sql_plans"] = []
|
|
316
|
+
|
|
317
|
+
summary["parsed_sql_plans"].append(
|
|
318
|
+
_protobuf_to_json_with_redaction(plan, REDACTED_PLAN_SUFFIXES)
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
@safe
|
|
238
322
|
def report_function_usage(self, function_name: str) -> None:
|
|
323
|
+
if self._not_in_request():
|
|
324
|
+
return
|
|
325
|
+
|
|
239
326
|
summary = self._request_summary.get()
|
|
240
327
|
|
|
241
328
|
if "used_functions" not in summary:
|
|
@@ -243,8 +330,11 @@ class Telemetry:
|
|
|
243
330
|
|
|
244
331
|
summary["used_functions"][function_name] += 1
|
|
245
332
|
|
|
246
|
-
@
|
|
333
|
+
@safe
|
|
247
334
|
def report_request_failure(self, e: Exception) -> None:
|
|
335
|
+
if self._not_in_request():
|
|
336
|
+
return
|
|
337
|
+
|
|
248
338
|
summary = self._request_summary.get()
|
|
249
339
|
|
|
250
340
|
summary["was_successful"] = False
|
|
@@ -255,37 +345,78 @@ class Telemetry:
|
|
|
255
345
|
if error_location:
|
|
256
346
|
summary["error_location"] = error_location
|
|
257
347
|
|
|
258
|
-
@
|
|
259
|
-
def report_config_set(self,
|
|
348
|
+
@safe
|
|
349
|
+
def report_config_set(self, pairs: Iterable) -> None:
|
|
350
|
+
if self._not_in_request():
|
|
351
|
+
return
|
|
352
|
+
|
|
260
353
|
summary = self._request_summary.get()
|
|
261
354
|
|
|
262
355
|
if "config_set" not in summary:
|
|
263
356
|
summary["config_set"] = []
|
|
264
357
|
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
358
|
+
for p in pairs:
|
|
359
|
+
summary["config_set"].append(
|
|
360
|
+
{
|
|
361
|
+
"key": p.key,
|
|
362
|
+
"value": p.value if p.key in RECORDED_CONFIG_KEYS else "<redacted>",
|
|
363
|
+
}
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
@safe
|
|
367
|
+
def report_config_unset(self, keys: Iterable[str]) -> None:
|
|
368
|
+
if self._not_in_request():
|
|
369
|
+
return
|
|
271
370
|
|
|
272
|
-
@safe_telemetry
|
|
273
|
-
def report_config_unset(self, key):
|
|
274
371
|
summary = self._request_summary.get()
|
|
275
372
|
|
|
276
373
|
if "config_unset" not in summary:
|
|
277
374
|
summary["config_unset"] = []
|
|
278
375
|
|
|
279
|
-
summary["config_unset"].
|
|
376
|
+
summary["config_unset"].extend(keys)
|
|
377
|
+
|
|
378
|
+
@safe
|
|
379
|
+
def report_config_get(self, keys: Iterable[str]) -> None:
|
|
380
|
+
if self._not_in_request():
|
|
381
|
+
return
|
|
280
382
|
|
|
281
|
-
@safe_telemetry
|
|
282
|
-
def report_config_op_type(self, op_type: str):
|
|
283
383
|
summary = self._request_summary.get()
|
|
284
384
|
|
|
385
|
+
if "config_get" not in summary:
|
|
386
|
+
summary["config_get"] = []
|
|
387
|
+
|
|
388
|
+
summary["config_get"].extend(keys)
|
|
389
|
+
|
|
390
|
+
@safe
|
|
391
|
+
def report_config_op_type(self, op_type: str):
|
|
392
|
+
if self._not_in_request():
|
|
393
|
+
return
|
|
394
|
+
|
|
395
|
+
summary = self._request_summary.get()
|
|
285
396
|
summary["config_op_type"] = op_type
|
|
286
397
|
|
|
287
|
-
@
|
|
288
|
-
def
|
|
398
|
+
@safe
|
|
399
|
+
def report_query(
|
|
400
|
+
self, result: SnowflakeCursor | dict | Exception, **kwargs
|
|
401
|
+
) -> None:
|
|
402
|
+
if result is None or isinstance(result, dict) or self._not_in_request():
|
|
403
|
+
return
|
|
404
|
+
|
|
405
|
+
# SnowflakeCursor and SQL errors will have sfqid
|
|
406
|
+
# other exceptions will not have it
|
|
407
|
+
# TODO: handle async queries, but filter out telemetry export queries
|
|
408
|
+
qid = getattr(result, "sfqid", None)
|
|
409
|
+
|
|
410
|
+
if qid is None:
|
|
411
|
+
logger.warning("Missing query id in result: %s", result)
|
|
412
|
+
|
|
413
|
+
is_internal = kwargs.get("_is_internal", False)
|
|
414
|
+
if is_internal:
|
|
415
|
+
self._report_internal_query()
|
|
416
|
+
elif qid:
|
|
417
|
+
self._report_query_id(qid)
|
|
418
|
+
|
|
419
|
+
def _report_query_id(self, query_id: str):
|
|
289
420
|
summary = self._request_summary.get()
|
|
290
421
|
|
|
291
422
|
if "queries" not in summary:
|
|
@@ -293,13 +424,19 @@ class Telemetry:
|
|
|
293
424
|
|
|
294
425
|
summary["queries"].append(query_id)
|
|
295
426
|
|
|
296
|
-
|
|
297
|
-
def report_internal_query(self):
|
|
427
|
+
def _report_internal_query(self):
|
|
298
428
|
summary = self._request_summary.get()
|
|
429
|
+
|
|
430
|
+
if "internal_queries" not in summary:
|
|
431
|
+
summary["internal_queries"] = 0
|
|
432
|
+
|
|
299
433
|
summary["internal_queries"] += 1
|
|
300
434
|
|
|
301
|
-
@
|
|
435
|
+
@safe
|
|
302
436
|
def report_udf_usage(self, udf_name: str):
|
|
437
|
+
if self._not_in_request():
|
|
438
|
+
return
|
|
439
|
+
|
|
303
440
|
summary = self._request_summary.get()
|
|
304
441
|
|
|
305
442
|
if "udf_usage" not in summary:
|
|
@@ -307,8 +444,10 @@ class Telemetry:
|
|
|
307
444
|
|
|
308
445
|
summary["udf_usage"][udf_name] += 1
|
|
309
446
|
|
|
310
|
-
|
|
311
|
-
|
|
447
|
+
def _report_io(self, op: str, type: str, options: dict | None):
|
|
448
|
+
if self._not_in_request():
|
|
449
|
+
return
|
|
450
|
+
|
|
312
451
|
summary = self._request_summary.get()
|
|
313
452
|
|
|
314
453
|
if "io" not in summary:
|
|
@@ -321,16 +460,18 @@ class Telemetry:
|
|
|
321
460
|
|
|
322
461
|
summary["io"].append(io)
|
|
323
462
|
|
|
463
|
+
@safe
|
|
324
464
|
def report_io_read(self, type: str, options: dict | None):
|
|
325
|
-
self.
|
|
465
|
+
self._report_io("read", type, options)
|
|
326
466
|
|
|
467
|
+
@safe
|
|
327
468
|
def report_io_write(self, type: str, options: dict | None):
|
|
328
|
-
self.
|
|
469
|
+
self._report_io("write", type, options)
|
|
329
470
|
|
|
330
|
-
@
|
|
471
|
+
@safe
|
|
331
472
|
def send_server_started_telemetry(self):
|
|
332
473
|
message = {
|
|
333
|
-
**
|
|
474
|
+
**_basic_telemetry_data(),
|
|
334
475
|
TelemetryField.KEY_TYPE.value: TelemetryType.TYPE_EVENT.value,
|
|
335
476
|
TelemetryType.EVENT_TYPE.value: EventType.SERVER_STARTED.value,
|
|
336
477
|
TelemetryField.KEY_DATA.value: {
|
|
@@ -339,17 +480,22 @@ class Telemetry:
|
|
|
339
480
|
}
|
|
340
481
|
self._send(message)
|
|
341
482
|
|
|
342
|
-
@
|
|
483
|
+
@safe
|
|
343
484
|
def send_request_summary_telemetry(self):
|
|
485
|
+
if self._not_in_request():
|
|
486
|
+
logger.warning(
|
|
487
|
+
"Truing to send request summary telemetry without initializing it"
|
|
488
|
+
)
|
|
489
|
+
return
|
|
490
|
+
|
|
344
491
|
summary = self._request_summary.get()
|
|
345
492
|
message = {
|
|
346
|
-
**
|
|
493
|
+
**_basic_telemetry_data(),
|
|
347
494
|
TelemetryField.KEY_TYPE.value: TelemetryType.TYPE_REQUEST_SUMMARY.value,
|
|
348
495
|
TelemetryField.KEY_DATA.value: summary,
|
|
349
496
|
}
|
|
350
497
|
self._send(message)
|
|
351
498
|
|
|
352
|
-
@safe_telemetry
|
|
353
499
|
def _send(self, msg: Dict) -> None:
|
|
354
500
|
"""Queue a telemetry message for asynchronous processing."""
|
|
355
501
|
if not self._is_enabled:
|
|
@@ -385,19 +531,6 @@ class Telemetry:
|
|
|
385
531
|
finally:
|
|
386
532
|
self._message_queue.task_done()
|
|
387
533
|
|
|
388
|
-
# Process any remaining messages
|
|
389
|
-
while not self._message_queue.empty():
|
|
390
|
-
try:
|
|
391
|
-
message, timestamp = self._message_queue.get_nowait()
|
|
392
|
-
self._sink.add_telemetry_data(message, timestamp)
|
|
393
|
-
self._message_queue.task_done()
|
|
394
|
-
except Exception:
|
|
395
|
-
logger.warning(
|
|
396
|
-
"Failed to add remaining telemetry messages to sink during shutdown",
|
|
397
|
-
exc_info=True,
|
|
398
|
-
)
|
|
399
|
-
break
|
|
400
|
-
|
|
401
534
|
# Flush the sink
|
|
402
535
|
self._sink.flush()
|
|
403
536
|
|
|
@@ -439,6 +572,18 @@ def _error_location(e: Exception) -> Dict | None:
|
|
|
439
572
|
}
|
|
440
573
|
|
|
441
574
|
|
|
575
|
+
def _is_map_field(field_descriptor) -> bool:
|
|
576
|
+
"""
|
|
577
|
+
Check if a protobuf field is a map.
|
|
578
|
+
"""
|
|
579
|
+
return (
|
|
580
|
+
field_descriptor.label == field_descriptor.LABEL_REPEATED
|
|
581
|
+
and field_descriptor.message_type is not None
|
|
582
|
+
and field_descriptor.message_type.has_options
|
|
583
|
+
and field_descriptor.message_type.GetOptions().map_entry
|
|
584
|
+
)
|
|
585
|
+
|
|
586
|
+
|
|
442
587
|
def _protobuf_to_json_with_redaction(
|
|
443
588
|
message: google.protobuf.message.Message, redacted_suffixes: list[str]
|
|
444
589
|
) -> dict:
|
|
@@ -463,7 +608,9 @@ def _protobuf_to_json_with_redaction(
|
|
|
463
608
|
return "<redacted>"
|
|
464
609
|
|
|
465
610
|
# Handle different field types
|
|
466
|
-
if field_descriptor
|
|
611
|
+
if _is_map_field(field_descriptor):
|
|
612
|
+
return dict(value)
|
|
613
|
+
elif field_descriptor.type == field_descriptor.TYPE_MESSAGE:
|
|
467
614
|
if field_descriptor.label == field_descriptor.LABEL_REPEATED:
|
|
468
615
|
# Repeated message field
|
|
469
616
|
return [_protobuf_to_json_recursive(item, field_path) for item in value]
|
|
@@ -481,6 +628,11 @@ def _protobuf_to_json_with_redaction(
|
|
|
481
628
|
msg: google.protobuf.message.Message, current_path: str = ""
|
|
482
629
|
) -> dict:
|
|
483
630
|
"""Recursively convert protobuf message to dict"""
|
|
631
|
+
|
|
632
|
+
if not isinstance(msg, google.protobuf.message.Message):
|
|
633
|
+
logger.warning("Expected a protobuf message, got: %s", type(msg))
|
|
634
|
+
return {}
|
|
635
|
+
|
|
484
636
|
result = {}
|
|
485
637
|
|
|
486
638
|
# Use ListFields() to get all set fields
|