dataproc-spark-connect 0.8.3__py2.py3-none-any.whl → 1.0.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -14,6 +14,7 @@
14
14
 
15
15
  import atexit
16
16
  import datetime
17
+ import functools
17
18
  import json
18
19
  import logging
19
20
  import os
@@ -24,8 +25,9 @@ import threading
24
25
  import time
25
26
  import uuid
26
27
  import tqdm
28
+ from packaging import version
27
29
  from types import MethodType
28
- from typing import Any, cast, ClassVar, Dict, Optional, Union
30
+ from typing import Any, cast, ClassVar, Dict, Iterable, Optional, Union
29
31
 
30
32
  from google.api_core import retry
31
33
  from google.api_core.client_options import ClientOptions
@@ -43,12 +45,14 @@ from google.cloud.dataproc_spark_connect.pypi_artifacts import PyPiArtifacts
43
45
  from google.cloud.dataproc_v1 import (
44
46
  AuthenticationConfig,
45
47
  CreateSessionRequest,
48
+ DeleteSessionRequest,
46
49
  GetSessionRequest,
47
50
  Session,
48
51
  SessionControllerClient,
49
52
  TerminateSessionRequest,
50
53
  )
51
54
  from google.cloud.dataproc_v1.types import sessions
55
+ from google.cloud.dataproc_spark_connect import environment
52
56
  from pyspark.sql.connect.session import SparkSession
53
57
  from pyspark.sql.utils import to_str
54
58
 
@@ -56,6 +60,16 @@ from pyspark.sql.utils import to_str
56
60
  logging.basicConfig(level=logging.INFO)
57
61
  logger = logging.getLogger(__name__)
58
62
 
63
+ # System labels that should not be overridden by user
64
+ SYSTEM_LABELS = {
65
+ "dataproc-session-client",
66
+ "goog-colab-notebook-id",
67
+ }
68
+
69
+ _DATAPROC_SESSIONS_BASE_URL = (
70
+ "https://console.cloud.google.com/dataproc/interactive"
71
+ )
72
+
59
73
 
60
74
  def _is_valid_label_value(value: str) -> bool:
61
75
  """
@@ -77,6 +91,22 @@ def _is_valid_label_value(value: str) -> bool:
77
91
  return bool(re.match(pattern, value))
78
92
 
79
93
 
94
+ def _is_valid_session_id(session_id: str) -> bool:
95
+ """
96
+ Validates if a string complies with Google Cloud session ID format.
97
+ - Must be 4-63 characters
98
+ - Only lowercase letters, numbers, and dashes are allowed
99
+ - Must start with a lowercase letter
100
+ - Cannot end with a dash
101
+ """
102
+ if not session_id:
103
+ return False
104
+
105
+ # The pattern is sufficient for validation and already enforces length constraints.
106
+ pattern = r"^[a-z][a-z0-9-]{2,61}[a-z0-9]$"
107
+ return bool(re.match(pattern, session_id))
108
+
109
+
80
110
  class DataprocSparkSession(SparkSession):
81
111
  """The entry point to programming Spark with the Dataset and DataFrame API.
82
112
 
@@ -96,13 +126,16 @@ class DataprocSparkSession(SparkSession):
96
126
  ... ) # doctest: +SKIP
