dataproc-spark-connect 0.2.0__py2.py3-none-any.whl → 0.6.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,18 +1,16 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: dataproc-spark-connect
3
- Version: 0.2.0
3
+ Version: 0.6.0
4
4
  Summary: Dataproc client library for Spark Connect
5
5
  Home-page: https://github.com/GoogleCloudDataproc/dataproc-spark-connect-python
6
6
  Author: Google LLC
7
7
  License: Apache 2.0
8
8
  License-File: LICENSE
9
9
  Requires-Dist: google-api-core>=2.19.1
10
- Requires-Dist: google-cloud-dataproc>=5.15.1
11
- Requires-Dist: wheel
10
+ Requires-Dist: google-cloud-dataproc>=5.18.0
12
11
  Requires-Dist: websockets
13
- Requires-Dist: pyspark>=3.5
14
- Requires-Dist: pandas
15
- Requires-Dist: pyarrow
12
+ Requires-Dist: pyspark[connect]>=3.5
13
+ Requires-Dist: packaging>=20.0
16
14
 
17
15
  # Dataproc Spark Connect Client
18
16
 
@@ -22,16 +20,15 @@ Spark cluster using the Spark Connect protocol without requiring additional step
22
20
 
23
21
  ## Install
24
22
 
25
- .. code-block:: console
26
-
27
- pip install dataproc_spark_connect
23
+ ```console
24
+ pip install dataproc_spark_connect
25
+ ```
28
26
 
29
27
  ## Uninstall
30
28
 
31
- .. code-block:: console
32
-
33
- pip uninstall dataproc_spark_connect
34
-
29
+ ```console
30
+ pip uninstall dataproc_spark_connect
31
+ ```
35
32
 
36
33
  ## Setup
