spark-connect 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +82 -0
  3. data/LICENSE +202 -0
  4. data/NOTICE +16 -0
  5. data/README.md +166 -0
  6. data/lib/spark-connect.rb +5 -0
  7. data/lib/spark_connect/arrow.rb +115 -0
  8. data/lib/spark_connect/catalog.rb +190 -0
  9. data/lib/spark_connect/channel_builder.rb +134 -0
  10. data/lib/spark_connect/client.rb +264 -0
  11. data/lib/spark_connect/column.rb +379 -0
  12. data/lib/spark_connect/conf.rb +79 -0
  13. data/lib/spark_connect/data_frame.rb +828 -0
  14. data/lib/spark_connect/errors.rb +58 -0
  15. data/lib/spark_connect/functions.rb +903 -0
  16. data/lib/spark_connect/grouped_data.rb +101 -0
  17. data/lib/spark_connect/na_functions.rb +98 -0
  18. data/lib/spark_connect/observation.rb +61 -0
  19. data/lib/spark_connect/pipelines.rb +221 -0
  20. data/lib/spark_connect/plan.rb +39 -0
  21. data/lib/spark_connect/proto/spark/connect/base_pb.rb +118 -0
  22. data/lib/spark_connect/proto/spark/connect/base_services_pb.rb +82 -0
  23. data/lib/spark_connect/proto/spark/connect/catalog_pb.rb +46 -0
  24. data/lib/spark_connect/proto/spark/connect/commands_pb.rb +67 -0
  25. data/lib/spark_connect/proto/spark/connect/common_pb.rb +32 -0
  26. data/lib/spark_connect/proto/spark/connect/expressions_pb.rb +63 -0
  27. data/lib/spark_connect/proto/spark/connect/ml_common_pb.rb +22 -0
  28. data/lib/spark_connect/proto/spark/connect/ml_pb.rb +32 -0
  29. data/lib/spark_connect/proto/spark/connect/pipelines_pb.rb +45 -0
  30. data/lib/spark_connect/proto/spark/connect/relations_pb.rb +102 -0
  31. data/lib/spark_connect/proto/spark/connect/types_pb.rb +46 -0
  32. data/lib/spark_connect/proto.rb +32 -0
  33. data/lib/spark_connect/reader.rb +98 -0
  34. data/lib/spark_connect/row.rb +105 -0
  35. data/lib/spark_connect/session.rb +317 -0
  36. data/lib/spark_connect/stat_functions.rb +109 -0
  37. data/lib/spark_connect/streaming.rb +351 -0
  38. data/lib/spark_connect/types.rb +490 -0
  39. data/lib/spark_connect/version.rb +11 -0
  40. data/lib/spark_connect/window.rb +119 -0
  41. data/lib/spark_connect/writer.rb +208 -0
  42. data/lib/spark_connect.rb +58 -0
  43. data/proto/spark/connect/base.proto +1275 -0
  44. data/proto/spark/connect/catalog.proto +243 -0
  45. data/proto/spark/connect/commands.proto +553 -0
  46. data/proto/spark/connect/common.proto +179 -0
  47. data/proto/spark/connect/expressions.proto +557 -0
  48. data/proto/spark/connect/ml.proto +147 -0
  49. data/proto/spark/connect/ml_common.proto +64 -0
  50. data/proto/spark/connect/pipelines.proto +307 -0
  51. data/proto/spark/connect/relations.proto +1252 -0
  52. data/proto/spark/connect/types.proto +227 -0
  53. metadata +149 -0