97
127
  """
98
128
 
99
- _DEFAULT_RUNTIME_VERSION = "2.3"
129
+ _DEFAULT_RUNTIME_VERSION = "3.0"
130
+ _MIN_RUNTIME_VERSION = "3.0"
100
131
 
101
132
  _active_s8s_session_uuid: ClassVar[Optional[str]] = None
102
133
  _project_id = None
103
134
  _region = None
104
135
  _client_options = None
105
136
  _active_s8s_session_id: ClassVar[Optional[str]] = None
137
+ _active_session_uses_custom_id: ClassVar[bool] = False
138
+ _execution_progress_bar = dict()
106
139
 
107
140
  class Builder(SparkSession.Builder):
108
141
 
@@ -110,6 +143,7 @@ class DataprocSparkSession(SparkSession):
110
143
  self._options: Dict[str, Any] = {}
111
144
  self._channel_builder: Optional[DataprocChannelBuilder] = None
112
145
  self._dataproc_config: Optional[Session] = None
146
+ self._custom_session_id: Optional[str] = None
113
147
  self._project_id = os.getenv("GOOGLE_CLOUD_PROJECT")
114
148
  self._region = os.getenv("GOOGLE_CLOUD_REGION")
115
149
  self._client_options = ClientOptions(
@@ -118,6 +152,18 @@ class DataprocSparkSession(SparkSession):
118
152
  f"{self._region}-dataproc.googleapis.com",
119
153
  )
120
154
  )
155
+ self._session_controller_client: Optional[
156
+ SessionControllerClient
157
+ ] = None
158
+
159
+ @property
160
+ def session_controller_client(self) -> SessionControllerClient:
161
+ """Get or create a SessionControllerClient instance."""
162
+ if self._session_controller_client is None:
163
+ self._session_controller_client = SessionControllerClient(
164
+ client_options=self._client_options
165
+ )
166
+ return self._session_controller_client
121
167
 
122
168
  def projectId(self, project_id):
123
169
  self._project_id = project_id
@@ -131,12 +177,106 @@ class DataprocSparkSession(SparkSession):
131
177
  )
132
178
  return self
133
179
 
180
+ def dataprocSessionId(self, session_id: str):
181
+ """
182
+ Set a custom session ID for creating or reusing sessions.
183
+
184
+ The session ID must:
185
+ - Be 4-63 characters long
186
+ - Start with a lowercase letter
187
+ - Contain only lowercase letters, numbers, and hyphens
188
+ - Not end with a hyphen
189
+
190
+ Args:
191
+ session_id: The custom session ID to use
192
+
193
+ Returns:
194
+ This Builder instance for method chaining
195
+
196
+ Raises:
197
+ ValueError: If the session ID format is invalid
198
+ """
199
+ if not _is_valid_session_id(session_id):
200
+ raise ValueError(
201
+ f"Invalid session ID: '{session_id}'. "
202
+ "Session ID must be 4-63 characters, start with a lowercase letter, "
203
+ "contain only lowercase letters, numbers, and hyphens, "
204
+ "and not end with a hyphen."
205
+ )
206
+ self._custom_session_id = session_id
207
+ return self
208
+
134
209
  def dataprocSessionConfig(self, dataproc_config: Session):
210
+ self._dataproc_config = dataproc_config
211
+ for k, v in dataproc_config.runtime_config.properties.items():
212
+ self._options[cast(str, k)] = to_str(v)
213
+ return self
214
+
215
+ @property
216
+ def dataproc_config(self):
135
217
  with self._lock:
136
- self._dataproc_config = dataproc_config
137
- for k, v in dataproc_config.runtime_config.properties.items():
138
- self._options[cast(str, k)] = to_str(v)
139
- return self
218
+ self._dataproc_config = self._dataproc_config or Session()
219
+ return self._dataproc_config
220
+
221
+ def runtimeVersion(self, version: str):
222
+ self.dataproc_config.runtime_config.version = version
223
+ return self
224
+
225
+ def serviceAccount(self, account: str):
226
+ self.dataproc_config.environment_config.execution_config.service_account = (
227
+ account
228
+ )
229
+ return self
230
+
231
+ def subnetwork(self, subnet: str):
232
+ self.dataproc_config.environment_config.execution_config.subnetwork_uri = (
233
+ subnet
234
+ )
235
+ return self
236
+
237
+ def ttl(self, duration: datetime.timedelta):
238
+ """Set the time-to-live (TTL) for the session using a timedelta object."""
239
+ return self.ttlSeconds(int(duration.total_seconds()))
240
+
241
+ def ttlSeconds(self, seconds: int):
242
+ """Set the time-to-live (TTL) for the session in seconds."""
243
+ self.dataproc_config.environment_config.execution_config.ttl = {
244
+ "seconds": seconds
245
+ }
246
+ return self
247
+
248
+ def idleTtl(self, duration: datetime.timedelta):
249
+ """Set the idle time-to-live (idle TTL) for the session using a timedelta object."""
250
+ return self.idleTtlSeconds(int(duration.total_seconds()))
251
+
252
+ def idleTtlSeconds(self, seconds: int):
253
+ """Set the idle time-to-live (idle TTL) for the session in seconds."""
254
+ self.dataproc_config.environment_config.execution_config.idle_ttl = {
255
+ "seconds": seconds
256
+ }
257
+ return self
258
+
259
+ def sessionTemplate(self, template: str):
260
+ self.dataproc_config.session_template = template
261
+ return self
262
+
263
+ def label(self, key: str, value: str):
264
+ """Add a single label to the session."""
265
+ return self.labels({key: value})
266
+
267
+ def labels(self, labels: Dict[str, str]):
268
+ # Filter out system labels and warn user
269
+ filtered_labels = {}
270
+ for key, value in labels.items():
271
+ if key in SYSTEM_LABELS:
272
+ logger.warning(
273
+ f"Label '{key}' is a system label and cannot be overridden by user. Ignoring."
274
+ )
275
+ else:
276
+ filtered_labels[key] = value
277
+
278
+ self.dataproc_config.labels.update(filtered_labels)
279
+ return self
140
280
 
141
281
  def remote(self, url: Optional[str] = None) -> "SparkSession.Builder":
142
282
  if url:
@@ -175,7 +315,11 @@ class DataprocSparkSession(SparkSession):
175
315
  assert self._channel_builder is not None
176
316
  session = DataprocSparkSession(connection=self._channel_builder)
177
317
 
318
+ # Register handler for Cell Execution Progress bar
319
+ session._register_progress_execution_handler()
320
+
178
321
  DataprocSparkSession._set_default_and_active_session(session)
322
+
179
323
  return session
180
324
 
181
325
  def __create(self) -> "DataprocSparkSession":
@@ -190,7 +334,16 @@ class DataprocSparkSession(SparkSession):
190
334
 
191
335
  dataproc_config: Session = self._get_dataproc_config()
192
336
 
193
- session_id = self.generate_dataproc_session_id()
337
+ # Check runtime version compatibility before creating session
338
+ self._check_runtime_compatibility(dataproc_config)
339
+
340
+ # Use custom session ID if provided, otherwise generate one
341
+ session_id = (
342
+ self._custom_session_id
343
+ if self._custom_session_id
344
+ else self.generate_dataproc_session_id()
345
+ )
346
+
194
347
  dataproc_config.name = f"projects/{self._project_id}/locations/{self._region}/sessions/{session_id}"
195
348
  logger.debug(
196
349
  f"Dataproc Session configuration:\n{dataproc_config}"
@@ -205,6 +358,10 @@ class DataprocSparkSession(SparkSession):
205
358
 
206
359
  logger.debug("Creating Dataproc Session")
207
360
  DataprocSparkSession._active_s8s_session_id = session_id
361
+ # Track whether this session uses a custom ID (unmanaged) or auto-generated ID (managed)
362
+ DataprocSparkSession._active_session_uses_custom_id = (
363
+ self._custom_session_id is not None
364
+ )
208
365
  s8s_creation_start_time = time.time()
209
366
 
210
367
  stop_create_session_pbar_event = threading.Event()
@@ -258,8 +415,7 @@ class DataprocSparkSession(SparkSession):
258
415
  client_options=self._client_options
259
416
  ).create_session(session_request)
260
417
  self._display_session_link_on_creation(session_id)
261
- # TODO: Add the 'View Session Details' button once the UI changes are done.
262
- # self._display_view_session_details_button(session_id)
418
+ self._display_view_session_details_button(session_id)
263
419
  create_session_pbar_thread.start()
264
420
  session_response: Session = operation.result(
265
421
  polling=retry.Retry(
@@ -296,6 +452,7 @@ class DataprocSparkSession(SparkSession):
296
452
  if create_session_pbar_thread.is_alive():
297
453
  create_session_pbar_thread.join()
298
454
  DataprocSparkSession._active_s8s_session_id = None
455
+ DataprocSparkSession._active_session_uses_custom_id = False
299
456
  raise DataprocSparkConnectException(
300
457
  f"Error while creating Dataproc Session: {e.message}"
301
458
  )
@@ -304,6 +461,7 @@ class DataprocSparkSession(SparkSession):
304
461
  if create_session_pbar_thread.is_alive():
305
462
  create_session_pbar_thread.join()
306
463
  DataprocSparkSession._active_s8s_session_id = None
464
+ DataprocSparkSession._active_session_uses_custom_id = False
307
465
  raise RuntimeError(
308
466
  f"Error while creating Dataproc Session"
309
467
  ) from e
@@ -317,16 +475,43 @@ class DataprocSparkSession(SparkSession):
317
475
  session_response, dataproc_config.name
318
476
  )
319
477
 
478
+ def _wait_for_session_available(
479
+ self, session_name: str, timeout: int = 300
480
+ ) -> Session:
481
+ start_time = time.time()
482
+ while time.time() - start_time < timeout:
483
+ try:
484
+ session = self.session_controller_client.get_session(
485
+ name=session_name
486
+ )
487
+ if "Spark Connect Server" in session.runtime_info.endpoints:
488
+ return session
489
+ time.sleep(5)
490
+ except Exception as e:
491
+ logger.warning(
492
+ f"Error while polling for Spark Connect endpoint: {e}"
493
+ )
494
+ time.sleep(5)
495
+ raise RuntimeError(
496
+ f"Spark Connect endpoint not available for session {session_name} after {timeout} seconds."
497
+ )
498
+
320
499
  def _display_session_link_on_creation(self, session_id):
321
- session_url = f"https://console.cloud.google.com/dataproc/interactive/{self._region}/{session_id}?project={self._project_id}"
500
+ session_url = f"{_DATAPROC_SESSIONS_BASE_URL}/{self._region}/{session_id}?project={self._project_id}"
322
501
  plain_message = f"Creating Dataproc Session: {session_url}"
323
- html_element = f"""
502
+ if environment.is_colab_enterprise():
503
+ html_element = f"""
324
504
  <div>
325
505
  <p>Creating Dataproc Spark Session<p>
326
- <p><a href="{session_url}">Dataproc Session</a></p>
327
506
  </div>
328
- """
329
-
507
+ """
508
+ else:
509
+ html_element = f"""
510
+ <div>
511
+ <p>Creating Dataproc Spark Session<p>
512
+ <p><a href="{session_url}">Dataproc Session</a></p>
513
+ </div>
514
+ """
330
515
  self._output_element_or_message(plain_message, html_element)
331
516
 
332
517
  def _print_session_created_message(self):
@@ -345,16 +530,19 @@ class DataprocSparkSession(SparkSession):
345
530
  :param html_element: HTML element to display for interactive IPython
346
531
  environment
347
532
  """