37
34
  This client requires permissions to manage [Dataproc sessions and session templates](https://cloud.google.com/dataproc-serverless/docs/concepts/iam).
@@ -46,36 +43,36 @@ If you are running the client outside of Google Cloud, you must set following en
46
43
 
47
44
  1. Install the latest version of Dataproc Python client and Dataproc Spark Connect modules:
48
45
 
49
- .. code-block:: console
50
-
51
- pip install google_cloud_dataproc --force-reinstall
52
- pip install dataproc_spark_connect --force-reinstall
46
+ ```console
47
+ pip install google_cloud_dataproc --force-reinstall
48
+ pip install dataproc_spark_connect --force-reinstall
49
+ ```
53
50
 
54
51
  2. Add the required import into your PySpark application or notebook:
55
52
 
56
- .. code-block:: python
57
-
58
- from google.cloud.dataproc_spark_connect import DataprocSparkSession
53
+ ```python
54
+ from google.cloud.dataproc_spark_connect import DataprocSparkSession
55
+ ```
59
56
 
60
57
  3. There are two ways to create a spark session,
61
58
 
62
59
  1. Start a Spark session using properties defined in `DATAPROC_SPARK_CONNECT_SESSION_DEFAULT_CONFIG`:
63
60
 
64
- .. code-block:: python
65
-
66
- spark = DataprocSparkSession.builder.getOrCreate()
61
+ ```python
62
+ spark = DataprocSparkSession.builder.getOrCreate()
63
+ ```
67
64
 
68
65
  2. Start a Spark session with the following code instead of using a config file:
69
66
 
70
- .. code-block:: python
71
-
72
- from google.cloud.dataproc_v1 import SparkConnectConfig
73
- from google.cloud.dataproc_v1 import Session
74
- dataproc_config = Session()
75
- dataproc_config.spark_connect_session = SparkConnectConfig()
76
- dataproc_config.environment_config.execution_config.subnetwork_uri = "<subnet>"
77
- dataproc_config.runtime_config.version = '3.0'
78
- spark = DataprocSparkSession.builder.dataprocConfig(dataproc_config).getOrCreate()
67
+ ```python
68
+ from google.cloud.dataproc_v1 import SparkConnectConfig
69
+ from google.cloud.dataproc_v1 import Session
70
+ dataproc_session_config = Session()
71
+ dataproc_session_config.spark_connect_session = SparkConnectConfig()
72
+ dataproc_session_config.environment_config.execution_config.subnetwork_uri = "<subnet>"
73
+ dataproc_session_config.runtime_config.version = '3.0'
74
+ spark = DataprocSparkSession.builder.dataprocSessionConfig(dataproc_session_config).getOrCreate()
75
+ ```
79
76
 
80
77
  ## Billing
81
78
  As this client runs the spark workload on Dataproc, your project will be billed as per [Dataproc Serverless Pricing](https://cloud.google.com/dataproc-serverless/pricing).
@@ -86,29 +83,29 @@ This will happen even if you are running the client from a non-GCE instance.
86
83
 
87
84
  1. Install the requirements in virtual environment.
88
85
 
89
- .. code-block:: console
90
-
91
- pip install -r requirements.txt
86
+ ```console
87
+ pip install -r requirements-dev.txt
88
+ ```
92
89
 
93
90
  2. Build the code.
94
91
 
95
- .. code-block:: console
96
-
97
- python setup.py sdist bdist_wheel
98
-
92
+ ```console
93
+ python setup.py sdist bdist_wheel
94
+ ```
99
95
 
100
96
  3. Copy the generated `.whl` file to Cloud Storage. Use the version specified in the `setup.py` file.
101
97
 
102
- .. code-block:: console
103
-
104
- VERSION=<version> gsutil cp dist/dataproc_spark_connect-${VERSION}-py2.py3-none-any.whl gs://<your_bucket_name>
98
+ ```sh
99
+ VERSION=<version>
100
+ gsutil cp dist/dataproc_spark_connect-${VERSION}-py2.py3-none-any.whl gs://<your_bucket_name>
101
+ ```
105
102
 
106
103
  4. Download the new SDK on Vertex, then uninstall the old version and install the new one.
107
104
 
108
- .. code-block:: console
109
-
110
- %%bash
111
- export VERSION=<version>
112
- gsutil cp gs://<your_bucket_name>/dataproc_spark_connect-${VERSION}-py2.py3-none-any.whl .
113
- yes | pip uninstall dataproc_spark_connect
114
- pip install dataproc_spark_connect-${VERSION}-py2.py3-none-any.whl
105
+ ```sh
106
+ %%bash
107
+ export VERSION=<version>
108
+ gsutil cp gs://<your_bucket_name>/dataproc_spark_connect-${VERSION}-py2.py3-none-any.whl .
109
+ yes | pip uninstall dataproc_spark_connect
110
+ pip install dataproc_spark_connect-${VERSION}-py2.py3-none-any.whl
111
+ ```
@@ -0,0 +1,12 @@
1
+ google/cloud/dataproc_spark_connect/__init__.py,sha256=dIqHNWVWWrSuRf26x11kX5e9yMKSHCtmI_GBj1-FDdE,1101
2
+ google/cloud/dataproc_spark_connect/exceptions.py,sha256=ilGyHD5M_yBQ3IC58-Y5miRGIQVJsLaNKvEGcHuk_BE,969
3
+ google/cloud/dataproc_spark_connect/pypi_artifacts.py,sha256=gd-VMwiVP-EJuPp9Vf9Shx8pqps3oSKp0hBcSSZQS-A,1575
4
+ google/cloud/dataproc_spark_connect/session.py,sha256=gKPtWDzlz5WA5lPGLMOhNdtKskMDjbLG8KcTmv0PrWA,26189
5
+ google/cloud/dataproc_spark_connect/client/__init__.py,sha256=6hCNSsgYlie6GuVpc5gjFsPnyeMTScTpXSPYqp1fplY,615
6
+ google/cloud/dataproc_spark_connect/client/core.py,sha256=m3oXTKBm3sBy6jhDu9GRecrxLb5CdEM53SgMlnJb6ag,4616
7
+ google/cloud/dataproc_spark_connect/client/proxy.py,sha256=GNy561Fo8A2ehqLrDMkVWOUYV62YCO2tuN77it3H098,8954
8
+ dataproc_spark_connect-0.6.0.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
9
+ dataproc_spark_connect-0.6.0.dist-info/METADATA,sha256=m8PZHKk353AcATjML-Fgw_6yrtHmgVQLxEDI_90h2_0,4020
10
+ dataproc_spark_connect-0.6.0.dist-info/WHEEL,sha256=OpXWERl2xLPRHTvd2ZXo_iluPEQd8uSbYkJ53NAER_Y,109
11
+ dataproc_spark_connect-0.6.0.dist-info/top_level.txt,sha256=_1QvSJIhFAGfxb79D6DhB7SUw2X6T4rwnz_LLrbcD3c,7
12
+ dataproc_spark_connect-0.6.0.dist-info/RECORD,,
@@ -11,4 +11,19 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+ import importlib.metadata
15
+ import warnings
16
+
14
17
  from .session import DataprocSparkSession
18
+
19
+ old_package_name = "google-spark-connect"
20
+ current_package_name = "dataproc-spark-connect"
21
+ try:
22
+ importlib.metadata.distribution(old_package_name)
23
+ warnings.warn(
24
+ f"Package '{old_package_name}' is already installed in your environment. "
25
+ f"This might cause conflicts with '{current_package_name}'. "
26
+ f"Consider uninstalling '{old_package_name}' and only install '{current_package_name}'."
27
+ )
28
+ except:
29
+ pass
@@ -11,12 +11,16 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+ import logging
15
+
14
16
  import google
15
17
  import grpc
16
18
  from pyspark.sql.connect.client import ChannelBuilder
17
19
 
18
20
  from . import proxy
19
21
 
22
+ logger = logging.getLogger(__name__)
23
+
20
24
 
21
25
  class DataprocChannelBuilder(ChannelBuilder):
22
26
  """
@@ -36,6 +40,10 @@ class DataprocChannelBuilder(ChannelBuilder):
36
40
  True
37
41
  """
38
42
 
43
+ def __init__(self, url, is_active_callback=None):
44
+ self._is_active_callback = is_active_callback
45
+ super().__init__(url)
46
+
39
47
  def toChannel(self) -> grpc.Channel:
40
48
  """
41
49
  Applies the parameters of the connection string and creates a new
@@ -51,7 +59,7 @@ class DataprocChannelBuilder(ChannelBuilder):
51
59
  return self._proxied_channel()
52
60
 
53
61
  def _proxied_channel(self) -> grpc.Channel:
54
- return ProxiedChannel(self.host)
62
+ return ProxiedChannel(self.host, self._is_active_callback)
55
63
 
56
64
  def _direct_channel(self) -> grpc.Channel:
57
65
  destination = f"{self.host}:{self.port}"
@@ -75,7 +83,8 @@ class DataprocChannelBuilder(ChannelBuilder):
75
83
 
76
84
  class ProxiedChannel(grpc.Channel):
77
85
 
78
- def __init__(self, target_host):
86
+ def __init__(self, target_host, is_active_callback):
87
+ self._is_active_callback = is_active_callback
79
88
  self._proxy = proxy.DataprocSessionProxy(0, target_host)
80
89
  self._proxy.start()
81
90
  self._proxied_connect_url = f"sc://localhost:{self._proxy.port}"
@@ -94,20 +103,37 @@ class ProxiedChannel(grpc.Channel):
94
103
  self._proxy.stop()
95
104
  return ret
96
105
 
106
+ def _wrap_method(self, wrapped_method):
107
+ if self._is_active_callback is None:
108
+ return wrapped_method
109
+
110
+ def checked_method(*margs, **mkwargs):
111
+ if (
112
+ self._is_active_callback is not None
113
+ and not self._is_active_callback()
114
+ ):
115
+ logger.warning(f"Session is no longer active")
116
+ raise RuntimeError(
117
+ "Session not active. Please create a new session"
118
+ )
119
+ return wrapped_method(*margs, **mkwargs)
120
+
121
+ return checked_method
122
+
97
123
  def stream_stream(self, *args, **kwargs):
98
- return self._wrapped.stream_stream(*args, **kwargs)
124
+ return self._wrap_method(self._wrapped.stream_stream(*args, **kwargs))
99
125
 
100
126
  def stream_unary(self, *args, **kwargs):
101
- return self._wrapped.stream_unary(*args, **kwargs)
127
+ return self._wrap_method(self._wrapped.stream_unary(*args, **kwargs))
102
128
 
103
129
  def subscribe(self, *args, **kwargs):
104
- return self._wrapped.subscribe(*args, **kwargs)
130
+ return self._wrap_method(self._wrapped.subscribe(*args, **kwargs))
105
131
 
106
132
  def unary_stream(self, *args, **kwargs):
107
- return self._wrapped.unary_stream(*args, **kwargs)
133
+ return self._wrap_method(self._wrapped.unary_stream(*args, **kwargs))
108
134
 
109
135
  def unary_unary(self, *args, **kwargs):
110
- return self._wrapped.unary_unary(*args, **kwargs)
136
+ return self._wrap_method(self._wrapped.unary_unary(*args, **kwargs))
111
137
 
112
138
  def unsubscribe(self, *args, **kwargs):
113
- return self._wrapped.unsubscribe(*args, **kwargs)
139
+ return self._wrap_method(self._wrapped.unsubscribe(*args, **kwargs))
@@ -81,6 +81,7 @@ def connect_tcp_bridge(hostname):
81
81
  return websocketclient.connect(
82
82
  f"wss://{hostname}/{path}",
83
83
  additional_headers={"Authorization": f"Bearer {creds.token}"},
84
+ open_timeout=30,
84
85
  )
85
86
 
86
87
 
@@ -101,8 +102,11 @@ def forward_bytes(name, from_sock, to_sock):
101
102
  try:
102
103
  bs = from_sock.recv(1024)
103
104
  if not bs:
105
+ to_sock.close()
104
106
  return
105
- while bs:
107
+ attempt = 0
108
+ while bs and (attempt < 10):
109
+ attempt += 1
106
110
  try:
107
111
  to_sock.send(bs)
108
112
  bs = None
@@ -110,6 +114,8 @@ def forward_bytes(name, from_sock, to_sock):
110
114
  # On timeouts during a send, we retry just the send
111
115
  # to make sure we don't lose any bytes.
112
116
  pass
117
+ if bs:
118
+ raise Exception(f"Failed to forward bytes for {name}")
113
119
  except TimeoutError:
114
120
  # On timeouts during a receive, we retry the entire flow.
115
121
  pass
@@ -163,6 +169,11 @@ def forward_connection(conn_number, conn, addr, target_host):
163
169
  with conn:
164
170
  with connect_tcp_bridge(target_host) as websocket_conn:
165
171
  backend_socket = bridged_socket(websocket_conn)
172
+ # Set a timeout on how long we will allow send/recv calls to block
173
+ #
174
+ # The code that reads and writes to this connection will retry
175
+ # on timeouts, so this is a safe change.
176
+ conn.settimeout(10)
166
177
  connect_sockets(conn_number, conn, backend_socket)
167
178
 
168
179
 
@@ -210,14 +221,6 @@ class DataprocSessionProxy(object):
210
221
  s.release()
211
222
  while not self._killed:
212
223
  conn, addr = frontend_socket.accept()
213
- # Set a timeout on how long we will allow send/recv calls to block
214
- #
215
- # The code that reads and writes to this connection will retry
216
- # on timeouts, so this is a safe change.
217
- #
218
- # The chosen timeout is a very short one because it allows us
219
- # to more quickly detect when a connection has been closed.
220
- conn.settimeout(1)
221
224
  logger.debug(f"Accepted a connection from {addr}...")
222
225
  self._conn_number += 1
223
226
  threading.Thread(
@@ -0,0 +1,27 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ class DataprocSparkConnectException(Exception):
17
+ """A custom exception class to only print the error messages.
18
+ This would be used for exceptions where the stack trace
19
+ doesn't provide any additional information.h
20
+ """
21
+
22
+ def __init__(self, message):
23
+ self.message = message
24
+ super().__init__(message)
25
+
26
+ def _render_traceback_(self):
27
+ return self.message
@@ -0,0 +1,48 @@
1
+ import json
2
+ import logging
3
+ import os
4
+ import tempfile
5
+
6
+ from packaging.requirements import Requirement
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ class PyPiArtifacts:
12
+ """
13
+ This is a helper class to serialize the PYPI package installation request with a "magic" file name
14
+ that Spark Connect server understands
15
+ """
16
+
17
+ @staticmethod
18
+ def __try_parsing_package(packages: set[str]) -> list[Requirement]:
19
+ reqs = [Requirement(p) for p in packages]
20
+ if 0 in [len(req.specifier) for req in reqs]:
21
+ logger.info("It is recommended to pin the version of the package")
22
+ return reqs
23
+
24
+ def __init__(self, packages: set[str]):
25
+ self.requirements = PyPiArtifacts.__try_parsing_package(packages)
26
+
27
+ def write_packages_config(self, s8s_session_uuid: str) -> str:
28
+ """
29
+ Can't use the same file-name as Spark throws exception that file already exists
30
+ Keep the filename/format in sync with server
31
+ """
32
+ dependencies = {
33
+ "version": "0.5",
34
+ "packageType": "PYPI",
35
+ "packages": [str(req) for req in self.requirements],
36
+ }
37
+
38
+ file_path = os.path.join(
39
+ tempfile.gettempdir(),
40
+ s8s_session_uuid,
41
+ "add-artifacts-1729-" + self.__str__() + ".json",
42
+ )
43
+
44
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
45
+ with open(file_path, "w") as json_file:
46
+ json.dump(dependencies, json_file, indent=4)
47
+ logger.debug("Dumping dependencies request in file: " + file_path)
48
+ return file_path
@@ -11,6 +11,7 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+ import atexit
14
15
  import json
15
16
  import logging
16
17
  import os
@@ -24,9 +25,10 @@ from typing import Any, cast, ClassVar, Dict, Optional
24
25
  from google.api_core import retry
25
26
  from google.api_core.future.polling import POLLING_PREDICATE
26
27
  from google.api_core.client_options import ClientOptions
27
- from google.api_core.exceptions import FailedPrecondition, InvalidArgument, NotFound
28
+ from google.api_core.exceptions import Aborted, FailedPrecondition, InvalidArgument, NotFound, PermissionDenied
28
29
  from google.cloud.dataproc_v1.types import sessions
29
30
 
31
+ from google.cloud.dataproc_spark_connect.pypi_artifacts import PyPiArtifacts
30
32
  from google.cloud.dataproc_spark_connect.client import DataprocChannelBuilder
31
33
  from google.cloud.dataproc_v1 import (
32
34
  CreateSessionRequest,
@@ -41,6 +43,7 @@ from google.protobuf.text_format import ParseError
41
43
  from pyspark.sql.connect.session import SparkSession
42
44
  from pyspark.sql.utils import to_str
43
45
 
46
+ from google.cloud.dataproc_spark_connect.exceptions import DataprocSparkConnectException
44
47
 
45
48
  # Set up logging
46
49
  logging.basicConfig(level=logging.INFO)
@@ -61,7 +64,7 @@ class DataprocSparkSession(SparkSession):
61
64
  >>> spark = (
62
65
  ... DataprocSparkSession.builder
63
66
  ... .appName("Word Count")
64
- ... .dataprocConfig(Session())
67
+ ... .dataprocSessionConfig(Session())
65
68
  ... .getOrCreate()
66
69
  ... ) # doctest: +SKIP
67
70
  """
@@ -112,15 +115,15 @@ class DataprocSparkSession(SparkSession):
112
115
  self._project_id = project_id
113
116
  return self
114
117
 
115
- def region(self, region):
116
- self._region = region
118
+ def location(self, location):
119
+ self._region = location
117
120
  self._client_options.api_endpoint = os.environ.get(
118
121
  "GOOGLE_CLOUD_DATAPROC_API_ENDPOINT",
119
122
  f"{self._region}-dataproc.googleapis.com",
120
123
  )
121
124
  return self
122
125
 
123
- def dataprocConfig(self, dataproc_config: Session):
126
+ def dataprocSessionConfig(self, dataproc_config: Session):
124
127
  with self._lock:
125
128
  self._dataproc_config = dataproc_config
126
129
  for k, v in dataproc_config.runtime_config.properties.items():
@@ -135,14 +138,14 @@ class DataprocSparkSession(SparkSession):
135
138
  else:
136
139
  return self
137
140
 
138
- def create(self) -> "SparkSession":
141
+ def create(self) -> "DataprocSparkSession":
139
142
  raise NotImplemented(
140
143
  "DataprocSparkSession allows session creation only through getOrCreate"
141
144
  )
142
145
 
143
146
  def __create_spark_connect_session_from_s8s(
144
- self, session_response
145
- ) -> "SparkSession":
147
+ self, session_response, session_name
148
+ ) -> "DataprocSparkSession":
146
149
  DataprocSparkSession._active_s8s_session_uuid = (
147
150
  session_response.uuid
148
151
  )
@@ -153,9 +156,16 @@ class DataprocSparkSession(SparkSession):
153
156
  "Spark Connect Server"
154
157
  )
