spark-connect 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/CHANGELOG.md +82 -0
- data/LICENSE +202 -0
- data/NOTICE +16 -0
- data/README.md +166 -0
- data/lib/spark-connect.rb +5 -0
- data/lib/spark_connect/arrow.rb +115 -0
- data/lib/spark_connect/catalog.rb +190 -0
- data/lib/spark_connect/channel_builder.rb +134 -0
- data/lib/spark_connect/client.rb +264 -0
- data/lib/spark_connect/column.rb +379 -0
- data/lib/spark_connect/conf.rb +79 -0
- data/lib/spark_connect/data_frame.rb +828 -0
- data/lib/spark_connect/errors.rb +58 -0
- data/lib/spark_connect/functions.rb +903 -0
- data/lib/spark_connect/grouped_data.rb +101 -0
- data/lib/spark_connect/na_functions.rb +98 -0
- data/lib/spark_connect/observation.rb +61 -0
- data/lib/spark_connect/pipelines.rb +221 -0
- data/lib/spark_connect/plan.rb +39 -0
- data/lib/spark_connect/proto/spark/connect/base_pb.rb +118 -0
- data/lib/spark_connect/proto/spark/connect/base_services_pb.rb +82 -0
- data/lib/spark_connect/proto/spark/connect/catalog_pb.rb +46 -0
- data/lib/spark_connect/proto/spark/connect/commands_pb.rb +67 -0
- data/lib/spark_connect/proto/spark/connect/common_pb.rb +32 -0
- data/lib/spark_connect/proto/spark/connect/expressions_pb.rb +63 -0
- data/lib/spark_connect/proto/spark/connect/ml_common_pb.rb +22 -0
- data/lib/spark_connect/proto/spark/connect/ml_pb.rb +32 -0
- data/lib/spark_connect/proto/spark/connect/pipelines_pb.rb +45 -0
- data/lib/spark_connect/proto/spark/connect/relations_pb.rb +102 -0
- data/lib/spark_connect/proto/spark/connect/types_pb.rb +46 -0
- data/lib/spark_connect/proto.rb +32 -0
- data/lib/spark_connect/reader.rb +98 -0
- data/lib/spark_connect/row.rb +105 -0
- data/lib/spark_connect/session.rb +317 -0
- data/lib/spark_connect/stat_functions.rb +109 -0
- data/lib/spark_connect/streaming.rb +351 -0
- data/lib/spark_connect/types.rb +490 -0
- data/lib/spark_connect/version.rb +11 -0
- data/lib/spark_connect/window.rb +119 -0
- data/lib/spark_connect/writer.rb +208 -0
- data/lib/spark_connect.rb +58 -0
- data/proto/spark/connect/base.proto +1275 -0
- data/proto/spark/connect/catalog.proto +243 -0
- data/proto/spark/connect/commands.proto +553 -0
- data/proto/spark/connect/common.proto +179 -0
- data/proto/spark/connect/expressions.proto +557 -0
- data/proto/spark/connect/ml.proto +147 -0
- data/proto/spark/connect/ml_common.proto +64 -0
- data/proto/spark/connect/pipelines.proto +307 -0
- data/proto/spark/connect/relations.proto +1252 -0
- data/proto/spark/connect/types.proto +227 -0
- metadata +149 -0
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module SparkConnect
|
|
4
|
+
# The catalog interface for inspecting and managing databases, tables,
|
|
5
|
+
# functions, and the query cache. Returned by {SparkSession#catalog}. Mirrors
|
|
6
|
+
# PySpark's `Catalog`.
|
|
7
|
+
#
|
|
8
|
+
# Methods that return rows ({#list_databases}, {#list_tables}, ...) return
|
|
9
|
+
# arrays of {Row}; predicate methods return booleans.
|
|
10
|
+
#
|
|
11
|
+
# @example
|
|
12
|
+
# spark.catalog.list_tables.each { |t| puts t["name"] }
|
|
13
|
+
# spark.catalog.table_exists("my_table") #=> true
|
|
14
|
+
class Catalog
|
|
15
|
+
Proto = SparkConnect::Proto
|
|
16
|
+
C = Proto::Catalog
|
|
17
|
+
|
|
18
|
+
# @param session [SparkSession]
|
|
19
|
+
def initialize(session)
|
|
20
|
+
@session = session
|
|
21
|
+
end
|
|
22
|
+
|
|
23
|
+
# @return [String] the current default catalog.
|
|
24
|
+
def current_catalog
|
|
25
|
+
scalar(C.new(current_catalog: Proto::CurrentCatalog.new))
|
|
26
|
+
end
|
|
27
|
+
|
|
28
|
+
# Set the current catalog. @return [void]
|
|
29
|
+
def set_current_catalog(name)
|
|
30
|
+
run(C.new(set_current_catalog: Proto::SetCurrentCatalog.new(catalog_name: name.to_s)))
|
|
31
|
+
end
|
|
32
|
+
|
|
33
|
+
# @return [Array<Row>] all catalogs.
|
|
34
|
+
def list_catalogs
|
|
35
|
+
rows(C.new(list_catalogs: Proto::ListCatalogs.new))
|
|
36
|
+
end
|
|
37
|
+
|
|
38
|
+
# @return [String] the current default database.
|
|
39
|
+
def current_database
|
|
40
|
+
scalar(C.new(current_database: Proto::CurrentDatabase.new))
|
|
41
|
+
end
|
|
42
|
+
|
|
43
|
+
# Set the current database. @return [void]
|
|
44
|
+
def set_current_database(name)
|
|
45
|
+
run(C.new(set_current_database: Proto::SetCurrentDatabase.new(db_name: name.to_s)))
|
|
46
|
+
end
|
|
47
|
+
|
|
48
|
+
# @return [Array<Row>] all databases.
|
|
49
|
+
def list_databases
|
|
50
|
+
rows(C.new(list_databases: Proto::ListDatabases.new))
|
|
51
|
+
end
|
|
52
|
+
|
|
53
|
+
# @param db_name [String, nil] restrict to a database.
|
|
54
|
+
# @return [Array<Row>] tables (and views).
|
|
55
|
+
def list_tables(db_name = nil)
|
|
56
|
+
lt = Proto::ListTables.new
|
|
57
|
+
lt.db_name = db_name if db_name
|
|
58
|
+
rows(C.new(list_tables: lt))
|
|
59
|
+
end
|
|
60
|
+
|
|
61
|
+
# @return [Array<Row>] functions registered in the catalog.
|
|
62
|
+
def list_functions(db_name = nil)
|
|
63
|
+
lf = Proto::ListFunctions.new
|
|
64
|
+
lf.db_name = db_name if db_name
|
|
65
|
+
rows(C.new(list_functions: lf))
|
|
66
|
+
end
|
|
67
|
+
|
|
68
|
+
# @param table_name [String]
|
|
69
|
+
# @return [Array<Row>] the columns of a table.
|
|
70
|
+
def list_columns(table_name, db_name = nil)
|
|
71
|
+
lc = Proto::ListColumns.new(table_name: table_name.to_s)
|
|
72
|
+
lc.db_name = db_name if db_name
|
|
73
|
+
rows(C.new(list_columns: lc))
|
|
74
|
+
end
|
|
75
|
+
|
|
76
|
+
# @return [Boolean] whether a table or view exists.
|
|
77
|
+
def table_exists(table_name, db_name = nil)
|
|
78
|
+
te = Proto::TableExists.new(table_name: table_name.to_s)
|
|
79
|
+
te.db_name = db_name if db_name
|
|
80
|
+
scalar(C.new(table_exists: te)) == true
|
|
81
|
+
end
|
|
82
|
+
|
|
83
|
+
# @return [Boolean] whether a database exists.
|
|
84
|
+
def database_exists(db_name)
|
|
85
|
+
scalar(C.new(database_exists: Proto::DatabaseExists.new(db_name: db_name.to_s))) == true
|
|
86
|
+
end
|
|
87
|
+
|
|
88
|
+
# @return [Boolean] whether a function exists.
|
|
89
|
+
def function_exists(function_name, db_name = nil)
|
|
90
|
+
fe = Proto::FunctionExists.new(function_name: function_name.to_s)
|
|
91
|
+
fe.db_name = db_name if db_name
|
|
92
|
+
scalar(C.new(function_exists: fe)) == true
|
|
93
|
+
end
|
|
94
|
+
|
|
95
|
+
# Drop a session-local temporary view. @return [Boolean]
|
|
96
|
+
def drop_temp_view(view_name)
|
|
97
|
+
scalar(C.new(drop_temp_view: Proto::DropTempView.new(view_name: view_name.to_s))) == true
|
|
98
|
+
end
|
|
99
|
+
|
|
100
|
+
# Drop a global temporary view. @return [Boolean]
|
|
101
|
+
def drop_global_temp_view(view_name)
|
|
102
|
+
scalar(C.new(drop_global_temp_view: Proto::DropGlobalTempView.new(view_name: view_name.to_s))) == true
|
|
103
|
+
end
|
|
104
|
+
|
|
105
|
+
# @return [Boolean] whether the table is cached.
|
|
106
|
+
def cached?(table_name)
|
|
107
|
+
scalar(C.new(is_cached: Proto::IsCached.new(table_name: table_name.to_s))) == true
|
|
108
|
+
end
|
|
109
|
+
|
|
110
|
+
# Cache a table in memory. @return [void]
|
|
111
|
+
def cache_table(table_name)
|
|
112
|
+
run(C.new(cache_table: Proto::CacheTable.new(table_name: table_name.to_s)))
|
|
113
|
+
end
|
|
114
|
+
|
|
115
|
+
# Remove a table from the cache. @return [void]
|
|
116
|
+
def uncache_table(table_name)
|
|
117
|
+
run(C.new(uncache_table: Proto::UncacheTable.new(table_name: table_name.to_s)))
|
|
118
|
+
end
|
|
119
|
+
|
|
120
|
+
# Clear all cached tables. @return [void]
|
|
121
|
+
def clear_cache
|
|
122
|
+
run(C.new(clear_cache: Proto::ClearCache.new))
|
|
123
|
+
end
|
|
124
|
+
|
|
125
|
+
# Invalidate and refresh cached metadata for a table. @return [void]
|
|
126
|
+
def refresh_table(table_name)
|
|
127
|
+
run(C.new(refresh_table: Proto::RefreshTable.new(table_name: table_name.to_s)))
|
|
128
|
+
end
|
|
129
|
+
|
|
130
|
+
# Recover all partitions of a table. @return [void]
|
|
131
|
+
def recover_partitions(table_name)
|
|
132
|
+
run(C.new(recover_partitions: Proto::RecoverPartitions.new(table_name: table_name.to_s)))
|
|
133
|
+
end
|
|
134
|
+
|
|
135
|
+
# Create a managed table and return a {DataFrame} over it.
|
|
136
|
+
#
|
|
137
|
+
# @param table_name [String]
|
|
138
|
+
# @param path [String, nil]
|
|
139
|
+
# @param source [String, nil] the data source/format.
|
|
140
|
+
# @param schema [Types::StructType, nil]
|
|
141
|
+
# @param description [String, nil]
|
|
142
|
+
# @param options [Hash{String=>String}]
|
|
143
|
+
# @return [DataFrame]
|
|
144
|
+
def create_table(table_name, path: nil, source: nil, schema: nil, description: nil, options: {})
|
|
145
|
+
ct = Proto::CreateTable.new(table_name: table_name.to_s, options: stringify(options))
|
|
146
|
+
ct.path = path if path
|
|
147
|
+
ct.source = source if source
|
|
148
|
+
ct.description = description if description
|
|
149
|
+
ct.schema = schema.to_proto if schema
|
|
150
|
+
catalog_df(C.new(create_table: ct)).collect # eagerly create the table
|
|
151
|
+
@session.table(table_name.to_s)
|
|
152
|
+
end
|
|
153
|
+
|
|
154
|
+
# Create a table backed by data at `path` (an external/unmanaged table).
|
|
155
|
+
#
|
|
156
|
+
# @return [DataFrame]
|
|
157
|
+
def create_external_table(table_name, path: nil, source: nil, schema: nil, options: {})
|
|
158
|
+
ct = Proto::CreateExternalTable.new(table_name: table_name.to_s, options: stringify(options))
|
|
159
|
+
ct.path = path if path
|
|
160
|
+
ct.source = source if source
|
|
161
|
+
ct.schema = schema.to_proto if schema
|
|
162
|
+
catalog_df(C.new(create_external_table: ct)).collect # eagerly create the table
|
|
163
|
+
@session.table(table_name.to_s)
|
|
164
|
+
end
|
|
165
|
+
|
|
166
|
+
private
|
|
167
|
+
|
|
168
|
+
def stringify(options)
|
|
169
|
+
options.to_h { |k, v| [k.to_s, v.to_s] }
|
|
170
|
+
end
|
|
171
|
+
|
|
172
|
+
def catalog_df(catalog)
|
|
173
|
+
DataFrame.new(@session, PlanBuilder.relation(@session, catalog: catalog))
|
|
174
|
+
end
|
|
175
|
+
|
|
176
|
+
def rows(catalog)
|
|
177
|
+
catalog_df(catalog).collect
|
|
178
|
+
end
|
|
179
|
+
|
|
180
|
+
def scalar(catalog)
|
|
181
|
+
row = catalog_df(catalog).collect.first
|
|
182
|
+
row&.[](0)
|
|
183
|
+
end
|
|
184
|
+
|
|
185
|
+
def run(catalog)
|
|
186
|
+
catalog_df(catalog).collect
|
|
187
|
+
nil
|
|
188
|
+
end
|
|
189
|
+
end
|
|
190
|
+
end
|
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "grpc"
|
|
4
|
+
require "uri"
|
|
5
|
+
|
|
6
|
+
module SparkConnect
|
|
7
|
+
# Parses a Spark Connect connection string (`sc://...`) and builds the gRPC
|
|
8
|
+
# stub, credentials, and per-request metadata.
|
|
9
|
+
#
|
|
10
|
+
# The connection string grammar mirrors the official Spark Connect clients:
|
|
11
|
+
#
|
|
12
|
+
# sc://host[:port][/;param=value;param=value...]
|
|
13
|
+
#
|
|
14
|
+
# Recognised parameters:
|
|
15
|
+
# * `token` - bearer token; implies TLS and adds an `authorization` header
|
|
16
|
+
# * `user_id` - the Spark user id
|
|
17
|
+
# * `user_agent` - client user agent (default `spark-connect-ruby/<version>`)
|
|
18
|
+
# * `use_ssl` - `true`/`false`; force TLS on or off
|
|
19
|
+
# * `session_id` - reuse a specific server-side session id (UUID)
|
|
20
|
+
#
|
|
21
|
+
# Any parameter whose name begins with `x-` is forwarded verbatim as gRPC
|
|
22
|
+
# request metadata.
|
|
23
|
+
#
|
|
24
|
+
# @example
|
|
25
|
+
# cb = SparkConnect::ChannelBuilder.new("sc://localhost:15002")
|
|
26
|
+
# cb.host #=> "localhost"
|
|
27
|
+
# cb.port #=> 15002
|
|
28
|
+
class ChannelBuilder
|
|
29
|
+
DEFAULT_PORT = 15_002
|
|
30
|
+
PARAM_PREFIX = "x-"
|
|
31
|
+
|
|
32
|
+
# @return [String]
|
|
33
|
+
attr_reader :host
|
|
34
|
+
# @return [Integer]
|
|
35
|
+
attr_reader :port
|
|
36
|
+
# @return [Hash{String=>String}] raw connection parameters.
|
|
37
|
+
attr_reader :params
|
|
38
|
+
# @return [String, nil]
|
|
39
|
+
attr_reader :token
|
|
40
|
+
# @return [String, nil]
|
|
41
|
+
attr_reader :user_id
|
|
42
|
+
# @return [String, nil]
|
|
43
|
+
attr_reader :session_id
|
|
44
|
+
|
|
45
|
+
# @param url [String] an `sc://` connection string.
|
|
46
|
+
def initialize(url)
|
|
47
|
+
raise ConnectionError, "Connection string must not be nil" if url.nil?
|
|
48
|
+
raise ConnectionError, "Connection string must start with 'sc://', got: #{url.inspect}" unless url.start_with?("sc://")
|
|
49
|
+
|
|
50
|
+
body = url.delete_prefix("sc://")
|
|
51
|
+
endpoint, _, param_str = body.partition("/")
|
|
52
|
+
@params = parse_params(param_str)
|
|
53
|
+
parse_endpoint(endpoint)
|
|
54
|
+
|
|
55
|
+
@token = @params["token"]
|
|
56
|
+
@user_id = @params["user_id"]
|
|
57
|
+
@session_id = @params["session_id"]
|
|
58
|
+
@use_ssl = parse_bool(@params["use_ssl"]) || !@token.nil?
|
|
59
|
+
end
|
|
60
|
+
|
|
61
|
+
# @return [Boolean] whether the channel uses TLS.
|
|
62
|
+
def ssl?
|
|
63
|
+
@use_ssl
|
|
64
|
+
end
|
|
65
|
+
|
|
66
|
+
# @return [String] the gRPC target, e.g. `"localhost:15002"`.
|
|
67
|
+
def target
|
|
68
|
+
"#{host}:#{port}"
|
|
69
|
+
end
|
|
70
|
+
|
|
71
|
+
# @return [String] the effective user agent.
|
|
72
|
+
def user_agent
|
|
73
|
+
@params["user_agent"] || "spark-connect-ruby/#{SparkConnect::VERSION}"
|
|
74
|
+
end
|
|
75
|
+
|
|
76
|
+
# Per-request gRPC metadata derived from the connection string (bearer token
|
|
77
|
+
# plus any `x-*` parameters).
|
|
78
|
+
#
|
|
79
|
+
# @return [Hash{String=>String}]
|
|
80
|
+
def metadata
|
|
81
|
+
md = {}
|
|
82
|
+
md["authorization"] = "Bearer #{@token}" if @token
|
|
83
|
+
@params.each { |k, v| md[k] = v if k.start_with?(PARAM_PREFIX) }
|
|
84
|
+
md
|
|
85
|
+
end
|
|
86
|
+
|
|
87
|
+
# Build the gRPC stub for the {Spark::Connect::SparkConnectService}.
|
|
88
|
+
#
|
|
89
|
+
# @param channel_args [Hash] extra gRPC channel arguments.
|
|
90
|
+
# @return [Spark::Connect::SparkConnectService::Stub]
|
|
91
|
+
def build_stub(channel_args: {})
|
|
92
|
+
creds = @use_ssl ? GRPC::Core::ChannelCredentials.new : :this_channel_is_insecure
|
|
93
|
+
args = { "grpc.primary_user_agent" => user_agent }.merge(channel_args)
|
|
94
|
+
Proto::SparkConnectService::Stub.new(target, creds, channel_args: args)
|
|
95
|
+
end
|
|
96
|
+
|
|
97
|
+
private
|
|
98
|
+
|
|
99
|
+
Proto = SparkConnect::Proto
|
|
100
|
+
|
|
101
|
+
def parse_endpoint(endpoint)
|
|
102
|
+
raise ConnectionError, "Missing host in connection string" if endpoint.nil? || endpoint.empty?
|
|
103
|
+
|
|
104
|
+
host, sep, port = endpoint.rpartition(":")
|
|
105
|
+
if sep.empty?
|
|
106
|
+
@host = endpoint
|
|
107
|
+
@port = DEFAULT_PORT
|
|
108
|
+
else
|
|
109
|
+
@host = host
|
|
110
|
+
@port = Integer(port, exception: false) ||
|
|
111
|
+
raise(ConnectionError, "Invalid port in connection string: #{port.inspect}")
|
|
112
|
+
end
|
|
113
|
+
end
|
|
114
|
+
|
|
115
|
+
def parse_params(param_str)
|
|
116
|
+
params = {}
|
|
117
|
+
param_str.split(";").each do |kv|
|
|
118
|
+
next if kv.empty?
|
|
119
|
+
|
|
120
|
+
key, sep, value = kv.partition("=")
|
|
121
|
+
raise ConnectionError, "Malformed parameter (expected key=value): #{kv.inspect}" if sep.empty?
|
|
122
|
+
|
|
123
|
+
params[URI.decode_www_form_component(key)] = URI.decode_www_form_component(value)
|
|
124
|
+
end
|
|
125
|
+
params
|
|
126
|
+
end
|
|
127
|
+
|
|
128
|
+
def parse_bool(value)
|
|
129
|
+
return nil if value.nil?
|
|
130
|
+
|
|
131
|
+
%w[true 1 yes].include?(value.to_s.downcase)
|
|
132
|
+
end
|
|
133
|
+
end
|
|
134
|
+
end
|
|
@@ -0,0 +1,264 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "securerandom"
|
|
4
|
+
|
|
5
|
+
module SparkConnect
|
|
6
|
+
# The low-level Spark Connect client. Wraps the gRPC stub and exposes the four
|
|
7
|
+
# core RPC families used by the high-level API: {#execute_plan},
|
|
8
|
+
# {#execute_command}, {#analyze}, and {#config}. Higher layers
|
|
9
|
+
# ({SparkSession}, {DataFrame}) never touch the stub directly.
|
|
10
|
+
#
|
|
11
|
+
# Transient transport failures (e.g. `GRPC::Unavailable`) are retried with
|
|
12
|
+
# exponential backoff and jitter before any response data has been observed.
|
|
13
|
+
class SparkConnectClient
|
|
14
|
+
Proto = SparkConnect::Proto
|
|
15
|
+
|
|
16
|
+
# Accumulated result of an `ExecutePlan` stream.
|
|
17
|
+
#
|
|
18
|
+
# @!attribute [r] arrow_batches
|
|
19
|
+
# @return [Array<String>] each element is one Arrow IPC stream chunk.
|
|
20
|
+
# @!attribute [r] schema
|
|
21
|
+
# @return [Spark::Connect::DataType, nil] result schema, if returned.
|
|
22
|
+
# @!attribute [r] metrics
|
|
23
|
+
# @return [Spark::Connect::ExecutePlanResponse::Metrics, nil]
|
|
24
|
+
# @!attribute [r] observed_metrics
|
|
25
|
+
# @return [Array] observed (named) metrics.
|
|
26
|
+
# @!attribute [r] sql_command_result
|
|
27
|
+
# @return [Spark::Connect::Relation, nil] relation produced by a SQL command.
|
|
28
|
+
ExecuteResult = Struct.new(
|
|
29
|
+
:arrow_batches, :schema, :metrics, :observed_metrics, :sql_command_result, :row_count,
|
|
30
|
+
:write_stream_result, :streaming_query_result, :streaming_manager_result, :checkpoint_relation,
|
|
31
|
+
:pipeline_command_result, :pipeline_events
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
# @return [String] the client-side session id (UUID v4).
|
|
35
|
+
attr_reader :session_id
|
|
36
|
+
# @return [String] the user agent / client type.
|
|
37
|
+
attr_reader :client_type
|
|
38
|
+
# @return [ChannelBuilder]
|
|
39
|
+
attr_reader :channel_builder
|
|
40
|
+
|
|
41
|
+
# @param channel_builder [ChannelBuilder]
|
|
42
|
+
# @param session_id [String, nil] reuse a session id, otherwise generated.
|
|
43
|
+
# @param max_retries [Integer]
|
|
44
|
+
# @param retry_base_delay [Float] base backoff in seconds.
|
|
45
|
+
def initialize(channel_builder, session_id: nil, max_retries: 10, retry_base_delay: 0.05, max_retry_delay: 10.0)
|
|
46
|
+
@channel_builder = channel_builder
|
|
47
|
+
@stub = channel_builder.build_stub
|
|
48
|
+
@metadata = channel_builder.metadata
|
|
49
|
+
@session_id = session_id || channel_builder.session_id || SecureRandom.uuid
|
|
50
|
+
@client_type = channel_builder.user_agent
|
|
51
|
+
@user_context = Proto::UserContext.new(user_id: channel_builder.user_id || "")
|
|
52
|
+
@max_retries = max_retries
|
|
53
|
+
@retry_base_delay = retry_base_delay
|
|
54
|
+
@max_retry_delay = max_retry_delay
|
|
55
|
+
@server_side_session_id = nil
|
|
56
|
+
@tags = []
|
|
57
|
+
end
|
|
58
|
+
|
|
59
|
+
# Operation tags attached to every subsequent execution (used with
|
|
60
|
+
# {#interrupt} `type: :tag`).
|
|
61
|
+
# @return [Array<String>]
|
|
62
|
+
attr_reader :tags
|
|
63
|
+
|
|
64
|
+
# Add an operation tag. Tags must be non-empty and contain no commas.
|
|
65
|
+
# @return [void]
|
|
66
|
+
def add_tag(tag)
|
|
67
|
+
tag = tag.to_s
|
|
68
|
+
raise IllegalArgumentError, "Tag must not be empty" if tag.empty?
|
|
69
|
+
raise IllegalArgumentError, "Tag must not contain ','" if tag.include?(",")
|
|
70
|
+
|
|
71
|
+
@tags << tag unless @tags.include?(tag)
|
|
72
|
+
nil
|
|
73
|
+
end
|
|
74
|
+
|
|
75
|
+
# Remove an operation tag. @return [void]
|
|
76
|
+
def remove_tag(tag)
|
|
77
|
+
@tags.delete(tag.to_s)
|
|
78
|
+
nil
|
|
79
|
+
end
|
|
80
|
+
|
|
81
|
+
# Remove all operation tags. @return [void]
|
|
82
|
+
def clear_tags
|
|
83
|
+
@tags.clear
|
|
84
|
+
nil
|
|
85
|
+
end
|
|
86
|
+
|
|
87
|
+
# Execute a relation plan and accumulate the streamed response.
|
|
88
|
+
#
|
|
89
|
+
# @param relation [Spark::Connect::Relation]
|
|
90
|
+
# @return [ExecuteResult]
|
|
91
|
+
def execute_plan(relation)
|
|
92
|
+
execute(PlanBuilder.root_plan(relation))
|
|
93
|
+
end
|
|
94
|
+
|
|
95
|
+
# Execute a command plan (side-effecting, e.g. write/SQL DML).
|
|
96
|
+
#
|
|
97
|
+
# @param command [Spark::Connect::Command]
|
|
98
|
+
# @return [ExecuteResult]
|
|
99
|
+
def execute_command(command)
|
|
100
|
+
execute(PlanBuilder.command_plan(command))
|
|
101
|
+
end
|
|
102
|
+
|
|
103
|
+
# Run an `AnalyzePlan` request.
|
|
104
|
+
#
|
|
105
|
+
# @param analyze_kw [Hash] exactly one `analyze` oneof keyword, e.g.
|
|
106
|
+
# `schema:`, `explain:`, `tree_string:`, `is_local:`, `spark_version:`.
|
|
107
|
+
# @return [Spark::Connect::AnalyzePlanResponse]
|
|
108
|
+
def analyze(**analyze_kw)
|
|
109
|
+
req = Proto::AnalyzePlanRequest.new(
|
|
110
|
+
session_id: @session_id,
|
|
111
|
+
user_context: @user_context,
|
|
112
|
+
client_type: @client_type,
|
|
113
|
+
**analyze_kw
|
|
114
|
+
)
|
|
115
|
+
with_retries { @stub.analyze_plan(req, metadata: @metadata) }
|
|
116
|
+
end
|
|
117
|
+
|
|
118
|
+
# Run a `Config` request.
|
|
119
|
+
#
|
|
120
|
+
# @param operation [Spark::Connect::ConfigRequest::Operation]
|
|
121
|
+
# @return [Spark::Connect::ConfigResponse]
|
|
122
|
+
def config(operation)
|
|
123
|
+
req = Proto::ConfigRequest.new(
|
|
124
|
+
session_id: @session_id,
|
|
125
|
+
user_context: @user_context,
|
|
126
|
+
client_type: @client_type,
|
|
127
|
+
operation: operation
|
|
128
|
+
)
|
|
129
|
+
with_retries { @stub.config(req, metadata: @metadata) }
|
|
130
|
+
end
|
|
131
|
+
|
|
132
|
+
# Interrupt running operations.
|
|
133
|
+
#
|
|
134
|
+
# @param type [Symbol] `:all`, `:tag`, or `:operation_id`.
|
|
135
|
+
# @param value [String, nil] the tag or operation id when applicable.
|
|
136
|
+
# @return [Spark::Connect::InterruptResponse]
|
|
137
|
+
def interrupt(type: :all, value: nil)
|
|
138
|
+
kw = { interrupt_type: :"INTERRUPT_TYPE_#{type.to_s.upcase}" }
|
|
139
|
+
kw[:operation_tag] = value if type == :tag
|
|
140
|
+
kw[:operation_id] = value if type == :operation_id
|
|
141
|
+
req = Proto::InterruptRequest.new(
|
|
142
|
+
session_id: @session_id, user_context: @user_context, client_type: @client_type, **kw
|
|
143
|
+
)
|
|
144
|
+
with_retries { @stub.interrupt(req, metadata: @metadata) }
|
|
145
|
+
end
|
|
146
|
+
|
|
147
|
+
# Release this client's server-side session.
|
|
148
|
+
# @return [void]
|
|
149
|
+
def release_session
|
|
150
|
+
req = Proto::ReleaseSessionRequest.new(
|
|
151
|
+
session_id: @session_id, user_context: @user_context, client_type: @client_type
|
|
152
|
+
)
|
|
153
|
+
# Best-effort and non-retrying: this runs on teardown, so a dead server
|
|
154
|
+
# must not block the caller with the retry/backoff loop.
|
|
155
|
+
@stub.release_session(req, metadata: @metadata)
|
|
156
|
+
nil
|
|
157
|
+
rescue StandardError
|
|
158
|
+
nil
|
|
159
|
+
end
|
|
160
|
+
|
|
161
|
+
private
|
|
162
|
+
|
|
163
|
+
def execute(plan)
|
|
164
|
+
operation_id = SecureRandom.uuid
|
|
165
|
+
req = Proto::ExecutePlanRequest.new(
|
|
166
|
+
session_id: @session_id,
|
|
167
|
+
user_context: @user_context,
|
|
168
|
+
operation_id: operation_id,
|
|
169
|
+
plan: plan,
|
|
170
|
+
client_type: @client_type,
|
|
171
|
+
tags: @tags
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
result = ExecuteResult.new([], nil, nil, [], nil, 0)
|
|
175
|
+
result.pipeline_events = []
|
|
176
|
+
with_retries do
|
|
177
|
+
responses = @stub.execute_plan(req, metadata: @metadata)
|
|
178
|
+
responses.each do |resp|
|
|
179
|
+
@server_side_session_id = resp.server_side_session_id unless resp.server_side_session_id.empty?
|
|
180
|
+
accumulate(result, resp)
|
|
181
|
+
end
|
|
182
|
+
end
|
|
183
|
+
result
|
|
184
|
+
end
|
|
185
|
+
|
|
186
|
+
def accumulate(result, resp)
|
|
187
|
+
result.schema = resp.schema if resp.schema
|
|
188
|
+
result.metrics = resp.metrics if resp.metrics
|
|
189
|
+
result.observed_metrics += resp.observed_metrics.to_a unless resp.observed_metrics.empty?
|
|
190
|
+
|
|
191
|
+
case resp.response_type
|
|
192
|
+
when :arrow_batch
|
|
193
|
+
batch = resp.arrow_batch
|
|
194
|
+
result.arrow_batches << batch.data unless batch.data.empty?
|
|
195
|
+
result.row_count += batch.row_count
|
|
196
|
+
when :sql_command_result
|
|
197
|
+
result.sql_command_result = resp.sql_command_result.relation
|
|
198
|
+
when :write_stream_operation_start_result
|
|
199
|
+
result.write_stream_result = resp.write_stream_operation_start_result
|
|
200
|
+
when :streaming_query_command_result
|
|
201
|
+
result.streaming_query_result = resp.streaming_query_command_result
|
|
202
|
+
when :streaming_query_manager_command_result
|
|
203
|
+
result.streaming_manager_result = resp.streaming_query_manager_command_result
|
|
204
|
+
when :checkpoint_command_result
|
|
205
|
+
result.checkpoint_relation = resp.checkpoint_command_result.relation
|
|
206
|
+
when :pipeline_command_result
|
|
207
|
+
result.pipeline_command_result = resp.pipeline_command_result
|
|
208
|
+
when :pipeline_event_result
|
|
209
|
+
result.pipeline_events << resp.pipeline_event_result.event
|
|
210
|
+
end
|
|
211
|
+
end
|
|
212
|
+
|
|
213
|
+
def with_retries
|
|
214
|
+
attempt = 0
|
|
215
|
+
begin
|
|
216
|
+
yield
|
|
217
|
+
rescue GRPC::BadStatus => e
|
|
218
|
+
if retryable?(e) && attempt < @max_retries
|
|
219
|
+
delay = backoff(attempt)
|
|
220
|
+
attempt += 1
|
|
221
|
+
sleep(delay)
|
|
222
|
+
retry
|
|
223
|
+
end
|
|
224
|
+
raise translate_error(e)
|
|
225
|
+
end
|
|
226
|
+
end
|
|
227
|
+
|
|
228
|
+
RETRYABLE_CODES = [
|
|
229
|
+
GRPC::Core::StatusCodes::UNAVAILABLE,
|
|
230
|
+
GRPC::Core::StatusCodes::DEADLINE_EXCEEDED,
|
|
231
|
+
GRPC::Core::StatusCodes::ABORTED,
|
|
232
|
+
GRPC::Core::StatusCodes::RESOURCE_EXHAUSTED,
|
|
233
|
+
].freeze
|
|
234
|
+
|
|
235
|
+
def retryable?(error)
|
|
236
|
+
RETRYABLE_CODES.include?(error.code)
|
|
237
|
+
end
|
|
238
|
+
|
|
239
|
+
def backoff(attempt)
|
|
240
|
+
delay = @retry_base_delay * (2**attempt)
|
|
241
|
+
delay = [delay, @max_retry_delay].min
|
|
242
|
+
delay + (rand * delay * 0.5)
|
|
243
|
+
end
|
|
244
|
+
|
|
245
|
+
def translate_error(error)
|
|
246
|
+
message = error.respond_to?(:details) ? error.details : error.message
|
|
247
|
+
code = grpc_code_name(error)
|
|
248
|
+
klass =
|
|
249
|
+
if /\[(PARSE_SYNTAX_ERROR|PARSE_)/.match?(message.to_s) || /ParseException/.match?(message.to_s)
|
|
250
|
+
ParseError
|
|
251
|
+
elsif /AnalysisException|UNRESOLVED_|TABLE_OR_VIEW_NOT_FOUND|\[.*\] /.match?(message.to_s)
|
|
252
|
+
AnalysisError
|
|
253
|
+
else
|
|
254
|
+
SparkConnectError
|
|
255
|
+
end
|
|
256
|
+
error_class = message.to_s[/\[([A-Z0-9_.]+)\]/, 1]
|
|
257
|
+
klass.new(message, error_class: error_class, grpc_code: code)
|
|
258
|
+
end
|
|
259
|
+
|
|
260
|
+
def grpc_code_name(error)
|
|
261
|
+
GRPC::Core::StatusCodes.constants.find { |c| GRPC::Core::StatusCodes.const_get(c) == error.code }&.to_s
|
|
262
|
+
end
|
|
263
|
+
end
|
|
264
|
+
end
|