533
+ # Don't print any output (Rich or Plain) for non-interactive
534
+ if not environment.is_interactive():
535
+ return
536
+
537
+ if environment.is_interactive_terminal():
538
+ print(plain_message)
539
+ return
540
+
348
541
  try:
349
542
  from IPython.display import display, HTML
350
- from IPython.core.interactiveshell import InteractiveShell
351
543
 
352
- if not InteractiveShell.initialized():
353
- raise DataprocSparkConnectException(
354
- "Not in an Interactive IPython Environment"
355
- )
356
544
  display(HTML(html_element))
357
- except (ImportError, DataprocSparkConnectException):
545
+ except ImportError:
358
546
  print(plain_message)
359
547
 
360
548
  def _get_exiting_active_session(
@@ -375,11 +563,13 @@ class DataprocSparkSession(SparkSession):
375
563
 
376
564
  if session_response is not None:
377
565
  print(
378
- f"Using existing Dataproc Session (configuration changes may not be applied): https://console.cloud.google.com/dataproc/interactive/{self._region}/{s8s_session_id}?project={self._project_id}"
566
+ f"Using existing Dataproc Session (configuration changes may not be applied): {_DATAPROC_SESSIONS_BASE_URL}/{self._region}/{s8s_session_id}?project={self._project_id}"
379
567
  )
380
- # TODO: Add the 'View Session Details' button once the UI changes are done.
381
- # self._display_view_session_details_button(s8s_session_id)
568
+ self._display_view_session_details_button(s8s_session_id)
382
569
  if session is None:
570
+ session_response = self._wait_for_session_available(
571
+ session_name
572
+ )
383
573
  session = self.__create_spark_connect_session_from_s8s(
384
574
  session_response, session_name
385
575
  )
@@ -395,17 +585,59 @@ class DataprocSparkSession(SparkSession):
395
585
 
396
586
  def getOrCreate(self) -> "DataprocSparkSession":
397
587
  with DataprocSparkSession._lock:
588
+ if environment.is_dataproc_batch():
589
+ # For Dataproc batch workloads, connect to the already initialized local SparkSession
590
+ from pyspark.sql import SparkSession as PySparkSQLSession
591
+
592
+ session = PySparkSQLSession.builder.getOrCreate()
593
+ return session # type: ignore
594
+
595
+ if self._project_id is None:
596
+ raise DataprocSparkConnectException(
597
+ f"Error while creating Dataproc Session: project ID is not set"
598
+ )
599
+
600
+ if self._region is None:
601
+ raise DataprocSparkConnectException(
602
+ f"Error while creating Dataproc Session: location is not set"
603
+ )
604
+
605
+ # Handle custom session ID by setting it early and letting existing logic handle it
606
+ if self._custom_session_id:
607
+ self._handle_custom_session_id()
608
+
398
609
  session = self._get_exiting_active_session()
399
610
  if session is None:
400
611
  session = self.__create()
612
+
613
+ # Register this session as the instantiated SparkSession for compatibility
614
+ # with tools and libraries that expect SparkSession._instantiatedSession
615
+ from pyspark.sql import SparkSession as PySparkSQLSession
616
+
617
+ PySparkSQLSession._instantiatedSession = session
618
+
401
619
  return session
402
620
 
621
+ def _handle_custom_session_id(self):
622
+ """Handle custom session ID by checking if it exists and setting _active_s8s_session_id."""
623
+ session_response = self._get_session_by_id(self._custom_session_id)
624
+ if session_response is not None:
625
+ # Found an active session with the custom ID, set it as the active session
626
+ DataprocSparkSession._active_s8s_session_id = (
627
+ self._custom_session_id
628
+ )
629
+ # Mark that this session uses a custom ID
630
+ DataprocSparkSession._active_session_uses_custom_id = True
631
+ else:
632
+ # No existing session found, clear any existing active session ID
633
+ # so we'll create a new one with the custom ID
634
+ DataprocSparkSession._active_s8s_session_id = None
635
+
403
636
  def _get_dataproc_config(self):
404
- dataproc_config = Session()
405
- if self._dataproc_config:
406
- dataproc_config = self._dataproc_config
407
- for k, v in self._options.items():
408
- dataproc_config.runtime_config.properties[k] = v
637
+ # Use the property to ensure we always have a config
638
+ dataproc_config = self.dataproc_config
639
+ for k, v in self._options.items():
640
+ dataproc_config.runtime_config.properties[k] = v
409
641
  dataproc_config.spark_connect_session = (
410
642
  sessions.SparkConnectConfig()
411
643
  )
@@ -413,20 +645,38 @@ class DataprocSparkSession(SparkSession):
413
645
  dataproc_config.runtime_config.version = (
414
646
  DataprocSparkSession._DEFAULT_RUNTIME_VERSION
415
647
  )
648
+
649
+ # Check for Python version mismatch with runtime for UDF compatibility
650
+ self._check_python_version_compatibility(
651
+ dataproc_config.runtime_config.version
652
+ )
653
+
654
+ # Use local variable to improve readability of deeply nested attribute access
655
+ exec_config = dataproc_config.environment_config.execution_config
656
+
657
+ # Set service account from environment if not already set
416
658
  if (
417
- not dataproc_config.environment_config.execution_config.authentication_config.user_workload_authentication_type
418
- and "DATAPROC_SPARK_CONNECT_AUTH_TYPE" in os.environ
419
- ):
420
- dataproc_config.environment_config.execution_config.authentication_config.user_workload_authentication_type = AuthenticationConfig.AuthenticationType[
421
- os.getenv("DATAPROC_SPARK_CONNECT_AUTH_TYPE")
422
- ]
423
- if (
424
- not dataproc_config.environment_config.execution_config.service_account
659
+ not exec_config.service_account
425
660
  and "DATAPROC_SPARK_CONNECT_SERVICE_ACCOUNT" in os.environ
426
661
  ):
427
- dataproc_config.environment_config.execution_config.service_account = os.getenv(
662
+ exec_config.service_account = os.getenv(
428
663
  "DATAPROC_SPARK_CONNECT_SERVICE_ACCOUNT"
429
664
  )
665
+
666
+ # Auto-set authentication type to SERVICE_ACCOUNT when service account is provided
667
+ if exec_config.service_account:
668
+ # When service account is provided, explicitly set auth type to SERVICE_ACCOUNT
669
+ exec_config.authentication_config.user_workload_authentication_type = (
670
+ AuthenticationConfig.AuthenticationType.SERVICE_ACCOUNT
671
+ )
672
+ elif (
673
+ not exec_config.authentication_config.user_workload_authentication_type
674
+ and "DATAPROC_SPARK_CONNECT_AUTH_TYPE" in os.environ
675
+ ):
676
+ # Only set auth type from environment if no service account is present
677
+ exec_config.authentication_config.user_workload_authentication_type = AuthenticationConfig.AuthenticationType[
678
+ os.getenv("DATAPROC_SPARK_CONNECT_AUTH_TYPE")
679
+ ]
430
680
  if (
431
681
  not dataproc_config.environment_config.execution_config.subnetwork_uri
432
682
  and "DATAPROC_SPARK_CONNECT_SUBNET" in os.environ
@@ -452,6 +702,10 @@ class DataprocSparkSession(SparkSession):
452
702
  os.getenv("DATAPROC_SPARK_CONNECT_IDLE_TTL_SECONDS")
453
703
  )
454
704
  }
705
+ client_environment = environment.get_client_environment_label()
706
+ dataproc_config.labels["dataproc-session-client"] = (
707
+ client_environment
708
+ )
455
709
  if "COLAB_NOTEBOOK_ID" in os.environ:
456
710
  colab_notebook_name = os.environ["COLAB_NOTEBOOK_ID"]
457
711
  # Extract the last part of the path, which is the ID
@@ -466,37 +720,102 @@ class DataprocSparkSession(SparkSession):
466
720
  f"Only lowercase letters, numbers, and dashes are allowed. "
467
721
  f"The value must start with lowercase letter or number and end with a lowercase letter or number. "
468
722
  f"Maximum length is 63 characters. "
469
- f"Skipping notebook ID label."
723
+ f"Ignoring notebook ID label."
470
724
  )