155
158
  spark_connect_url = spark_connect_url.replace("https", "sc")
159
+ if not spark_connect_url.endswith("/"):
160
+ spark_connect_url += "/"
156
161
  url = f"{spark_connect_url.replace('.com/', '.com:443/')};session_id={session_response.uuid};use_ssl=true"
157
162
  logger.debug(f"Spark Connect URL: {url}")
158
- self._channel_builder = DataprocChannelBuilder(url)
163
+ self._channel_builder = DataprocChannelBuilder(
164
+ url,
165
+ is_active_callback=lambda: is_s8s_session_active(
166
+ session_name, self._client_options
167
+ ),
168
+ )
159
169
 
160
170
  assert self._channel_builder is not None
161
171
  session = DataprocSparkSession(connection=self._channel_builder)
@@ -164,7 +174,7 @@ class DataprocSparkSession(SparkSession):
164
174
  self.__apply_options(session)
165
175
  return session
166
176
 
167
- def __create(self) -> "SparkSession":
177
+ def __create(self) -> "DataprocSparkSession":
168
178
  with self._lock:
169
179
 
170
180
  if self._options.get("spark.remote", False):
@@ -185,12 +195,8 @@ class DataprocSparkSession(SparkSession):
185
195
  dataproc_config, session_template
186
196
  )
