nurion-raydp 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
raydp/spark/dataset.py ADDED
@@ -0,0 +1,232 @@
1
+ #
2
+ # Licensed to the Apache Software Foundation (ASF) under one or more
3
+ # contributor license agreements. See the NOTICE file distributed with
4
+ # this work for additional information regarding copyright ownership.
5
+ # The ASF licenses this file to You under the Apache License, Version 2.0
6
+ # (the "License"); you may not use this file except in compliance with
7
+ # the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import logging
17
+ import uuid
18
+ from typing import Callable, List, Optional, Union
19
+ from dataclasses import dataclass
20
+
21
+ import pandas as pd
22
+ import pyarrow as pa
23
+ import pyspark.sql as sql
24
+ from pyspark.sql import SparkSession
25
+ from pyspark.sql.dataframe import DataFrame
26
+ from pyspark.sql.types import StructType
27
+ from pyspark.sql.pandas.types import from_arrow_type
28
+ from pyspark.storagelevel import StorageLevel
29
+ import ray
30
+ import ray.cross_language
31
+ from ray.data import Dataset, from_arrow_refs
32
+ from ray.types import ObjectRef
33
+ from ray._private.client_mode_hook import client_mode_wrap
34
+
35
+ from raydp.spark.ray_cluster_master import RAYDP_SPARK_MASTER_SUFFIX
36
+
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+
41
+ @dataclass
42
+ class PartitionObjectsOwner:
43
+ # Actor owner name
44
+ actor_name: str
45
+ # Function that set serialized parquet objects to actor owner state
46
+ # and return result of .remote() calling
47
+ set_reference_as_state: Callable[[ray.actor.ActorHandle, List[ObjectRef]], ObjectRef]
48
+
49
+
50
+ def get_raydp_master_owner(
51
+ spark: Optional[SparkSession] = None,
52
+ ) -> PartitionObjectsOwner:
53
+ if spark is None:
54
+ spark = SparkSession.getActiveSession()
55
+ obj_holder_name = spark.sparkContext.appName + RAYDP_SPARK_MASTER_SUFFIX
56
+
57
+ def raydp_master_set_reference_as_state(
58
+ raydp_master_actor: ray.actor.ActorHandle, objects: List[ObjectRef]
59
+ ) -> ObjectRef:
60
+ return raydp_master_actor.add_objects.remote(uuid.uuid4(), objects)
61
+
62
+ return PartitionObjectsOwner(obj_holder_name, raydp_master_set_reference_as_state)
63
+
64
+
65
+ @client_mode_wrap
66
+ def _register_objects(records):
67
+ worker = ray.worker.global_worker
68
+ blocks: List[ray.ObjectRef] = []
69
+ block_sizes: List[int] = []
70
+ for obj_id, owner, num_record in records:
71
+ object_ref = ray.ObjectRef(obj_id)
72
+ # Register the ownership of the ObjectRef
73
+ worker.core_worker.deserialize_and_register_object_ref(
74
+ object_ref.binary(), ray.ObjectRef.nil(), owner, ""
75
+ )
76
+ blocks.append(object_ref)
77
+ block_sizes.append(num_record)
78
+ return blocks, block_sizes
79
+
80
+
81
+ def _save_spark_df_to_object_store(
82
+ df: sql.DataFrame,
83
+ use_batch: bool = True,
84
+ owner: Union[PartitionObjectsOwner, None] = None,
85
+ ):
86
+ # call java function from python
87
+ jvm = df.sql_ctx.sparkSession.sparkContext._jvm
88
+ jdf = df._jdf
89
+ object_store_writer = jvm.org.apache.spark.sql.raydp.ObjectStoreWriter(jdf)
90
+ actor_owner_name = ""
91
+ if owner is not None:
92
+ actor_owner_name = owner.actor_name
93
+ records = object_store_writer.save(use_batch, actor_owner_name)
94
+
95
+ record_tuples = [
96
+ (record.objectId(), record.ownerAddress(), record.numRecords()) for record in records
97
+ ]
98
+ blocks, block_sizes = _register_objects(record_tuples)
99
+ logger.info(
100
+ f"after _register_objects, len(blocks): {len(blocks)}, len(block_sizes): {len(block_sizes)}"
101
+ )
102
+
103
+ if owner is not None:
104
+ actor_owner = ray.get_actor(actor_owner_name)
105
+ ray.get(owner.set_reference_as_state(actor_owner, blocks))
106
+
107
+ return blocks, block_sizes
108
+
109
+
110
+ def spark_dataframe_to_ray_dataset(
111
+ df: sql.DataFrame,
112
+ parallelism: Optional[int] = None,
113
+ owner: Union[PartitionObjectsOwner, None] = None,
114
+ ):
115
+ num_part = df.rdd.getNumPartitions()
116
+ if parallelism is not None:
117
+ if parallelism != num_part:
118
+ df = df.repartition(parallelism)
119
+ blocks, _ = _save_spark_df_to_object_store(df, False, owner)
120
+ return from_arrow_refs(blocks)
121
+
122
+
123
+ # This is an experimental API for now.
124
+ # If you had any issue using it, welcome to report at our github.
125
+ # This function WILL cache/persist the dataframe!
126
+ def from_spark_recoverable(
127
+ df: sql.DataFrame,
128
+ storage_level: StorageLevel = StorageLevel.MEMORY_AND_DISK,
129
+ parallelism: Optional[int] = None,
130
+ ):
131
+ num_part = df.rdd.getNumPartitions()
132
+ if parallelism is not None:
133
+ if parallelism != num_part:
134
+ df = df.repartition(parallelism)
135
+ sc = df.sql_ctx.sparkSession.sparkContext
136
+ storage_level = sc._getJavaStorageLevel(storage_level)
137
+ object_store_writer = sc._jvm.org.apache.spark.sql.raydp.ObjectStoreWriter
138
+ object_ids = object_store_writer.fromSparkRDD(df._jdf, storage_level)
139
+ owner = object_store_writer.getAddress()
140
+ worker = ray.worker.global_worker
141
+ blocks = []
142
+ for object_id in object_ids:
143
+ object_ref = ray.ObjectRef(object_id)
144
+ # Register the ownership of the ObjectRef
145
+ worker.core_worker.deserialize_and_register_object_ref(
146
+ object_ref.binary(), ray.ObjectRef.nil(), owner, ""
147
+ )
148
+ blocks.append(object_ref)
149
+ return from_arrow_refs(blocks)
150
+
151
+
152
+ def _convert_by_udf(
153
+ spark: sql.SparkSession,
154
+ blocks: List[ObjectRef],
155
+ locations: List[bytes],
156
+ schema: StructType,
157
+ ) -> DataFrame:
158
+ holder_name = spark.sparkContext.appName + RAYDP_SPARK_MASTER_SUFFIX
159
+ holder = ray.get_actor(holder_name)
160
+ df_id = uuid.uuid4()
161
+ ray.get(holder.add_objects.remote(df_id, blocks))
162
+ jvm = spark.sparkContext._jvm
163
+ object_store_reader = jvm.org.apache.spark.sql.raydp.ObjectStoreReader
164
+ # create the rdd then dataframe to utilize locality
165
+ jdf = object_store_reader.createRayObjectRefDF(spark._jsparkSession, locations)
166
+ current_namespace = ray.get_runtime_context().namespace
167
+ ray_address = ray.get(holder.get_ray_address.remote())
168
+ blocks_df = DataFrame(jdf, spark._wrapped if hasattr(spark, "_wrapped") else spark)
169
+
170
+ def _convert_blocks_to_dataframe(blocks):
171
+ # connect to ray
172
+ if not ray.is_initialized():
173
+ ray.init(
174
+ address=ray_address,
175
+ namespace=current_namespace,
176
+ logging_level=logging.WARN,
177
+ )
178
+ obj_holder = ray.get_actor(holder_name)
179
+ for block in blocks:
180
+ dfs = []
181
+ for idx in block["idx"]:
182
+ ref = ray.get(obj_holder.get_object.remote(df_id, idx))
183
+ data = ray.get(ref)
184
+ dfs.append(data.to_pandas())
185
+ yield pd.concat(dfs)
186
+
187
+ df = blocks_df.mapInPandas(_convert_blocks_to_dataframe, schema)
188
+ return df
189
+
190
+
191
+ def _convert_by_rdd(
192
+ spark: sql.SparkSession, blocks: Dataset, locations: List[bytes], schema: StructType
193
+ ) -> DataFrame:
194
+ object_ids = [block.binary() for block in blocks]
195
+ schema_str = schema.json()
196
+ jvm = spark.sparkContext._jvm
197
+ # create rdd in java
198
+ rdd = jvm.org.apache.spark.rdd.RayDatasetRDD(spark._jsc, object_ids, locations)
199
+ # convert the rdd to dataframe
200
+ object_store_reader = jvm.org.apache.spark.sql.raydp.ObjectStoreReader
201
+ jdf = object_store_reader.RayDatasetToDataFrame(spark._jsparkSession, rdd, schema_str)
202
+ return DataFrame(jdf, spark._wrapped if hasattr(spark, "_wrapped") else spark)
203
+
204
+
205
+ @client_mode_wrap
206
+ def get_locations(blocks):
207
+ core_worker = ray.worker.global_worker.core_worker
208
+ return [core_worker.get_owner_address(block) for block in blocks]
209
+
210
+
211
+ def ray_dataset_to_spark_dataframe(
212
+ spark: sql.SparkSession, arrow_schema, blocks: List[ObjectRef], locations=None
213
+ ) -> DataFrame:
214
+ locations = get_locations(blocks)
215
+ if hasattr(arrow_schema, "base_schema"):
216
+ arrow_schema = arrow_schema.base_schema
217
+ if not isinstance(arrow_schema, pa.lib.Schema):
218
+ raise RuntimeError(
219
+ f"Schema is {type(arrow_schema)}, required pyarrow.lib.Schema. \n"
220
+ f"to_spark does not support converting non-arrow ray datasets."
221
+ )
222
+ schema = StructType()
223
+ for field in arrow_schema:
224
+ schema.add(field.name, from_arrow_type(field.type), nullable=field.nullable)
225
+ # TODO how to branch on type of block?
226
+ sample = ray.get(blocks[0])
227
+ if isinstance(sample, bytes):
228
+ return _convert_by_rdd(spark, blocks, locations, schema)
229
+ elif isinstance(sample, pa.Table):
230
+ return _convert_by_udf(spark, blocks, locations, schema)
231
+ else:
232
+ raise RuntimeError("ray.to_spark only supports arrow type blocks")
@@ -0,0 +1,162 @@
1
+ #
2
+ # Licensed to the Apache Software Foundation (ASF) under one or more
3
+ # contributor license agreements. See the NOTICE file distributed with
4
+ # this work for additional information regarding copyright ownership.
5
+ # The ASF licenses this file to You under the Apache License, Version 2.0
6
+ # (the "License"); you may not use this file except in compliance with
7
+ # the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+
18
+ import glob
19
+ import os
20
+ import platform
21
+ import pyspark
22
+ from typing import Dict
23
+
24
+ import ray
25
+ from pyspark.sql.session import SparkSession
26
+
27
+ from .ray_cluster_master import RAYDP_SPARK_MASTER_SUFFIX, RayDPSparkMaster
28
+
29
+ DRIVER_CP_KEY = "spark.driver.extraClassPath"
30
+ DRIVER_JAVA_OPTIONS_KEY = "spark.driver.extraJavaOptions"
31
+
32
+
33
+ def _get_ray_job_id() -> str:
34
+ """Get the Ray job ID from environment or runtime context."""
35
+ # Try environment variable first
36
+ job_id = os.environ.get("RAY_JOB_ID")
37
+ if job_id:
38
+ return job_id
39
+ # Try runtime context
40
+ try:
41
+ ctx = ray.get_runtime_context()
42
+ return ctx.get_job_id()
43
+ except Exception:
44
+ return "local-job-id"
45
+
46
+
47
+ class SparkCluster:
48
+ def __init__(
49
+ self,
50
+ app_name,
51
+ configs,
52
+ logging_level,
53
+ ):
54
+ self._app_name = app_name
55
+ self._configs = configs
56
+ self._logging_level = logging_level
57
+ # self._logger = logging.getLogger(__file__)
58
+ self._prepare_spark_configs()
59
+ self._setup_master(self._get_master_resources(self._configs))
60
+ self._spark_session: SparkSession = None
61
+
62
+ def _setup_master(self, resources: Dict[str, float]):
63
+ spark_master_name = self._app_name + RAYDP_SPARK_MASTER_SUFFIX
64
+
65
+ if resources:
66
+ num_cpu = 1
67
+ if "CPU" in resources:
68
+ num_cpu = resources["CPU"]
69
+ resources.pop("CPU", None)
70
+ self._spark_master_handle = RayDPSparkMaster.options(
71
+ name=spark_master_name,
72
+ num_cpus=num_cpu,
73
+ resources=resources,
74
+ ).remote(self._app_name, self._configs, logging_level=self._logging_level)
75
+ else:
76
+ self._spark_master_handle = RayDPSparkMaster.options(
77
+ name=spark_master_name,
78
+ ).remote(self._app_name, self._configs, logging_level=self._logging_level)
79
+
80
+ ray.get(self._spark_master_handle.start_up.remote(resources))
81
+
82
+ def _get_master_resources(self, configs: Dict[str, str]) -> Dict[str, float]:
83
+ resources = {}
84
+ spark_master_actor_resource_prefix = "spark.ray.master.actor.resource."
85
+
86
+ def get_master_actor_resource(
87
+ key_prefix: str, resource: Dict[str, float]
88
+ ) -> Dict[str, float]:
89
+ for key in configs:
90
+ if key.startswith(key_prefix):
91
+ resource_name = key[len(key_prefix) :]
92
+ resource[resource_name] = float(configs[key])
93
+ return resource
94
+
95
+ resources = get_master_actor_resource(spark_master_actor_resource_prefix, resources)
96
+
97
+ return resources
98
+
99
+ def get_cluster_url(self) -> str:
100
+ return ray.get(self._spark_master_handle.get_master_url.remote())
101
+
102
+ def _prepare_spark_configs(self):
103
+ if self._configs is None:
104
+ self._configs = {}
105
+ if platform.system() != "Darwin":
106
+ driver_node_ip = ray.util.get_node_ip_address()
107
+ if "spark.driver.host" not in self._configs:
108
+ self._configs["spark.driver.host"] = str(driver_node_ip)
109
+ self._configs["spark.driver.bindAddress"] = str(driver_node_ip)
110
+
111
+ raydp_cp = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../jars/*"))
112
+ ray_cp = os.path.abspath(os.path.join(os.path.dirname(ray.__file__), "jars/*"))
113
+ spark_home = os.environ.get("SPARK_HOME", os.path.dirname(pyspark.__file__))
114
+ spark_jars_dir = os.path.abspath(os.path.join(spark_home, "jars/*"))
115
+
116
+ raydp_jars = glob.glob(raydp_cp)
117
+ driver_cp = ":".join(raydp_jars + [spark_jars_dir] + glob.glob(ray_cp))
118
+ if DRIVER_CP_KEY in self._configs:
119
+ self._configs[DRIVER_CP_KEY] += self._configs[DRIVER_CP_KEY] + ":" + driver_cp
120
+ else:
121
+ self._configs[DRIVER_CP_KEY] = driver_cp
122
+
123
+ extra_driver_options = f"-Dray.job.id={_get_ray_job_id()}"
124
+ if DRIVER_JAVA_OPTIONS_KEY in self._configs:
125
+ self._configs[DRIVER_JAVA_OPTIONS_KEY] += " " + extra_driver_options
126
+ else:
127
+ self._configs[DRIVER_JAVA_OPTIONS_KEY] = extra_driver_options
128
+
129
+ python_path_candidates = self._configs.get("spark.executorEnv.PYTHONPATH", "").split(":")
130
+ for k, v in os.environ.items():
131
+ if k == "PYTHONPATH":
132
+ python_path_candidates.append(v)
133
+ if k == "VIRTUAL_ENV":
134
+ python_path_candidates += glob.glob(f"{v}/lib/python*/site-packages")
135
+ self._configs["spark.pyspark.python"] = f"{v}/bin/python"
136
+ self._configs["spark.executorEnv.PYTHONPATH"] = ":".join(
137
+ [x for x in python_path_candidates if len(x) > 0]
138
+ )
139
+
140
+ def get_spark_session(self) -> SparkSession:
141
+ if self._spark_session is not None:
142
+ return self._spark_session
143
+ spark_builder = SparkSession.builder
144
+ for k, v in self._configs.items():
145
+ spark_builder.config(k, v)
146
+ spark_builder.enableHiveSupport()
147
+ self._spark_session = (
148
+ spark_builder.appName(self._app_name).master(self.get_cluster_url()).getOrCreate()
149
+ )
150
+
151
+ print(f"Spark UI: {self._spark_session.sparkContext.uiWebUrl}")
152
+ self._spark_session.sparkContext.setLogLevel(self._logging_level)
153
+ return self._spark_session
154
+
155
+ def stop(self, cleanup_data):
156
+ if self._spark_session is not None:
157
+ self._spark_session.stop()
158
+ self._spark_session = None
159
+ if self._spark_master_handle is not None:
160
+ self._spark_master_handle.stop.remote(cleanup_data)
161
+ if cleanup_data:
162
+ self._spark_master_handle = None
@@ -0,0 +1,102 @@
1
+ #
2
+ # Licensed to the Apache Software Foundation (ASF) under one or more
3
+ # contributor license agreements. See the NOTICE file distributed with
4
+ # this work for additional information regarding copyright ownership.
5
+ # The ASF licenses this file to You under the Apache License, Version 2.0
6
+ # (the "License"); you may not use this file except in compliance with
7
+ # the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+
18
+ import json
19
+ import logging
20
+
21
+ import ray
22
+ import ray.cross_language
23
+ from ray.util.scheduling_strategies import (
24
+ PlacementGroupSchedulingStrategy,
25
+ )
26
+
27
+
28
+ from .ray_pyworker import PyWorker
29
+
30
+ RAYDP_SPARK_MASTER_SUFFIX = "_SPARK_MASTER"
31
+
32
+
33
+ @ray.remote
34
+ class RayDPSparkMaster:
35
+ def __init__(self, app_name, configs, logging_level: str):
36
+ self._logger = logging.getLogger(__file__)
37
+
38
+ self._app_name = app_name
39
+ self._ray_java_master = None
40
+ self._started_up = False
41
+ self._configs = configs
42
+ self._logging_level = logging_level
43
+ self._objects = {}
44
+
45
+ def start_up(self, resources=None):
46
+ if self._started_up:
47
+ self._logger.warning("The RayClusterMaster has started already. Do not call it twice")
48
+ return
49
+ ray_app_master_class = ray.cross_language.java_actor_class(
50
+ "org.apache.spark.deploy.raydp.RayAppMaster",
51
+ # {
52
+ # "runtime_env": {
53
+ # "java_executable": f"java -Dray.logging.level={self._logging_level} -cp {':'.join([p + '/*' for p in code_search_path()])}",
54
+ # },
55
+ # },
56
+ )
57
+ self._logger.info(f"Start the RayClusterMaster with configs: {self._configs}")
58
+ self._ray_java_master = ray_app_master_class.options(resources=resources).remote()
59
+ self._logger.info("The RayClusterMaster has started")
60
+ self._started_up = True
61
+
62
+ def get_app_id(self) -> str:
63
+ assert self._started_up
64
+ return f"raydp-{self._ray_java_master._actor_id.hex()}"
65
+
66
+ def get_master_url(self) -> str:
67
+ assert self._started_up
68
+ url = ray.get(self._ray_java_master.getMasterUrl.remote())
69
+ self._logger.info(f"The master url is {url}")
70
+ return url
71
+
72
+ def create_pyworker(self, worker_id: str, node_id: str, env_vars: str) -> str:
73
+ self._logger.info(
74
+ f"Create a PyWorker with node_id: {node_id}, env_vars: {env_vars}, runtime_env: {ray.get_runtime_context().namespace}"
75
+ )
76
+ envs = json.loads(env_vars)
77
+ pg_name = f"raydp-executor-{self._app_name}-{worker_id}-pg"
78
+ # get placement group by name
79
+ pg = ray.util.get_placement_group(pg_name)
80
+ worker = PyWorker.options(
81
+ runtime_env={
82
+ "env_vars": envs,
83
+ },
84
+ max_concurrency=2,
85
+ scheduling_strategy=PlacementGroupSchedulingStrategy(pg),
86
+ ).remote()
87
+ ray.get(worker.heartbeat.remote())
88
+ return worker
89
+
90
+ def add_objects(self, timestamp, objects):
91
+ self._objects[timestamp] = objects
92
+
93
+ def get_object(self, timestamp, idx):
94
+ return self._objects[timestamp][idx]
95
+
96
+ def get_ray_address(self):
97
+ return ray.worker.global_worker.node.address
98
+
99
+ def stop(self, cleanup_data):
100
+ self._started_up = False
101
+ if cleanup_data:
102
+ ray.actor.exit_actor()
@@ -0,0 +1,135 @@
1
+ #
2
+ # Licensed to the Apache Software Foundation (ASF) under one or more
3
+ # contributor license agreements. See the NOTICE file distributed with
4
+ # this work for additional information regarding copyright ownership.
5
+ # The ASF licenses this file to You under the Apache License, Version 2.0
6
+ # (the "License"); you may not use this file except in compliance with
7
+ # the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+
18
+ import logging
19
+ import numbers
20
+ import os
21
+ import select
22
+ import socket
23
+ from errno import EINTR
24
+ from socket import AF_INET, SOCK_STREAM, SOMAXCONN
25
+
26
+ import ray
27
+
28
+ logger = logging.getLogger(__file__)
29
+
30
+
31
+ def compute_real_exit_code(exit_code):
32
+ # SystemExit's code can be integer or string, but os._exit only accepts integers
33
+ if isinstance(exit_code, numbers.Integral):
34
+ return exit_code
35
+ else:
36
+ return 1
37
+
38
+
39
+ @ray.remote
40
+ class PyWorker:
41
+ def __init__(self):
42
+ logger.info("PyWorker is created")
43
+
44
+ self.listen_sock = socket.socket(AF_INET, SOCK_STREAM)
45
+ self.listen_sock.bind(("127.0.0.1", 0))
46
+ self.listen_sock.listen(max(1024, SOMAXCONN))
47
+ listen_host, self.listen_port = self.listen_sock.getsockname()
48
+
49
+ def heartbeat(self):
50
+ return f"{os.getpid()} is alive"
51
+
52
+ def get_port(self) -> int:
53
+ return self.listen_port
54
+
55
+ def start(self):
56
+ import time
57
+ # Most of the code is copied from PySpark's daemon.py
58
+
59
+ from pyspark.serializers import (
60
+ UTF8Deserializer,
61
+ write_int,
62
+ write_with_length,
63
+ )
64
+ from pyspark.worker import main as worker_main
65
+
66
+ logger.info(f"Starting PyWorker with pid: {os.getpid()}, listen_port={self.listen_port}")
67
+ while True:
68
+ try:
69
+ logger.info("Waiting for connection")
70
+ ready_fds = select.select([0, self.listen_sock], [], [], 1)[0]
71
+ except select.error as ex:
72
+ logger.error(f"select error: {ex}")
73
+ if ex[0] == EINTR:
74
+ continue
75
+ else:
76
+ raise
77
+
78
+ logger.info(f"ready_fds: {ready_fds}")
79
+ if self.listen_sock in ready_fds:
80
+ try:
81
+ sock, _ = self.listen_sock.accept()
82
+ except OSError as e:
83
+ logger.error(f"Failed to accept connection: {e}")
84
+ if e.errno == EINTR:
85
+ continue
86
+ raise
87
+
88
+ try:
89
+ logger.info("Connection accepted")
90
+ # Acknowledge that the fork was successful
91
+ outfile = sock.makefile(mode="wb")
92
+ write_int(os.getpid(), outfile)
93
+ outfile.flush()
94
+ outfile.close()
95
+ while True:
96
+ buffer_size = int(os.environ.get("SPARK_BUFFER_SIZE", 65536))
97
+ infile = os.fdopen(os.dup(sock.fileno()), "rb", buffer_size)
98
+ outfile = os.fdopen(os.dup(sock.fileno()), "wb", buffer_size)
99
+ client_secret = UTF8Deserializer().loads(infile)
100
+ if os.environ["PYTHON_WORKER_FACTORY_SECRET"] == client_secret:
101
+ write_with_length("ok".encode("utf-8"), outfile)
102
+ outfile.flush()
103
+ else:
104
+ write_with_length("err".encode("utf-8"), outfile)
105
+ outfile.flush()
106
+ sock.close()
107
+ return 1
108
+
109
+ try:
110
+ code = worker_main(infile, outfile)
111
+ logger.info(f"normal exit code: {code}")
112
+ except SystemExit as exc:
113
+ code = compute_real_exit_code(exc.code)
114
+ finally:
115
+ try:
116
+ outfile.flush()
117
+ except Exception:
118
+ pass
119
+ # wait for closing
120
+ logger.info(f"exit code: {code}")
121
+ # logger.info("Waiting for closing")
122
+ # try:
123
+ # while sock.recv(1024):
124
+ # pass
125
+ # except Exception:
126
+ # pass
127
+ logger.info("Closing. Waiting for next loop")
128
+ break
129
+ except BaseException as e:
130
+ logger.error(f"PyWorker failed with exception: {e}")
131
+ return 1
132
+ # else:
133
+ # return 0
134
+ else:
135
+ time.sleep(0.5)