471
725
  default_datasource = os.getenv(
472
726
  "DATAPROC_SPARK_CONNECT_DEFAULT_DATASOURCE"
473
727
  )
474
- if (
475
- default_datasource
476
- and dataproc_config.runtime_config.version == "2.3"
477
- ):
478
- if default_datasource == "bigquery":
479
- bq_datasource_properties = {
480
- "spark.datasource.bigquery.viewsEnabled": "true",
481
- "spark.datasource.bigquery.writeMethod": "direct",
728
+ match default_datasource:
729
+ case "bigquery":
730
+ # Merge default configs with existing properties,
731
+ # user configs take precedence
732
+ for k, v in {
482
733
  "spark.sql.catalog.spark_catalog": "com.google.cloud.spark.bigquery.BigQuerySparkSessionCatalog",
483
- "spark.sql.legacy.createHiveTableByDefault": "false",
484
734
  "spark.sql.sources.default": "bigquery",
485
- }
486
- # Merge default configs with existing properties, user configs take precedence
487
- for k, v in bq_datasource_properties.items():
735
+ }.items():
488
736
  if k not in dataproc_config.runtime_config.properties:
489
737
  dataproc_config.runtime_config.properties[k] = v
490
- else:
491
- logger.warning(
492
- f"DATAPROC_SPARK_CONNECT_DEFAULT_DATASOURCE is set to an invalid value:"
493
- f" {default_datasource}. Supported value is 'bigquery'."
494
- )
738
+ case _:
739
+ if default_datasource:
740
+ logger.warning(
741
+ f"DATAPROC_SPARK_CONNECT_DEFAULT_DATASOURCE is set to an invalid value:"
742
+ f" {default_datasource}. Supported value is 'bigquery'."
743
+ )
744
+
495
745
  return dataproc_config