187
197
 
188
- spark = self._get_spark(dataproc_config, session_template)
189
-
190
198
  if not spark_connect_session:
191
199
  dataproc_config.spark_connect_session = {}
192
- if not spark:
193
- dataproc_config.spark = {}
194
200
  os.environ["SPARK_CONNECT_MODE_ENABLED"] = "1"
195
201
  session_request = CreateSessionRequest()
196
202
  session_id = self.generate_dataproc_session_id()
@@ -216,14 +222,28 @@ class DataprocSparkSession(SparkSession):
216
222
  multiplier=1.0,
217
223
  timeout=600, # seconds
218
224
  )
219
- logger.info(
220
- "Creating Spark session. It may take few minutes."
221
- )
225
+ print("Creating Spark session. It may take a few minutes.")
226
+ if (
227
+ "dataproc_spark_connect_SESSION_TERMINATE_AT_EXIT"
228
+ in os.environ
229
+ and os.getenv(
230
+ "dataproc_spark_connect_SESSION_TERMINATE_AT_EXIT"
231
+ ).lower()
232
+ == "true"
233
+ ):
234
+ atexit.register(
235
+ lambda: terminate_s8s_session(
236
+ self._project_id,
237
+ self._region,
238
+ session_id,
239
+ self._client_options,
240
+ )
241
+ )
222
242
  operation = SessionControllerClient(
223
243
  client_options=self._client_options
224
244
  ).create_session(session_request)