@@ -0,0 +1,190 @@
1
+ # frozen_string_literal: true
2
+
3
+ module SparkConnect
4
+ # The catalog interface for inspecting and managing databases, tables,
5
+ # functions, and the query cache. Returned by {SparkSession#catalog}. Mirrors
6
+ # PySpark's `Catalog`.
7
+ #
8
+ # Methods that return rows ({#list_databases}, {#list_tables}, ...) return
9
+ # arrays of {Row}; predicate methods return booleans.
10
+ #
11
+ # @example
12
+ # spark.catalog.list_tables.each { |t| puts t["name"] }
13
+ # spark.catalog.table_exists("my_table") #=> true
14
+ class Catalog
15
+ Proto = SparkConnect::Proto
16
+ C = Proto::Catalog
17
+
18
+ # @param session [SparkSession]
19
+ def initialize(session)
20
+ @session = session
21
+ end
22
+
23
+ # @return [String] the current default catalog.
24
+ def current_catalog
25
+ scalar(C.new(current_catalog: Proto::CurrentCatalog.new))
26
+ end
27
+
28
+ # Set the current catalog. @return [void]
29
+ def set_current_catalog(name)
30
+ run(C.new(set_current_catalog: Proto::SetCurrentCatalog.new(catalog_name: name.to_s)))
31
+ end
32
+
33
+ # @return [Array<Row>] all catalogs.
34
+ def list_catalogs
35
+ rows(C.new(list_catalogs: Proto::ListCatalogs.new))
36
+ end
37
+
38
+ # @return [String] the current default database.
39
+ def current_database
40
+ scalar(C.new(current_database: Proto::CurrentDatabase.new))
41
+ end
42
+
43
+ # Set the current database. @return [void]
44
+ def set_current_database(name)
45
+ run(C.new(set_current_database: Proto::SetCurrentDatabase.new(db_name: name.to_s)))
46
+ end
47
+
48
+ # @return [Array<Row>] all databases.
49
+ def list_databases
50
+ rows(C.new(list_databases: Proto::ListDatabases.new))
51
+ end
52
+
53
+ # @param db_name [String, nil] restrict to a database.
54
+ # @return [Array<Row>] tables (and views).
55
+ def list_tables(db_name = nil)
56
+ lt = Proto::ListTables.new
57
+ lt.db_name = db_name if db_name
58
+ rows(C.new(list_tables: lt))
59
+ end
60
+
61
+ # @return [Array<Row>] functions registered in the catalog.
62
+ def list_functions(db_name = nil)
63
+ lf = Proto::ListFunctions.new
64
+ lf.db_name = db_name if db_name
65
+ rows(C.new(list_functions: lf))
66
+ end
67
+
68
+ # @param table_name [String]
69
+ # @return [Array<Row>] the columns of a table.
70
+ def list_columns(table_name, db_name = nil)
71
+ lc = Proto::ListColumns.new(table_name: table_name.to_s)
72
+ lc.db_name = db_name if db_name
73
+ rows(C.new(list_columns: lc))
74
+ end
75
+
76
+ # @return [Boolean] whether a table or view exists.
77
+ def table_exists(table_name, db_name = nil)
78
+ te = Proto::TableExists.new(table_name: table_name.to_s)
79
+ te.db_name = db_name if db_name
80
+ scalar(C.new(table_exists: te)) == true
81
+ end
82
+
83
+ # @return [Boolean] whether a database exists.
84
+ def database_exists(db_name)
85
+ scalar(C.new(database_exists: Proto::DatabaseExists.new(db_name: db_name.to_s))) == true
86
+ end
87
+
88
+ # @return [Boolean] whether a function exists.
89
+ def function_exists(function_name, db_name = nil)
90
+ fe = Proto::FunctionExists.new(function_name: function_name.to_s)
91
+ fe.db_name = db_name if db_name
92
+ scalar(C.new(function_exists: fe)) == true
93
+ end
94
+
95
+ # Drop a session-local temporary view. @return [Boolean]
96
+ def drop_temp_view(view_name)
97
+ scalar(C.new(drop_temp_view: Proto::DropTempView.new(view_name: view_name.to_s))) == true
98
+ end
99
+
100
+ # Drop a global temporary view. @return [Boolean]
101
+ def drop_global_temp_view(view_name)
102
+ scalar(C.new(drop_global_temp_view: Proto::DropGlobalTempView.new(view_name: view_name.to_s))) == true
103
+ end
104
+
105
+ # @return [Boolean] whether the table is cached.
106
+ def cached?(table_name)
107
+ scalar(C.new(is_cached: Proto::IsCached.new(table_name: table_name.to_s))) == true
108
+ end
109
+
110
+ # Cache a table in memory. @return [void]
111
+ def cache_table(table_name)
112
+ run(C.new(cache_table: Proto::CacheTable.new(table_name: table_name.to_s)))
113
+ end
114
+
115
+ # Remove a table from the cache. @return [void]
116
+ def uncache_table(table_name)
117
+ run(C.new(uncache_table: Proto::UncacheTable.new(table_name: table_name.to_s)))
118
+ end
119
+
120
+ # Clear all cached tables. @return [void]
121
+ def clear_cache
122
+ run(C.new(clear_cache: Proto::ClearCache.new))
123
+ end
124
+
125
+ # Invalidate and refresh cached metadata for a table. @return [void]
126
+ def refresh_table(table_name)
127
+ run(C.new(refresh_table: Proto::RefreshTable.new(table_name: table_name.to_s)))
128
+ end
129
+
130
+ # Recover all partitions of a table. @return [void]
131
+ def recover_partitions(table_name)
132
+ run(C.new(recover_partitions: Proto::RecoverPartitions.new(table_name: table_name.to_s)))
133
+ end
134
+
135
+ # Create a managed table and return a {DataFrame} over it.
136
+ #
137
+ # @param table_name [String]
138
+ # @param path [String, nil]
139
+ # @param source [String, nil] the data source/format.
140
+ # @param schema [Types::StructType, nil]
141
+ # @param description [String, nil]
142
+ # @param options [Hash{String=>String}]
143
+ # @return [DataFrame]
144
+ def create_table(table_name, path: nil, source: nil, schema: nil, description: nil, options: {})
145
+ ct = Proto::CreateTable.new(table_name: table_name.to_s, options: stringify(options))
146
+ ct.path = path if path
147
+ ct.source = source if source
148
+ ct.description = description if description
149
+ ct.schema = schema.to_proto if schema
150
+ catalog_df(C.new(create_table: ct)).collect # eagerly create the table
151
+ @session.table(table_name.to_s)
152
+ end
153
+
154
+ # Create a table backed by data at `path` (an external/unmanaged table).
155
+ #
156
+ # @return [DataFrame]
157
+ def create_external_table(table_name, path: nil, source: nil, schema: nil, options: {})
158
+ ct = Proto::CreateExternalTable.new(table_name: table_name.to_s, options: stringify(options))
159
+ ct.path = path if path
160
+ ct.source = source if source
161
+ ct.schema = schema.to_proto if schema
162
+ catalog_df(C.new(create_external_table: ct)).collect # eagerly create the table
163
+ @session.table(table_name.to_s)
164
+ end
165
+
166
+ private
167
+
168
+ def stringify(options)
169
+ options.to_h { |k, v| [k.to_s, v.to_s] }
170
+ end
171
+
172
+ def catalog_df(catalog)
173
+ DataFrame.new(@session, PlanBuilder.relation(@session, catalog: catalog))
174
+ end
175
+
176
+ def rows(catalog)
177
+ catalog_df(catalog).collect
178
+ end
179
+
180
+ def scalar(catalog)
181
+ row = catalog_df(catalog).collect.first
182
+ row&.[](0)
183
+ end
184
+
185
+ def run(catalog)
186
+ catalog_df(catalog).collect
187
+ nil
188
+ end
189
+ end
190
+ end
@@ -0,0 +1,134 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "grpc"
4
+ require "uri"
5
+
6
+ module SparkConnect
7
+ # Parses a Spark Connect connection string (`sc://...`) and builds the gRPC
8
+ # stub, credentials, and per-request metadata.
9
+ #
10
+ # The connection string grammar mirrors the official Spark Connect clients:
11
+ #
12
+ # sc://host[:port][/;param=value;param=value...]
13
+ #
14
+ # Recognised parameters:
15
+ # * `token` - bearer token; implies TLS and adds an `authorization` header
16
+ # * `user_id` - the Spark user id
17
+ # * `user_agent` - client user agent (default `spark-connect-ruby/<version>`)
18
+ # * `use_ssl` - `true`/`false`; force TLS on or off
19
+ # * `session_id` - reuse a specific server-side session id (UUID)
20
+ #
21
+ # Any parameter whose name begins with `x-` is forwarded verbatim as gRPC
22
+ # request metadata.
23
+ #
24
+ # @example
25
+ # cb = SparkConnect::ChannelBuilder.new("sc://localhost:15002")
26
+ # cb.host #=> "localhost"
27
+ # cb.port #=> 15002
28
+ class ChannelBuilder
29
+ DEFAULT_PORT = 15_002
30
+ PARAM_PREFIX = "x-"
31
+
32
+ # @return [String]
33
+ attr_reader :host
34
+ # @return [Integer]
35
+ attr_reader :port
36
+ # @return [Hash{String=>String}] raw connection parameters.
37
+ attr_reader :params
38
+ # @return [String, nil]
39
+ attr_reader :token
40
+ # @return [String, nil]
41
+ attr_reader :user_id
42
+ # @return [String, nil]
43
+ attr_reader :session_id
44
+
45
+ # @param url [String] an `sc://` connection string.
46
+ def initialize(url)
47
+ raise ConnectionError, "Connection string must not be nil" if url.nil?
48
+ raise ConnectionError, "Connection string must start with 'sc://', got: #{url.inspect}" unless url.start_with?("sc://")
49
+
50
+ body = url.delete_prefix("sc://")
51
+ endpoint, _, param_str = body.partition("/")
52
+ @params = parse_params(param_str)
53
+ parse_endpoint(endpoint)
54
+
55
+ @token = @params["token"]
56
+ @user_id = @params["user_id"]
57
+ @session_id = @params["session_id"]
58
+ @use_ssl = parse_bool(@params["use_ssl"]) || !@token.nil?
59
+ end
60
+
61
+ # @return [Boolean] whether the channel uses TLS.
62
+ def ssl?
63
+ @use_ssl
64
+ end
65
+
66
+ # @return [String] the gRPC target, e.g. `"localhost:15002"`.
67
+ def target
68
+ "#{host}:#{port}"
69
+ end
70
+
71
+ # @return [String] the effective user agent.
72
+ def user_agent
73
+ @params["user_agent"] || "spark-connect-ruby/#{SparkConnect::VERSION}"
74
+ end
75
+
76
+ # Per-request gRPC metadata derived from the connection string (bearer token
77
+ # plus any `x-*` parameters).
78
+ #
79
+ # @return [Hash{String=>String}]
80
+ def metadata
81
+ md = {}
82
+ md["authorization"] = "Bearer #{@token}" if @token
83
+ @params.each { |k, v| md[k] = v if k.start_with?(PARAM_PREFIX) }
84
+ md
85
+ end
86
+
87
+ # Build the gRPC stub for the {Spark::Connect::SparkConnectService}.
88
+ #
89
+ # @param channel_args [Hash] extra gRPC channel arguments.
90
+ # @return [Spark::Connect::SparkConnectService::Stub]
91
+ def build_stub(channel_args: {})
92
+ creds = @use_ssl ? GRPC::Core::ChannelCredentials.new : :this_channel_is_insecure
93
+ args = { "grpc.primary_user_agent" => user_agent }.merge(channel_args)
94
+ Proto::SparkConnectService::Stub.new(target, creds, channel_args: args)
95
+ end
96
+
97
+ private
98
+
99
+ Proto = SparkConnect::Proto
100
+
101
+ def parse_endpoint(endpoint)
102
+ raise ConnectionError, "Missing host in connection string" if endpoint.nil? || endpoint.empty?
103
+
104
+ host, sep, port = endpoint.rpartition(":")
105
+ if sep.empty?
106
+ @host = endpoint
107
+ @port = DEFAULT_PORT
108
+ else
109
+ @host = host
110
+ @port = Integer(port, exception: false) ||
111
+ raise(ConnectionError, "Invalid port in connection string: #{port.inspect}")
112
+ end
113
+ end
114
+
115
+ def parse_params(param_str)
116
+ params = {}
117
+ param_str.split(";").each do |kv|
118
+ next if kv.empty?
119
+
120
+ key, sep, value = kv.partition("=")
121
+ raise ConnectionError, "Malformed parameter (expected key=value): #{kv.inspect}" if sep.empty?
122
+
123
+ params[URI.decode_www_form_component(key)] = URI.decode_www_form_component(value)
124
+ end
125
+ params
126
+ end
127
+
128
+ def parse_bool(value)
129
+ return nil if value.nil?
130
+
131
+ %w[true 1 yes].include?(value.to_s.downcase)
132
+ end
133
+ end
134
+ end
@@ -0,0 +1,264 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "securerandom"
4
+
5
+ module SparkConnect
6
+ # The low-level Spark Connect client. Wraps the gRPC stub and exposes the four
7
+ # core RPC families used by the high-level API: {#execute_plan},
8
+ # {#execute_command}, {#analyze}, and {#config}. Higher layers
9
+ # ({SparkSession}, {DataFrame}) never touch the stub directly.
10
+ #
11
+ # Transient transport failures (e.g. `GRPC::Unavailable`) are retried with
12
+ # exponential backoff and jitter before any response data has been observed.
13
+ class SparkConnectClient
14
+ Proto = SparkConnect::Proto
15
+
16
+ # Accumulated result of an `ExecutePlan` stream.
17
+ #
18
+ # @!attribute [r] arrow_batches
19
+ # @return [Array<String>] each element is one Arrow IPC stream chunk.
20
+ # @!attribute [r] schema
21
+ # @return [Spark::Connect::DataType, nil] result schema, if returned.
22
+ # @!attribute [r] metrics
23
+ # @return [Spark::Connect::ExecutePlanResponse::Metrics, nil]
24
+ # @!attribute [r] observed_metrics
25
+ # @return [Array] observed (named) metrics.
26
+ # @!attribute [r] sql_command_result
27
+ # @return [Spark::Connect::Relation, nil] relation produced by a SQL command.
28
+ ExecuteResult = Struct.new(
29
+ :arrow_batches, :schema, :metrics, :observed_metrics, :sql_command_result, :row_count,
30
+ :write_stream_result, :streaming_query_result, :streaming_manager_result, :checkpoint_relation,
31
+ :pipeline_command_result, :pipeline_events
32
+ )
33
+
34
+ # @return [String] the client-side session id (UUID v4).
35
+ attr_reader :session_id
36
+ # @return [String] the user agent / client type.
37
+ attr_reader :client_type
38
+ # @return [ChannelBuilder]
39
+ attr_reader :channel_builder
40
+
41
+ # @param channel_builder [ChannelBuilder]
42
+ # @param session_id [String, nil] reuse a session id, otherwise generated.
43
+ # @param max_retries [Integer]
44
+ # @param retry_base_delay [Float] base backoff in seconds.
45
+ def initialize(channel_builder, session_id: nil, max_retries: 10, retry_base_delay: 0.05, max_retry_delay: 10.0)
46
+ @channel_builder = channel_builder
47
+ @stub = channel_builder.build_stub
48
+ @metadata = channel_builder.metadata
49
+ @session_id = session_id || channel_builder.session_id || SecureRandom.uuid
50
+ @client_type = channel_builder.user_agent
51
+ @user_context = Proto::UserContext.new(user_id: channel_builder.user_id || "")
52
+ @max_retries = max_retries
53
+ @retry_base_delay = retry_base_delay
54
+ @max_retry_delay = max_retry_delay
55
+ @server_side_session_id = nil
56
+ @tags = []
57
+ end
58
+
59
+ # Operation tags attached to every subsequent execution (used with
60
+ # {#interrupt} `type: :tag`).
61
+ # @return [Array<String>]
62
+ attr_reader :tags
63
+
64
+ # Add an operation tag. Tags must be non-empty and contain no commas.
65
+ # @return [void]
66
+ def add_tag(tag)
67
+ tag = tag.to_s
68
+ raise IllegalArgumentError, "Tag must not be empty" if tag.empty?
69
+ raise IllegalArgumentError, "Tag must not contain ','" if tag.include?(",")
70
+
71
+ @tags << tag unless @tags.include?(tag)
72
+ nil
73
+ end
74
+
75
+ # Remove an operation tag. @return [void]
76
+ def remove_tag(tag)
77
+ @tags.delete(tag.to_s)
78
+ nil
79
+ end
80
+
81
+ # Remove all operation tags. @return [void]
82
+ def clear_tags
83
+ @tags.clear
84
+ nil
85
+ end
86
+
87
+ # Execute a relation plan and accumulate the streamed response.
88
+ #
89
+ # @param relation [Spark::Connect::Relation]
90
+ # @return [ExecuteResult]
91
+ def execute_plan(relation)
92
+ execute(PlanBuilder.root_plan(relation))
93
+ end
94
+
95
+ # Execute a command plan (side-effecting, e.g. write/SQL DML).
96
+ #
97
+ # @param command [Spark::Connect::Command]
98
+ # @return [ExecuteResult]
99
+ def execute_command(command)
100
+ execute(PlanBuilder.command_plan(command))
101
+ end
102
+
103
+ # Run an `AnalyzePlan` request.
104
+ #
105
+ # @param analyze_kw [Hash] exactly one `analyze` oneof keyword, e.g.
106
+ # `schema:`, `explain:`, `tree_string:`, `is_local:`, `spark_version:`.
107
+ # @return [Spark::Connect::AnalyzePlanResponse]
108
+ def analyze(**analyze_kw)
109
+ req = Proto::AnalyzePlanRequest.new(
110
+ session_id: @session_id,
111
+ user_context: @user_context,
112
+ client_type: @client_type,
113
+ **analyze_kw
114
+ )
115
+ with_retries { @stub.analyze_plan(req, metadata: @metadata) }
116
+ end
117
+
118
+ # Run a `Config` request.
119
+ #
120
+ # @param operation [Spark::Connect::ConfigRequest::Operation]
121
+ # @return [Spark::Connect::ConfigResponse]
122
+ def config(operation)
123
+ req = Proto::ConfigRequest.new(
124
+ session_id: @session_id,
125
+ user_context: @user_context,
126
+ client_type: @client_type,
127
+ operation: operation
128
+ )
129
+ with_retries { @stub.config(req, metadata: @metadata) }
130
+ end
131
+
132
+ # Interrupt running operations.
133
+ #
134
+ # @param type [Symbol] `:all`, `:tag`, or `:operation_id`.
135
+ # @param value [String, nil] the tag or operation id when applicable.
136
+ # @return [Spark::Connect::InterruptResponse]
137
+ def interrupt(type: :all, value: nil)
138
+ kw = { interrupt_type: :"INTERRUPT_TYPE_#{type.to_s.upcase}" }
139
+ kw[:operation_tag] = value if type == :tag
140
+ kw[:operation_id] = value if type == :operation_id
141
+ req = Proto::InterruptRequest.new(
142
+ session_id: @session_id, user_context: @user_context, client_type: @client_type, **kw
143
+ )
144
+ with_retries { @stub.interrupt(req, metadata: @metadata) }
145
+ end
146
+
147
+ # Release this client's server-side session.
148
+ # @return [void]
149
+ def release_session
150
+ req = Proto::ReleaseSessionRequest.new(
151
+ session_id: @session_id, user_context: @user_context, client_type: @client_type
152
+ )
153
+ # Best-effort and non-retrying: this runs on teardown, so a dead server
154
+ # must not block the caller with the retry/backoff loop.
155
+ @stub.release_session(req, metadata: @metadata)
156
+ nil
157
+ rescue StandardError
158
+ nil
159
+ end
160
+
161
+ private
162
+
163
+ def execute(plan)
164
+ operation_id = SecureRandom.uuid
165
+ req = Proto::ExecutePlanRequest.new(
166
+ session_id: @session_id,
167
+ user_context: @user_context,
168
+ operation_id: operation_id,
169
+ plan: plan,
170
+ client_type: @client_type,
171
+ tags: @tags
172
+ )
173
+
174
+ result = ExecuteResult.new([], nil, nil, [], nil, 0)
175
+ result.pipeline_events = []
176
+ with_retries do
177
+ responses = @stub.execute_plan(req, metadata: @metadata)
178
+ responses.each do |resp|
179
+ @server_side_session_id = resp.server_side_session_id unless resp.server_side_session_id.empty?
180
+ accumulate(result, resp)
181
+ end
182
+ end
183
+ result
184
+ end
185
+
186
+ def accumulate(result, resp)
187
+ result.schema = resp.schema if resp.schema
188
+ result.metrics = resp.metrics if resp.metrics
189
+ result.observed_metrics += resp.observed_metrics.to_a unless resp.observed_metrics.empty?
190
+
191
+ case resp.response_type
192
+ when :arrow_batch
193
+ batch = resp.arrow_batch
194
+ result.arrow_batches << batch.data unless batch.data.empty?
195
+ result.row_count += batch.row_count
196
+ when :sql_command_result
197
+ result.sql_command_result = resp.sql_command_result.relation
198
+ when :write_stream_operation_start_result
199
+ result.write_stream_result = resp.write_stream_operation_start_result
200
+ when :streaming_query_command_result
201
+ result.streaming_query_result = resp.streaming_query_command_result
202
+ when :streaming_query_manager_command_result
203
+ result.streaming_manager_result = resp.streaming_query_manager_command_result
204
+ when :checkpoint_command_result
205
+ result.checkpoint_relation = resp.checkpoint_command_result.relation
206
+ when :pipeline_command_result
207
+ result.pipeline_command_result = resp.pipeline_command_result
208
+ when :pipeline_event_result
209
+ result.pipeline_events << resp.pipeline_event_result.event
210
+ end
211
+ end
212
+
213
+ def with_retries
214
+ attempt = 0
215
+ begin
216
+ yield
217
+ rescue GRPC::BadStatus => e
218
+ if retryable?(e) && attempt < @max_retries
219
+ delay = backoff(attempt)
220
+ attempt += 1
221
+ sleep(delay)
222
+ retry
223
+ end
224
+ raise translate_error(e)
225
+ end
226
+ end
227
+
228
+ RETRYABLE_CODES = [
229
+ GRPC::Core::StatusCodes::UNAVAILABLE,
230
+ GRPC::Core::StatusCodes::DEADLINE_EXCEEDED,
231
+ GRPC::Core::StatusCodes::ABORTED,
232
+ GRPC::Core::StatusCodes::RESOURCE_EXHAUSTED,
233
+ ].freeze
234
+
235
+ def retryable?(error)
236
+ RETRYABLE_CODES.include?(error.code)
237
+ end
238
+
239
+ def backoff(attempt)
240
+ delay = @retry_base_delay * (2**attempt)
241
+ delay = [delay, @max_retry_delay].min
242
+ delay + (rand * delay * 0.5)
243
+ end
244
+
245
+ def translate_error(error)
246
+ message = error.respond_to?(:details) ? error.details : error.message
247
+ code = grpc_code_name(error)
248
+ klass =
249
+ if /\[(PARSE_SYNTAX_ERROR|PARSE_)/.match?(message.to_s) || /ParseException/.match?(message.to_s)
250
+ ParseError
251
+ elsif /AnalysisException|UNRESOLVED_|TABLE_OR_VIEW_NOT_FOUND|\[.*\] /.match?(message.to_s)
252
+ AnalysisError
253
+ else
254
+ SparkConnectError
255
+ end
256
+ error_class = message.to_s[/\[([A-Z0-9_.]+)\]/, 1]
257
+ klass.new(message, error_class: error_class, grpc_code: code)
258
+ end
259
+
260
+ def grpc_code_name(error)
261
+ GRPC::Core::StatusCodes.constants.find { |c| GRPC::Core::StatusCodes.const_get(c) == error.code }&.to_s
262
+ end
263
+ end
264
+ end