496
746
 
747
+ def _check_python_version_compatibility(self, runtime_version):
748
+ """Check if client Python version matches server Python version for UDF compatibility."""
749
+ import sys
750
+ import warnings
751
+
752
+ # Runtime version to server Python version mapping
753
+ RUNTIME_PYTHON_MAP = {
754
+ "3.0": (3, 12),
755
+ }
756
+
757
+ client_python = sys.version_info[:2] # (major, minor)
758
+
759
+ if runtime_version in RUNTIME_PYTHON_MAP:
760
+ server_python = RUNTIME_PYTHON_MAP[runtime_version]
761
+
762
+ if client_python != server_python:
763
+ warnings.warn(
764
+ f"Python version mismatch detected: Client is using Python {client_python[0]}.{client_python[1]}, "
765
+ f"but Dataproc runtime {runtime_version} uses Python {server_python[0]}.{server_python[1]}. "
766
+ f"This mismatch may cause issues with Python UDF (User Defined Function) compatibility. "
767
+ f"Consider using Python {server_python[0]}.{server_python[1]} for optimal UDF execution.",
768
+ stacklevel=3,
769
+ )
770
+
771
+ def _check_runtime_compatibility(self, dataproc_config):
772
+ """Check if runtime version 3.0 client is compatible with older runtime versions.
773
+
774
+ Runtime version 3.0 clients do not support older runtime versions (pre-3.0).
775
+ There is no backward or forward compatibility between different runtime versions.
776
+
777
+ Args:
778
+ dataproc_config: The Session configuration containing runtime version
779
+
780
+ Raises:
781
+ DataprocSparkConnectException: If server is using pre-3.0 runtime version
782
+ """
783
+ runtime_version = dataproc_config.runtime_config.version
784
+
785
+ if not runtime_version:
786
+ return
787
+
788
+ logger.debug(f"Detected server runtime version: {runtime_version}")
789
+
790
+ # Parse runtime version to check if it's below minimum supported version
791
+ try:
792
+ server_version = version.parse(runtime_version)
793
+ min_version = version.parse(
794
+ DataprocSparkSession._MIN_RUNTIME_VERSION
795
+ )
796
+
797
+ if server_version < min_version:
798
+ raise DataprocSparkConnectException(
799
+ f"Specified {runtime_version} Dataproc Runtime version is not supported, "
800
+ f"use {DataprocSparkSession._MIN_RUNTIME_VERSION} version or higher."
801
+ )
802
+ except version.InvalidVersion:
803
+ # If we can't parse the version, log a warning but continue
804
+ logger.warning(
805
+ f"Could not parse runtime version: {runtime_version}"
806
+ )
807
+
497
808
  def _display_view_session_details_button(self, session_id):