225
245
  print(
226
- f"Interactive Session Detail View: https://console.cloud.google.com/dataproc/interactive/{self._region}/{session_id}"
246
+ f"Interactive Session Detail View: https://console.cloud.google.com/dataproc/interactive/{self._region}/{session_id}?project={self._project_id}"
227
247
  )
228
248
  session_response: Session = operation.result(
229
249
  polling=session_polling
@@ -249,49 +269,32 @@ class DataprocSparkSession(SparkSession):
249
269
  logger.error(
250
270
  f"Exception while writing active session to file {file_path} , {e}"
251
271
  )
252
- except InvalidArgument as e:
272
+ except (InvalidArgument, PermissionDenied) as e:
253
273
  DataprocSparkSession._active_s8s_session_id = None
254
- raise RuntimeError(
255
- f"Error while creating serverless session: {e}"
256
- ) from None
274
+ raise DataprocSparkConnectException(
275
+ f"Error while creating serverless session: {e.message}"
276
+ )
257
277
  except Exception as e:
258
278
  DataprocSparkSession._active_s8s_session_id = None
259
279
  raise RuntimeError(
260
- f"Error while creating serverless session https://console.cloud.google.com/dataproc/interactive/{self._region}/{session_id} : {e}"
261
- ) from None
280
+ f"Error while creating serverless session"
281
+ ) from e
262
282
 