809
+ # Display button is only supported in colab enterprise
810
+ if not environment.is_colab_enterprise():
811
+ return
812
+
813
+ # Skip button display for colab enterprise IPython terminals
814
+ if environment.is_interactive_terminal():
815
+ return
816
+
498
817
  try:
499
- session_url = f"https://console.cloud.google.com/dataproc/interactive/sessions/{session_id}/locations/{self._region}?project={self._project_id}"
818
+ session_url = f"{_DATAPROC_SESSIONS_BASE_URL}/{self._region}/{session_id}?project={self._project_id}"
500
819
  from IPython.core.interactiveshell import InteractiveShell
501
820
 
502
821
  if not InteractiveShell.initialized():
@@ -510,6 +829,90 @@ class DataprocSparkSession(SparkSession):
510
829
  except ImportError as e:
511
830
  logger.debug(f"Import error: {e}")
512
831
 
832
+ def _get_session_by_id(self, session_id: str) -> Optional[Session]:
833
+ """
834
+ Get existing session by ID.
835
+
836
+ Returns:
837
+ Session if ACTIVE/CREATING, None if not found or not usable
838
+ """
839
+ session_name = f"projects/{self._project_id}/locations/{self._region}/sessions/{session_id}"
840
+
841
+ try:
842
+ get_request = GetSessionRequest(name=session_name)
843
+ session = self.session_controller_client.get_session(
844
+ get_request
845
+ )
846
+
847
+ logger.debug(
848
+ f"Found existing session {session_id} in state: {session.state}"
849
+ )
850
+
851
+ if session.state in [
852
+ Session.State.ACTIVE,
853
+ Session.State.CREATING,
854
+ ]:
855
+ # Reuse the active session
856
+ logger.info(f"Reusing existing session: {session_id}")
857
+ return session
858
+ else:
859
+ # Session exists but is not usable (terminated/failed/terminating)
860
+ logger.info(
861
+ f"Session {session_id} in {session.state.name} state, cannot reuse"
862
+ )
863
+ return None
864
+
865
+ except NotFound:
866
+ # Session doesn't exist, can create new one
867
+ logger.debug(
868
+ f"Session {session_id} not found, can create new one"
869
+ )
870
+ return None
871
+ except Exception as e:
872
+ logger.error(f"Error checking session {session_id}: {e}")
873
+ return None
874
+
875
+ def _delete_session(self, session_name: str):
876
+ """Delete a session to free up the session ID for reuse."""
877
+ try:
878
+ delete_request = DeleteSessionRequest(name=session_name)
879
+ self.session_controller_client.delete_session(delete_request)
880
+ logger.debug(f"Deleted session: {session_name}")
881
+ except NotFound:
882
+ logger.debug(f"Session already deleted: {session_name}")
883
+
884
+ def _wait_for_termination(self, session_name: str, timeout: int = 180):
885
+ """Wait for a session to finish terminating."""
886
+ start_time = time.time()
887
+
888
+ while time.time() - start_time < timeout:
889
+ try:
890
+ get_request = GetSessionRequest(name=session_name)
891
+ session = self.session_controller_client.get_session(
892
+ get_request
893
+ )
894
+
895
+ if session.state in [
896
+ Session.State.TERMINATED,
897
+ Session.State.FAILED,
898
+ ]:
899
+ return
900
+ elif session.state != Session.State.TERMINATING:
901
+ # Session is in unexpected state
902
+ logger.warning(
903
+ f"Session {session_name} in unexpected state while waiting for termination: {session.state}"
904
+ )
905
+ return
906
+
907
+ time.sleep(2)
908
+ except NotFound:
909
+ # Session was deleted
910
+ return
911
+
912
+ logger.warning(
913
+ f"Timeout waiting for session {session_name} to terminate"
914
+ )
915
+
513
916
  @staticmethod
514
917
  def generate_dataproc_session_id():
515
918
  timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
@@ -583,16 +986,111 @@ class DataprocSparkSession(SparkSession):
583
986
  execute_and_fetch_as_iterator_wrapped_method, self.client
584
987
  )
585
988
 