263
283
  logger.debug(
264
284
  f"Serverless session created: {session_id}, creation time taken: {int(time.time() - s8s_creation_start_time)} seconds"
265
285
  )
266
286
  return self.__create_spark_connect_session_from_s8s(
267
- session_response
287
+ session_response, dataproc_config.name
268
288
  )
269
289
 
270
- def _is_s8s_session_active(
271
- self, s8s_session_id: str
272
- ) -> Optional[sessions.Session]:
273
- session_name = f"projects/{self._project_id}/locations/{self._region}/sessions/{s8s_session_id}"
274
- get_session_request = GetSessionRequest()
275
- get_session_request.name = session_name
276
- state = None
277
- try:
278
- get_session_response = SessionControllerClient(
279
- client_options=self._client_options
280
- ).get_session(get_session_request)
281
- state = get_session_response.state
282
- except Exception as e:
283
- logger.debug(f"{s8s_session_id} deleted: {e}")
284
- return None
285
-
286
- if state is not None and (
287
- state == Session.State.ACTIVE or state == Session.State.CREATING
288
- ):
289
- return get_session_response
290
- return None
291
-
292
- def _get_exiting_active_session(self) -> Optional["SparkSession"]:
290
+ def _get_exiting_active_session(
291
+ self,
292
+ ) -> Optional["DataprocSparkSession"]:
293
293
  s8s_session_id = DataprocSparkSession._active_s8s_session_id
294
- session_response = self._is_s8s_session_active(s8s_session_id)
294
+ session_name = f"projects/{self._project_id}/locations/{self._region}/sessions/{s8s_session_id}"
295
+ session_response = get_active_s8s_session_response(
296
+ session_name, self._client_options
297
+ )
295
298
 
296
299
  session = DataprocSparkSession.getActiveSession()
297
300
  if session is None:
@@ -299,11 +302,11 @@ class DataprocSparkSession(SparkSession):
299
302
 
300
303
  if session_response is not None:
301
304
  print(
302
- f"Using existing session: https://console.cloud.google.com/dataproc/interactive/{self._region}/{s8s_session_id}, configuration changes may not be applied."
305
+ f"Using existing session: https://console.cloud.google.com/dataproc/interactive/{self._region}/{s8s_session_id}?project={self._project_id}, configuration changes may not be applied."
303
306
  )
304
307
  if session is None:
305
308
  session = self.__create_spark_connect_session_from_s8s(
306
- session_response
309
+ session_response, session_name
307
310
  )
308
311
  return session
309
312
  else:
@@ -315,7 +318,7 @@ class DataprocSparkSession(SparkSession):
315
318
 
316
319
  return None
317
320
 
318
- def getOrCreate(self) -> "SparkSession":
321
+ def getOrCreate(self) -> "DataprocSparkSession":
319
322
  with DataprocSparkSession._lock:
320
323
  session = self._get_exiting_active_session()
321
324
  if session is None:
@@ -413,11 +416,11 @@ class DataprocSparkSession(SparkSession):
413
416
  ]
414
417
  import importlib.metadata
415
418
 
416
- dataproc_connect_version = importlib.metadata.version(
419
+ google_connect_version = importlib.metadata.version(
417
420
  "dataproc-spark-connect"
418
421
  )
419
422
  client_version = importlib.metadata.version("pyspark")
420
- version_message = f"Dataproc Spark Connect: {dataproc_connect_version} (PySpark: {client_version}) Dataproc Session Runtime: {version} (Spark: {server_version})"
423
+ version_message = f"Spark Connect: {google_connect_version} (PySpark: {client_version}) Session Runtime: {version} (Spark: {server_version})"
421
424
  logger.info(version_message)
422
425
  if trimmed_version(client_version) != trimmed_version(
423
426
  server_version
@@ -435,14 +438,6 @@ class DataprocSparkSession(SparkSession):
435
438
  spark_connect_session = session_template.spark_connect_session
436
439
  return spark_connect_session
437
440
 
438
- def _get_spark(self, dataproc_config, session_template):
439
- spark = None
440
- if dataproc_config and dataproc_config.spark:
441
- spark = dataproc_config.spark
442
- elif session_template and session_template.spark:
443
- spark = session_template.spark
444
- return spark
445
-
446
441
  def generate_dataproc_session_id(self):
447
442
  timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
448
443
  suffix_length = 6
@@ -466,8 +461,8 @@ class DataprocSparkSession(SparkSession):
466
461
  <div>
467
462
  <p><b>Spark Connect</b></p>
468
463
 
469
- <p><a href="{s8s_session}">Dataproc Session</a></p>
470
- <p><a href="{ui}">Spark UI</a></p>
464
+ <p><a href="{s8s_session}?project={self._project_id}">Serverless Session</a></p>
465
+ <p><a href="{ui}?project={self._project_id}">Spark UI</a></p>
471
466
  </div>
472
467
  """
473
468
 
@@ -484,45 +479,70 @@ class DataprocSparkSession(SparkSession):
484
479
  f"Exception while removing active session in file {file_path} , {e}"
485
480
  )
486
481
 
482
+ def addArtifacts(
483
+ self,
484
+ *artifact: str,
485
+ pyfile: bool = False,
486
+ archive: bool = False,
487
+ file: bool = False,
488
+ pypi: bool = False,
489
+ ) -> None:
490
+ """
491
+ Add artifact(s) to the client session. Currently only local files & pypi installations are supported.
492
+
493
+ .. versionadded:: 3.5.0
494
+
495
+ Parameters
496
+ ----------
497
+ *path : tuple of str
498
+ Artifact's URIs to add.
499
+ pyfile : bool
500
+ Whether to add them as Python dependencies such as .py, .egg, .zip or .jar files.
501
+ The pyfiles are directly inserted into the path when executing Python functions
502
+ in executors.
503
+ archive : bool
504
+ Whether to add them as archives such as .zip, .jar, .tar.gz, .tgz, or .tar files.
505
+ The archives are unpacked on the executor side automatically.
506
+ file : bool
507
+ Add a file to be downloaded with this Spark job on every node.
508
+ The ``path`` passed can only be a local file for now.
509
+ pypi : bool
510
+ This option is only available with DataprocSparkSession. eg. `spark.addArtifacts("spacy==3.8.4", "torch", pypi=True)`
511
+ Installs PyPi package (with its dependencies) in the active Spark session on the driver and executors.
512
+
513
+ Notes
514
+ -----
515
+ This is an API dedicated to Spark Connect client only. With regular Spark Session, it throws
516
+ an exception.
517
+ Regarding pypi: Popular packages are already pre-installed in s8s runtime.
518
+ https://cloud.google.com/dataproc-serverless/docs/concepts/versions/spark-runtime-2.2#python_libraries
519
+ If there are conflicts/package doesn't exist, it throws an exception.
520
+ """
521
+ if sum([pypi, file, pyfile, archive]) > 1:
522
+ raise ValueError(
523
+ "'pyfile', 'archive', 'file' and/or 'pypi' cannot be True together."
524
+ )
525
+ if pypi:
526
+ artifacts = PyPiArtifacts(set(artifact))
527
+ logger.debug("Making addArtifact call to install packages")
528
+ self.addArtifact(
529
+ artifacts.write_packages_config(self._active_s8s_session_uuid),
530
+ file=True,
531
+ )
532
+ else:
533
+ super().addArtifacts(
534
+ *artifact, pyfile=pyfile, archive=archive, file=file
535
+ )
536
+
487
537
  def stop(self) -> None:
488
538
  with DataprocSparkSession._lock:
489
539
  if DataprocSparkSession._active_s8s_session_id is not None:
490
- from google.cloud.dataproc_v1 import SessionControllerClient
491
-
492
- logger.debug(
493
- f"Terminating serverless session: {DataprocSparkSession._active_s8s_session_id}"
540
+ terminate_s8s_session(
541
+ DataprocSparkSession._project_id,
542
+ DataprocSparkSession._region,
543
+ DataprocSparkSession._active_s8s_session_id,
544
+ self._client_options,
494
545
  )
495
- terminate_session_request = TerminateSessionRequest()
496
- session_name = f"projects/{DataprocSparkSession._project_id}/locations/{DataprocSparkSession._region}/sessions/{DataprocSparkSession._active_s8s_session_id}"
497
- terminate_session_request.name = session_name
498
- state = None
499
- try:
500
- SessionControllerClient(
501
- client_options=self._client_options
502
- ).terminate_session(terminate_session_request)
503
- get_session_request = GetSessionRequest()
504
- get_session_request.name = session_name
505
- state = Session.State.ACTIVE
506
- while (
507
- state != Session.State.TERMINATING
508
- and state != Session.State.TERMINATED
509
- and state != Session.State.FAILED
510
- ):
511
- session = SessionControllerClient(
512
- client_options=self._client_options
513
- ).get_session(get_session_request)
514
- state = session.state
515
- sleep(1)
516
- except NotFound:
517
- logger.debug(
518
- f"Session {DataprocSparkSession._active_s8s_session_id} already deleted"
519
- )
520
- except FailedPrecondition:
521
- logger.debug(
522
- f"Session {DataprocSparkSession._active_s8s_session_id} already terminated manually or terminated automatically through session ttl limits"
523
- )
524
- if state is not None and state == Session.State.FAILED:
525
- raise RuntimeError("Serverless session termination failed")
526
546
 
527
547
  self._remove_stoped_session_from_file()
528
548
  DataprocSparkSession._active_s8s_session_uuid = None
@@ -538,3 +558,66 @@ class DataprocSparkSession(SparkSession):
538
558
  DataprocSparkSession._active_session, "session", None
539
559
  ):
540
560
  DataprocSparkSession._active_session.session = None
561
+
562
+
563
+ def terminate_s8s_session(
564
+ project_id, region, active_s8s_session_id, client_options=None
565
+ ):
566
+ from google.cloud.dataproc_v1 import SessionControllerClient
567
+
568
+ logger.debug(f"Terminating serverless session: {active_s8s_session_id}")
569
+ terminate_session_request = TerminateSessionRequest()
570
+ session_name = f"projects/{project_id}/locations/{region}/sessions/{active_s8s_session_id}"
571
+ terminate_session_request.name = session_name
572
+ state = None
573
+ try:
574
+ session_client = SessionControllerClient(client_options=client_options)
575
+ session_client.terminate_session(terminate_session_request)
576
+ get_session_request = GetSessionRequest()
577
+ get_session_request.name = session_name
578
+ state = Session.State.ACTIVE
579
+ while (
580
+ state != Session.State.TERMINATING
581
+ and state != Session.State.TERMINATED
582
+ and state != Session.State.FAILED
583
+ ):
584
+ session = session_client.get_session(get_session_request)
585
+ state = session.state
586
+ sleep(1)
587
+ except NotFound:
588
+ logger.debug(f"Session {active_s8s_session_id} already deleted")
589
+ # Client will get 'Aborted' error if session creation is still in progress and
590
+ # 'FailedPrecondition' if another termination is still in progress.
591
+ # Both are retryable but we catch it and let TTL take care of cleanups.
592
+ except (FailedPrecondition, Aborted):
593
+ logger.debug(
594
+ f"Session {active_s8s_session_id} already terminated manually or terminated automatically through session ttl limits"
595
+ )
596
+ if state is not None and state == Session.State.FAILED:
597
+ raise RuntimeError("Serverless session termination failed")
598
+
599
+
600
+ def get_active_s8s_session_response(
601
+ session_name, client_options
602
+ ) -> Optional[sessions.Session]:
603
+ get_session_request = GetSessionRequest()
604
+ get_session_request.name = session_name
605
+ try:
606
+ get_session_response = SessionControllerClient(
607
+ client_options=client_options
608
+ ).get_session(get_session_request)
609
+ state = get_session_response.state
610
+ except Exception as e:
611
+ logger.info(f"{session_name} deleted: {e}")
612
+ return None
613
+ if state is not None and (
614
+ state == Session.State.ACTIVE or state == Session.State.CREATING
615
+ ):
616
+ return get_session_response
617
+ return None
618
+
619
+
620
+ def is_s8s_session_active(session_name, client_options) -> bool:
621
+ if get_active_s8s_session_response(session_name, client_options) is None:
622
+ return False
623
+ return True
@@ -1,10 +0,0 @@
1
- google/cloud/dataproc_spark_connect/__init__.py,sha256=pybAofW6rmWI-4C8VYm1q0NOZD_sBvFQz43jUBSQW30,616
2
- google/cloud/dataproc_spark_connect/session.py,sha256=A42Wo87VSunG0D3sB-biWyNvU33WhI92mmrJbXI1oNo,23017
3
- google/cloud/dataproc_spark_connect/client/__init__.py,sha256=6hCNSsgYlie6GuVpc5gjFsPnyeMTScTpXSPYqp1fplY,615
4
- google/cloud/dataproc_spark_connect/client/core.py,sha256=7Wy6QwkcWxlHBdo4NsktJEknggPpGkx9F5CS5IpQ7iM,3630
5
- google/cloud/dataproc_spark_connect/client/proxy.py,sha256=ScrbaGsEvqi8wp4ngfD-T9K9mFHXBkVMZkTSr7mdNBs,8926
6
- dataproc_spark_connect-0.2.0.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
7
- dataproc_spark_connect-0.2.0.dist-info/METADATA,sha256=UivMTIfzkp6fzGHG4hiXPUAsRP9P7VBQMKJdEcjmowk,4200
8
- dataproc_spark_connect-0.2.0.dist-info/WHEEL,sha256=OpXWERl2xLPRHTvd2ZXo_iluPEQd8uSbYkJ53NAER_Y,109
9
- dataproc_spark_connect-0.2.0.dist-info/top_level.txt,sha256=_1QvSJIhFAGfxb79D6DhB7SUw2X6T4rwnz_LLrbcD3c,7
10
- dataproc_spark_connect-0.2.0.dist-info/RECORD,,