989
+ # Patching clearProgressHandlers method to not remove Dataproc Progress Handler
990
+ clearProgressHandlers_base_method = self.clearProgressHandlers
991
+
992
+ def clearProgressHandlers_wrapper_method(_, *args, **kwargs):
993
+ clearProgressHandlers_base_method(*args, **kwargs)
994
+
995
+ self._register_progress_execution_handler()
996
+
997
+ self.clearProgressHandlers = MethodType(
998
+ clearProgressHandlers_wrapper_method, self
999
+ )
1000
+
1001
+ @staticmethod
1002
+ @functools.lru_cache(maxsize=1)
1003
+ def get_tqdm_bar():
1004
+ """
1005
+ Return a tqdm implementation that works in the current environment.
1006
+
1007
+ - Uses CLI tqdm for interactive terminals.
1008
+ - Uses the notebook tqdm if available, otherwise falls back to CLI tqdm.
1009
+ """
1010
+ from tqdm import tqdm as cli_tqdm
1011
+
1012
+ if environment.is_interactive_terminal():
1013
+ return cli_tqdm
1014
+
1015
+ try:
1016
+ import ipywidgets
1017
+ from tqdm.notebook import tqdm as notebook_tqdm
1018
+
1019
+ return notebook_tqdm
1020
+ except ImportError:
1021
+ return cli_tqdm
1022
+
1023
+ def _register_progress_execution_handler(self):
1024
+ from pyspark.sql.connect.shell.progress import StageInfo
1025
+
1026
+ def handler(
1027
+ stages: Optional[Iterable[StageInfo]],
1028
+ inflight_tasks: int,
1029
+ operation_id: Optional[str],
1030
+ done: bool,
1031
+ ):
1032
+ if operation_id is None:
1033
+ return
1034
+
1035
+ # Don't build / render progress bar for non-interactive (despite
1036
+ # Ipython or non-IPython)
1037
+ if not environment.is_interactive():
1038
+ return
1039
+
1040
+ total_tasks = 0
1041
+ completed_tasks = 0
1042
+
1043
+ for stage in stages or []:
1044
+ total_tasks += stage.num_tasks
1045
+ completed_tasks += stage.num_completed_tasks
1046
+
1047
+ # Don't show progress bar till we receive some tasks
1048
+ if total_tasks == 0:
1049
+ return
1050
+
1051
+ # Get correct tqdm (notebook or CLI)
1052
+ tqdm_pbar = self.get_tqdm_bar()
1053
+
1054
+ # Use a lock to ensure only one thread can access and modify
1055
+ # the shared dictionaries at a time.
1056
+ with self._lock:
1057
+ if operation_id in self._execution_progress_bar:
1058
+ pbar = self._execution_progress_bar[operation_id]
1059
+ if pbar.total != total_tasks:
1060
+ pbar.reset(
1061
+ total=total_tasks
1062
+ ) # This force resets the progress bar % too on next refresh
1063
+ else:
1064
+ pbar = tqdm_pbar(
1065
+ total=total_tasks,
1066
+ leave=True,
1067
+ dynamic_ncols=True,
1068
+ bar_format="{l_bar}{bar} {n_fmt}/{total_fmt} Tasks",
1069
+ )
1070
+ self._execution_progress_bar[operation_id] = pbar
1071
+
1072
+ # To handle skipped or failed tasks.
1073
+ # StageInfo proto doesn't have skipped and failed tasks information to process.
1074
+ if done and completed_tasks < total_tasks:
1075
+ completed_tasks = total_tasks
1076
+
1077
+ pbar.n = completed_tasks
1078
+ pbar.refresh()
1079
+
1080
+ if done:
1081
+ pbar.close()
1082
+ self._execution_progress_bar.pop(operation_id, None)
1083
+
1084
+ self.registerProgressHandler(handler)
1085
+
586
1086
  @staticmethod
587
1087
  def _sql_lazy_transformation(req):
588
1088
  # Select SQL command
589
- if req.plan and req.plan.command and req.plan.command.sql_command:
590
- return (
591
- "select"
592
- in req.plan.command.sql_command.sql.strip().lower().split()
593
- )
594
-
595
- return False
1089
+ try:
1090
+ query = req.plan.command.sql_command.input.sql.query
1091
+ return "select" in query.strip().lower().split()
1092
+ except AttributeError:
1093
+ return False
596
1094
 
597
1095
  def _repr_html_(self) -> str:
598
1096
  if not self._active_s8s_session_id:
@@ -600,7 +1098,7 @@ class DataprocSparkSession(SparkSession):
600
1098
  <div>No Active Dataproc Session</div>
601
1099
  """
602
1100
 
603
- s8s_session = f"https://console.cloud.google.com/dataproc/interactive/{self._region}/{self._active_s8s_session_id}"
1101
+ s8s_session = f"{_DATAPROC_SESSIONS_BASE_URL}/{self._region}/{self._active_s8s_session_id}"
604
1102
  ui = f"{s8s_session}/sparkApplications/applications"
605
1103
  return f"""
606
1104
  <div>
@@ -612,6 +1110,11 @@ class DataprocSparkSession(SparkSession):
612
1110
  """
613
1111
 
614
1112
  def _display_operation_link(self, operation_id: str):
1113
+ # Don't print per-operation Spark UI link for non-interactive (despite
1114
+ # Ipython or non-IPython)
1115
+ if not environment.is_interactive():
1116
+ return
1117
+
615
1118
  assert all(
616
1119
  [
617
1120
  operation_id is not None,
@@ -622,17 +1125,18 @@ class DataprocSparkSession(SparkSession):
622
1125
  )
623
1126
 
624
1127
  url = (
625
- f"https://console.cloud.google.com/dataproc/interactive/{self._region}/"
1128
+ f"{_DATAPROC_SESSIONS_BASE_URL}/{self._region}/"
626
1129
  f"{self._active_s8s_session_id}/sparkApplications/application;"
627
1130
  f"associatedSqlOperationId={operation_id}?project={self._project_id}"
628
1131
  )
629
1132
 
1133
+ if environment.is_interactive_terminal():
1134
+ print(f"Spark Query: {url}")
1135
+ return
1136
+
630
1137
  try:
631
1138
  from IPython.display import display, HTML
632
- from IPython.core.interactiveshell import InteractiveShell
633
1139
 
634
- if not InteractiveShell.initialized():
635
- return
636
1140
  html_element = f"""
637
1141
  <div>
638
1142
  <p><a href="{url}">Spark Query</a> (Operation: {operation_id})</p>
@@ -690,7 +1194,7 @@ class DataprocSparkSession(SparkSession):
690
1194
  This is an API dedicated to Spark Connect client only. With regular Spark Session, it throws
691
1195
  an exception.
692
1196
  Regarding pypi: Popular packages are already pre-installed in s8s runtime.
693
- https://cloud.google.com/dataproc-serverless/docs/concepts/versions/spark-runtime-2.2#python_libraries
1197
+ https://cloud.google.com/dataproc-serverless/docs/concepts/versions/spark-runtime-2.3#python_libraries
694
1198
  If there are conflicts/package doesn't exist, it throws an exception.
695
1199
  """
696
1200
  if sum([pypi, file, pyfile, archive]) > 1:
@@ -713,19 +1217,83 @@ class DataprocSparkSession(SparkSession):
713
1217
  def _get_active_session_file_path():
714
1218
  return os.getenv("DATAPROC_SPARK_CONNECT_ACTIVE_SESSION_FILE_PATH")
715
1219
 
716
- def stop(self) -> None:
1220
+ def stop(self, terminate: Optional[bool] = None) -> None:
1221
+ """
1222
+ Stop the Spark session and optionally terminate the server-side session.
1223
+
1224
+ Parameters
1225
+ ----------
1226
+ terminate : bool, optional
1227
+ Control server-side termination behavior.
1228
+
1229
+ - None (default): Auto-detect based on session type
1230
+
1231
+ - Managed sessions (auto-generated ID): terminate server
1232
+ - Named sessions (custom ID): client-side cleanup only
1233
+
1234
+ - True: Always terminate the server-side session
1235
+ - False: Never terminate the server-side session (client cleanup only)
1236
+
1237
+ Examples
1238
+ --------
1239
+ Auto-detect termination behavior (existing behavior):
1240
+
1241
+ >>> spark.stop()
1242
+
1243
+ Force terminate a named session:
1244
+
1245
+ >>> spark.stop(terminate=True)
1246
+
1247
+ Prevent termination of a managed session:
1248
+
1249
+ >>> spark.stop(terminate=False)
1250
+ """
717
1251
  with DataprocSparkSession._lock:
718
1252
  if DataprocSparkSession._active_s8s_session_id is not None:
719
- terminate_s8s_session(
720
- DataprocSparkSession._project_id,
721
- DataprocSparkSession._region,
722
- DataprocSparkSession._active_s8s_session_id,
723
- self._client_options,
724
- )
1253
+ # Determine if we should terminate the server-side session
1254
+ if terminate is None:
1255
+ # Auto-detect: managed sessions terminate, named sessions don't
1256
+ should_terminate = (
1257
+ not DataprocSparkSession._active_session_uses_custom_id
1258
+ )
1259
+ else:
1260
+ should_terminate = terminate
1261
+
1262
+ if should_terminate:
1263
+ # Terminate the server-side session
1264
+ logger.debug(
1265
+ f"Terminating session {DataprocSparkSession._active_s8s_session_id}"
1266
+ )
1267
+ terminate_s8s_session(
1268
+ DataprocSparkSession._project_id,
1269
+ DataprocSparkSession._region,
1270
+ DataprocSparkSession._active_s8s_session_id,
1271
+ self._client_options,
1272
+ )
1273
+ else:
1274
+ # Client-side cleanup only
1275
+ logger.debug(
1276
+ f"Stopping session {DataprocSparkSession._active_s8s_session_id} without termination"
1277
+ )
725
1278
 
726
1279
  self._remove_stopped_session_from_file()
1280
+
1281
+ # Clean up SparkSession._instantiatedSession if it points to this session
1282
+ try:
1283
+ from pyspark.sql import SparkSession as PySparkSQLSession
1284
+
1285
+ if PySparkSQLSession._instantiatedSession is self:
1286
+ PySparkSQLSession._instantiatedSession = None
1287
+ logger.debug(
1288
+ "Cleared SparkSession._instantiatedSession reference"
1289
+ )
1290
+ except (ImportError, AttributeError):
1291
+ # PySpark not available or _instantiatedSession doesn't exist
1292
+ pass
1293
+
727
1294
  DataprocSparkSession._active_s8s_session_uuid = None
728
1295
  DataprocSparkSession._active_s8s_session_id = None
1296
+ DataprocSparkSession._active_session_uses_custom_id = False
729
1297
  DataprocSparkSession._project_id = None
730
1298
  DataprocSparkSession._region = None
731
1299
  DataprocSparkSession._client_options = None