schwab 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/.brakeman.yml +75 -0
- data/.claude/commands/release-pr.md +120 -0
- data/.env.example +15 -0
- data/.rspec +3 -0
- data/.rubocop.yml +25 -0
- data/CHANGELOG.md +115 -0
- data/LICENSE.txt +21 -0
- data/README.md +133 -0
- data/Rakefile +12 -0
- data/docs/resource_objects.md +474 -0
- data/lib/schwab/account_number_resolver.rb +123 -0
- data/lib/schwab/accounts.rb +331 -0
- data/lib/schwab/client.rb +266 -0
- data/lib/schwab/configuration.rb +140 -0
- data/lib/schwab/connection.rb +81 -0
- data/lib/schwab/error.rb +51 -0
- data/lib/schwab/market_data.rb +179 -0
- data/lib/schwab/middleware/authentication.rb +100 -0
- data/lib/schwab/middleware/rate_limit.rb +119 -0
- data/lib/schwab/oauth.rb +95 -0
- data/lib/schwab/resources/account.rb +272 -0
- data/lib/schwab/resources/base.rb +300 -0
- data/lib/schwab/resources/order.rb +441 -0
- data/lib/schwab/resources/position.rb +318 -0
- data/lib/schwab/resources/strategy.rb +410 -0
- data/lib/schwab/resources/transaction.rb +333 -0
- data/lib/schwab/version.rb +6 -0
- data/lib/schwab.rb +46 -0
- data/sig/schwab.rbs +4 -0
- data/tasks/prd-accounts-trading-api.md +302 -0
- data/tasks/tasks-prd-accounts-trading-api-reordered.md +140 -0
- data/tasks/tasks-prd-accounts-trading-api.md +106 -0
- metadata +146 -0
@@ -0,0 +1,140 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Schwab
|
4
|
+
# Configuration storage for Schwab SDK
|
5
|
+
#
|
6
|
+
# @example Configure the SDK
|
7
|
+
# Schwab.configure do |config|
|
8
|
+
# config.client_id = "your_client_id"
|
9
|
+
# config.client_secret = "your_client_secret"
|
10
|
+
# config.redirect_uri = "http://localhost:3000/callback"
|
11
|
+
# config.response_format = :hash # or :resource
|
12
|
+
# end
|
13
|
+
class Configuration
|
14
|
+
# @!attribute client_id
|
15
|
+
# @return [String] OAuth client ID from Schwab developer portal
|
16
|
+
# @!attribute client_secret
|
17
|
+
# @return [String] OAuth client secret from Schwab developer portal
|
18
|
+
# @!attribute redirect_uri
|
19
|
+
# @return [String] OAuth callback URL configured in Schwab developer portal
|
20
|
+
# @!attribute api_base_url
|
21
|
+
# @return [String] Base URL for Schwab API (default: https://api.schwabapi.com)
|
22
|
+
# @!attribute api_version
|
23
|
+
# @return [String] API version to use (default: v1)
|
24
|
+
# @!attribute logger
|
25
|
+
# @return [Logger, nil] Logger instance for debugging
|
26
|
+
# @!attribute timeout
|
27
|
+
# @return [Integer] Request timeout in seconds (default: 30)
|
28
|
+
# @!attribute open_timeout
|
29
|
+
# @return [Integer] Connection open timeout in seconds (default: 30)
|
30
|
+
# @!attribute faraday_adapter
|
31
|
+
# @return [Symbol] Faraday adapter to use (default: Faraday.default_adapter)
|
32
|
+
# @!attribute max_retries
|
33
|
+
# @return [Integer] Maximum number of retries for failed requests (default: 3)
|
34
|
+
# @!attribute retry_delay
|
35
|
+
# @return [Integer] Delay in seconds between retries (default: 1)
|
36
|
+
# @!attribute response_format
|
37
|
+
# @return [Symbol] Response format (:hash or :resource, default: :hash)
|
38
|
+
# - :hash returns plain Ruby hashes (default, backward compatible)
|
39
|
+
# - :resource returns Sawyer::Resource-like objects with method access
|
40
|
+
attr_accessor :client_id,
|
41
|
+
:client_secret,
|
42
|
+
:redirect_uri,
|
43
|
+
:api_base_url,
|
44
|
+
:api_version,
|
45
|
+
:logger,
|
46
|
+
:timeout,
|
47
|
+
:open_timeout,
|
48
|
+
:faraday_adapter,
|
49
|
+
:max_retries,
|
50
|
+
:retry_delay
|
51
|
+
|
52
|
+
attr_reader :response_format
|
53
|
+
|
54
|
+
def initialize
|
55
|
+
@api_base_url = "https://api.schwabapi.com"
|
56
|
+
@api_version = "v1"
|
57
|
+
@timeout = 30
|
58
|
+
@open_timeout = 30
|
59
|
+
@faraday_adapter = Faraday.default_adapter
|
60
|
+
@max_retries = 3
|
61
|
+
@retry_delay = 1
|
62
|
+
@logger = nil
|
63
|
+
@response_format = :hash
|
64
|
+
end
|
65
|
+
|
66
|
+
# Set response format with validation
|
67
|
+
#
|
68
|
+
# @param format [Symbol] The response format to use (:hash or :resource)
|
69
|
+
# @raise [ArgumentError] if format is not :hash or :resource
|
70
|
+
# @example Set response format to resource objects
|
71
|
+
# config.response_format = :resource
|
72
|
+
def response_format=(format)
|
73
|
+
valid_formats = [:hash, :resource]
|
74
|
+
unless valid_formats.include?(format)
|
75
|
+
raise ArgumentError, "Invalid response_format: #{format}. Must be :hash or :resource"
|
76
|
+
end
|
77
|
+
|
78
|
+
@response_format = format
|
79
|
+
end
|
80
|
+
|
81
|
+
# Get the full API endpoint URL with version
|
82
|
+
def api_endpoint
|
83
|
+
"#{api_base_url}/#{api_version}"
|
84
|
+
end
|
85
|
+
|
86
|
+
# OAuth-specific endpoints
|
87
|
+
# @return [String] The OAuth authorization URL
|
88
|
+
def oauth_authorize_url
|
89
|
+
"#{api_base_url}/v1/oauth/authorize"
|
90
|
+
end
|
91
|
+
|
92
|
+
# Get the OAuth token endpoint URL
|
93
|
+
# @return [String] The OAuth token URL
|
94
|
+
def oauth_token_url
|
95
|
+
"#{api_base_url}/v1/oauth/token"
|
96
|
+
end
|
97
|
+
|
98
|
+
# Validate that required OAuth parameters are present
|
99
|
+
def validate!
|
100
|
+
missing = []
|
101
|
+
missing << "client_id" if client_id.nil? || client_id.empty?
|
102
|
+
missing << "client_secret" if client_secret.nil? || client_secret.empty?
|
103
|
+
missing << "redirect_uri" if redirect_uri.nil? || redirect_uri.empty?
|
104
|
+
|
105
|
+
unless missing.empty?
|
106
|
+
raise Error, "Missing required configuration: #{missing.join(", ")}"
|
107
|
+
end
|
108
|
+
|
109
|
+
# Validate response_format
|
110
|
+
valid_formats = [:hash, :resource]
|
111
|
+
unless valid_formats.include?(response_format)
|
112
|
+
raise Error, "Invalid response_format: #{response_format}. Must be :hash or :resource"
|
113
|
+
end
|
114
|
+
|
115
|
+
true
|
116
|
+
end
|
117
|
+
|
118
|
+
# Check if OAuth credentials are configured
|
119
|
+
def oauth_configured?
|
120
|
+
!client_id.nil? && !client_secret.nil? && !redirect_uri.nil?
|
121
|
+
end
|
122
|
+
|
123
|
+
# Convert configuration to a hash
|
124
|
+
def to_h
|
125
|
+
{
|
126
|
+
client_id: client_id,
|
127
|
+
client_secret: client_secret,
|
128
|
+
redirect_uri: redirect_uri,
|
129
|
+
api_base_url: api_base_url,
|
130
|
+
timeout: timeout,
|
131
|
+
open_timeout: open_timeout,
|
132
|
+
faraday_adapter: faraday_adapter,
|
133
|
+
max_retries: max_retries,
|
134
|
+
retry_delay: retry_delay,
|
135
|
+
logger: logger,
|
136
|
+
response_format: response_format,
|
137
|
+
}
|
138
|
+
end
|
139
|
+
end
|
140
|
+
end
|
@@ -0,0 +1,81 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require "faraday"
|
4
|
+
require "faraday/middleware"
|
5
|
+
require_relative "middleware/authentication"
|
6
|
+
|
7
|
+
module Schwab
|
8
|
+
# HTTP connection builder for Schwab API
|
9
|
+
module Connection
|
10
|
+
class << self
|
11
|
+
# Build a Faraday connection with the configured middleware stack
|
12
|
+
#
|
13
|
+
# @param access_token [String, nil] OAuth access token for authentication
|
14
|
+
# @param config [Configuration] Configuration object with connection settings
|
15
|
+
# @return [Faraday::Connection] Configured Faraday connection
|
16
|
+
def build(access_token: nil, config: nil)
|
17
|
+
config ||= Schwab.configuration || Configuration.new
|
18
|
+
|
19
|
+
Faraday.new(url: config.api_base_url) do |conn|
|
20
|
+
# Request middleware (executed in order)
|
21
|
+
conn.request(:json) # Encode request bodies as JSON
|
22
|
+
conn.request(:authorization, "Bearer", access_token) if access_token
|
23
|
+
|
24
|
+
# Response middleware (executed in reverse order)
|
25
|
+
conn.response(:json, content_type: /\bjson$/) # Parse JSON responses
|
26
|
+
conn.response(:raise_error) # Raise exceptions for 4xx/5xx responses
|
27
|
+
conn.response(:logger, config.logger, { headers: false, bodies: false }) if config.logger
|
28
|
+
|
29
|
+
# Adapter (must be last)
|
30
|
+
conn.adapter(config.faraday_adapter)
|
31
|
+
|
32
|
+
# Connection options
|
33
|
+
conn.options.timeout = config.timeout
|
34
|
+
conn.options.open_timeout = config.open_timeout
|
35
|
+
end
|
36
|
+
end
|
37
|
+
|
38
|
+
# Build a connection with automatic token refresh capability
|
39
|
+
#
|
40
|
+
# @param access_token [String] Initial access token
|
41
|
+
# @param refresh_token [String, nil] Refresh token for automatic refresh
|
42
|
+
# @param on_token_refresh [Proc, nil] Callback when token is refreshed
|
43
|
+
# @param config [Configuration] Configuration object
|
44
|
+
# @return [Faraday::Connection] Configured connection with refresh capability
|
45
|
+
def build_with_refresh(access_token:, refresh_token: nil, on_token_refresh: nil, config: nil)
|
46
|
+
config ||= Schwab.configuration || Configuration.new
|
47
|
+
|
48
|
+
Faraday.new(url: config.api_base_url) do |conn|
|
49
|
+
# Request middleware
|
50
|
+
conn.request(:json)
|
51
|
+
|
52
|
+
# Custom middleware for token refresh will be added here
|
53
|
+
if refresh_token
|
54
|
+
conn.use(
|
55
|
+
Middleware::TokenRefresh,
|
56
|
+
access_token: access_token,
|
57
|
+
refresh_token: refresh_token,
|
58
|
+
client_id: config.client_id,
|
59
|
+
client_secret: config.client_secret,
|
60
|
+
on_token_refresh: on_token_refresh,
|
61
|
+
)
|
62
|
+
else
|
63
|
+
conn.request(:authorization, "Bearer", access_token)
|
64
|
+
end
|
65
|
+
|
66
|
+
# Response middleware
|
67
|
+
conn.response(:json, content_type: /\bjson$/)
|
68
|
+
conn.response(:raise_error)
|
69
|
+
conn.response(:logger, config.logger, { headers: false, bodies: false }) if config.logger
|
70
|
+
|
71
|
+
# Adapter
|
72
|
+
conn.adapter(config.faraday_adapter)
|
73
|
+
|
74
|
+
# Connection options
|
75
|
+
conn.options.timeout = config.timeout
|
76
|
+
conn.options.open_timeout = config.open_timeout
|
77
|
+
end
|
78
|
+
end
|
79
|
+
end
|
80
|
+
end
|
81
|
+
end
|
data/lib/schwab/error.rb
ADDED
@@ -0,0 +1,51 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Schwab
|
4
|
+
# Base error class for all Schwab SDK errors
|
5
|
+
class Error < StandardError; end
|
6
|
+
|
7
|
+
# Base class for all API-related errors
|
8
|
+
class ApiError < Error
|
9
|
+
attr_reader :status, :response_body, :response_headers
|
10
|
+
|
11
|
+
def initialize(message = nil, status: nil, response_body: nil, response_headers: nil)
|
12
|
+
super(message)
|
13
|
+
@status = status
|
14
|
+
@response_body = response_body
|
15
|
+
@response_headers = response_headers
|
16
|
+
end
|
17
|
+
end
|
18
|
+
|
19
|
+
# Raised when API returns 401 Unauthorized
|
20
|
+
class AuthenticationError < ApiError; end
|
21
|
+
|
22
|
+
# Raised when the access token has expired
|
23
|
+
class TokenExpiredError < AuthenticationError; end
|
24
|
+
|
25
|
+
# Raised when API returns 403 Forbidden
|
26
|
+
class AuthorizationError < ApiError; end
|
27
|
+
|
28
|
+
# Raised when API returns 404 Not Found
|
29
|
+
class NotFoundError < ApiError; end
|
30
|
+
|
31
|
+
# Raised when API returns 429 Too Many Requests
|
32
|
+
class RateLimitError < ApiError
|
33
|
+
attr_reader :retry_after
|
34
|
+
|
35
|
+
def initialize(message = nil, retry_after: nil, **options)
|
36
|
+
super(message, **options)
|
37
|
+
@retry_after = retry_after
|
38
|
+
end
|
39
|
+
end
|
40
|
+
|
41
|
+
# Raised when API returns 5xx Server Error
|
42
|
+
class ServerError < ApiError; end
|
43
|
+
|
44
|
+
# Raised when API returns 400 Bad Request
|
45
|
+
class BadRequestError < ApiError
|
46
|
+
attr_accessor :response_body
|
47
|
+
end
|
48
|
+
|
49
|
+
# Raised when API returns an unexpected status code
|
50
|
+
class UnexpectedResponseError < ApiError; end
|
51
|
+
end
|
@@ -0,0 +1,179 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require "uri"
|
4
|
+
|
5
|
+
module Schwab
|
6
|
+
# Market Data API endpoints for retrieving quotes, price history, and market information
|
7
|
+
module MarketData
|
8
|
+
class << self
|
9
|
+
# Get quotes for one or more symbols
|
10
|
+
#
|
11
|
+
# @param symbols [String, Array<String>] Symbol(s) to get quotes for
|
12
|
+
# @param fields [String, Array<String>, nil] Quote fields to include (e.g., "quote", "fundamental")
|
13
|
+
# @param indicative [Boolean] Whether to include indicative quotes
|
14
|
+
# @param client [Schwab::Client, nil] Optional client instance (uses default if not provided)
|
15
|
+
# @return [Hash] Quote data for the requested symbols
|
16
|
+
# @example Get quotes for multiple symbols
|
17
|
+
# Schwab::MarketData.get_quotes(["AAPL", "MSFT"])
|
18
|
+
# @example Get quotes with specific fields
|
19
|
+
# Schwab::MarketData.get_quotes("AAPL", fields: ["quote", "fundamental"])
|
20
|
+
def get_quotes(symbols, fields: nil, indicative: false, client: nil)
|
21
|
+
client ||= default_client
|
22
|
+
params = {
|
23
|
+
symbols: normalize_symbols(symbols),
|
24
|
+
indicative: indicative,
|
25
|
+
}
|
26
|
+
params[:fields] = normalize_fields(fields) if fields
|
27
|
+
|
28
|
+
client.get("/marketdata/v1/quotes", params)
|
29
|
+
end
|
30
|
+
|
31
|
+
# Get detailed quote for a single symbol
|
32
|
+
#
|
33
|
+
# @param symbol [String] The symbol to get a quote for
|
34
|
+
# @param fields [String, Array<String>, nil] Quote fields to include
|
35
|
+
# @param client [Schwab::Client, nil] Optional client instance
|
36
|
+
# @return [Hash] Detailed quote data for the symbol
|
37
|
+
# @example Get a single quote
|
38
|
+
# Schwab::MarketData.get_quote("AAPL")
|
39
|
+
def get_quote(symbol, fields: nil, client: nil)
|
40
|
+
client ||= default_client
|
41
|
+
path = "/marketdata/v1/#{URI.encode_www_form_component(symbol)}/quotes"
|
42
|
+
params = {}
|
43
|
+
params[:fields] = normalize_fields(fields) if fields
|
44
|
+
|
45
|
+
client.get(path, params)
|
46
|
+
end
|
47
|
+
|
48
|
+
# Get price history for a symbol
|
49
|
+
#
|
50
|
+
# @param symbol [String] The symbol to get price history for
|
51
|
+
# @param period_type [String, nil] The type of period ("day", "month", "year", "ytd")
|
52
|
+
# @param period [Integer, nil] The number of periods
|
53
|
+
# @param frequency_type [String, nil] The type of frequency ("minute", "daily", "weekly", "monthly")
|
54
|
+
# @param frequency [Integer, nil] The frequency value
|
55
|
+
# @param start_date [Time, Date, String, Integer, nil] Start date for history
|
56
|
+
# @param end_date [Time, Date, String, Integer, nil] End date for history
|
57
|
+
# @param need_extended_hours [Boolean] Include extended hours data
|
58
|
+
# @param need_previous_close [Boolean] Include previous close data
|
59
|
+
# @param client [Schwab::Client, nil] Optional client instance
|
60
|
+
# @return [Hash] Price history data with candles
|
61
|
+
# @example Get 5 days of history
|
62
|
+
# Schwab::MarketData.get_quote_history("AAPL", period_type: "day", period: 5)
|
63
|
+
def get_quote_history(symbol, period_type: nil, period: nil, frequency_type: nil,
|
64
|
+
frequency: nil, start_date: nil, end_date: nil,
|
65
|
+
need_extended_hours: true, need_previous_close: false, client: nil)
|
66
|
+
client ||= default_client
|
67
|
+
path = "/marketdata/v1/pricehistory"
|
68
|
+
|
69
|
+
params = { symbol: symbol }
|
70
|
+
params[:periodType] = period_type if period_type
|
71
|
+
params[:period] = period if period
|
72
|
+
params[:frequencyType] = frequency_type if frequency_type
|
73
|
+
params[:frequency] = frequency if frequency
|
74
|
+
params[:startDate] = format_timestamp(start_date) if start_date
|
75
|
+
params[:endDate] = format_timestamp(end_date) if end_date
|
76
|
+
params[:needExtendedHoursData] = need_extended_hours
|
77
|
+
params[:needPreviousClose] = need_previous_close
|
78
|
+
|
79
|
+
client.get(path, params)
|
80
|
+
end
|
81
|
+
|
82
|
+
# Get market movers for an index
|
83
|
+
#
|
84
|
+
# @param index [String] The index symbol (e.g., "$SPX", "$DJI")
|
85
|
+
# @param direction [String, nil] Direction of movement ("up" or "down")
|
86
|
+
# @param change [String, nil] Type of change ("percent" or "value")
|
87
|
+
# @param client [Schwab::Client, nil] Optional client instance
|
88
|
+
# @return [Hash] Market movers data
|
89
|
+
# @example Get top movers for S&P 500
|
90
|
+
# Schwab::MarketData.get_movers("$SPX", direction: "up", change: "percent")
|
91
|
+
def get_movers(index, direction: nil, change: nil, client: nil)
|
92
|
+
client ||= default_client
|
93
|
+
path = "/marketdata/v1/movers/#{URI.encode_www_form_component(index)}"
|
94
|
+
|
95
|
+
params = {}
|
96
|
+
params[:direction] = direction if direction
|
97
|
+
params[:change] = change if change
|
98
|
+
|
99
|
+
client.get(path, params)
|
100
|
+
end
|
101
|
+
|
102
|
+
# Get market hours for one or more markets
|
103
|
+
#
|
104
|
+
# @param markets [String, Array<String>] Market(s) to get hours for (e.g., "EQUITY", "OPTION")
|
105
|
+
# @param date [Date, Time, String, nil] Date to get market hours for
|
106
|
+
# @param client [Schwab::Client, nil] Optional client instance
|
107
|
+
# @return [Hash] Market hours information
|
108
|
+
# @example Get equity market hours
|
109
|
+
# Schwab::MarketData.get_market_hours("EQUITY")
|
110
|
+
def get_market_hours(markets, date: nil, client: nil)
|
111
|
+
client ||= default_client
|
112
|
+
params = {
|
113
|
+
markets: normalize_markets(markets),
|
114
|
+
}
|
115
|
+
params[:date] = format_date(date) if date
|
116
|
+
|
117
|
+
client.get("/marketdata/v1/markets", params)
|
118
|
+
end
|
119
|
+
|
120
|
+
# Get market hours for a single market
|
121
|
+
# Note: This appears to be the same endpoint as get_market_hours
|
122
|
+
# but with a single market instead of multiple
|
123
|
+
def get_market_hour(market_id, date: nil, client: nil)
|
124
|
+
get_market_hours(market_id, date: date, client: client)
|
125
|
+
end
|
126
|
+
|
127
|
+
private
|
128
|
+
|
129
|
+
def default_client
|
130
|
+
raise Error, "No client provided and no global configuration available" unless Schwab.configuration
|
131
|
+
|
132
|
+
@default_client ||= Client.new(
|
133
|
+
access_token: Schwab.configuration.access_token,
|
134
|
+
refresh_token: Schwab.configuration.refresh_token,
|
135
|
+
)
|
136
|
+
end
|
137
|
+
|
138
|
+
def normalize_symbols(symbols)
|
139
|
+
Array(symbols).join(",")
|
140
|
+
end
|
141
|
+
|
142
|
+
def normalize_fields(fields)
|
143
|
+
Array(fields).join(",")
|
144
|
+
end
|
145
|
+
|
146
|
+
def normalize_markets(markets)
|
147
|
+
Array(markets).join(",")
|
148
|
+
end
|
149
|
+
|
150
|
+
def format_timestamp(time)
|
151
|
+
case time
|
152
|
+
when Time, DateTime
|
153
|
+
(time.to_f * 1000).to_i
|
154
|
+
when Date
|
155
|
+
(time.to_time.to_f * 1000).to_i
|
156
|
+
when Integer
|
157
|
+
time
|
158
|
+
when String
|
159
|
+
(Time.parse(time).to_f * 1000).to_i
|
160
|
+
else
|
161
|
+
raise ArgumentError, "Invalid timestamp format: #{time.class}"
|
162
|
+
end
|
163
|
+
end
|
164
|
+
|
165
|
+
def format_date(date)
|
166
|
+
case date
|
167
|
+
when Date
|
168
|
+
date.strftime("%Y-%m-%d")
|
169
|
+
when Time, DateTime
|
170
|
+
date.strftime("%Y-%m-%d")
|
171
|
+
when String
|
172
|
+
Date.parse(date).strftime("%Y-%m-%d")
|
173
|
+
else
|
174
|
+
raise ArgumentError, "Invalid date format: #{date.class}"
|
175
|
+
end
|
176
|
+
end
|
177
|
+
end
|
178
|
+
end
|
179
|
+
end
|
@@ -0,0 +1,100 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require "faraday"
|
4
|
+
|
5
|
+
module Schwab
|
6
|
+
module Middleware
|
7
|
+
# Faraday middleware for automatic token refresh
|
8
|
+
class TokenRefresh < Faraday::Middleware
|
9
|
+
def initialize(app, options = {})
|
10
|
+
super(app)
|
11
|
+
@access_token = options[:access_token]
|
12
|
+
@refresh_token = options[:refresh_token]
|
13
|
+
@client_id = options[:client_id]
|
14
|
+
@client_secret = options[:client_secret]
|
15
|
+
@on_token_refresh = options[:on_token_refresh]
|
16
|
+
@mutex = Mutex.new
|
17
|
+
end
|
18
|
+
|
19
|
+
# Process the request with automatic token refresh on 401
|
20
|
+
# @param env [Faraday::Env] The request environment
|
21
|
+
# @return [Faraday::Response] The response
|
22
|
+
def call(env)
|
23
|
+
# Add the current access token to the request
|
24
|
+
env[:request_headers]["Authorization"] = "Bearer #{@access_token}"
|
25
|
+
|
26
|
+
# Make the request
|
27
|
+
response = @app.call(env)
|
28
|
+
|
29
|
+
# Check if token expired (401 Unauthorized)
|
30
|
+
if response.status == 401 && @refresh_token
|
31
|
+
# Thread-safe token refresh
|
32
|
+
@mutex.synchronize do
|
33
|
+
# Double-check in case another thread already refreshed
|
34
|
+
if response.status == 401
|
35
|
+
refresh_access_token!
|
36
|
+
|
37
|
+
# Retry the request with new token
|
38
|
+
env[:request_headers]["Authorization"] = "Bearer #{@access_token}"
|
39
|
+
response = @app.call(env)
|
40
|
+
end
|
41
|
+
end
|
42
|
+
end
|
43
|
+
|
44
|
+
response
|
45
|
+
rescue Faraday::UnauthorizedError => e
|
46
|
+
# If we get an unauthorized error and have a refresh token, try refreshing
|
47
|
+
if @refresh_token
|
48
|
+
@mutex.synchronize do
|
49
|
+
refresh_access_token!
|
50
|
+
|
51
|
+
# Retry the request with new token
|
52
|
+
env[:request_headers]["Authorization"] = "Bearer #{@access_token}"
|
53
|
+
@app.call(env)
|
54
|
+
end
|
55
|
+
else
|
56
|
+
raise e
|
57
|
+
end
|
58
|
+
end
|
59
|
+
|
60
|
+
private
|
61
|
+
|
62
|
+
def refresh_access_token!
|
63
|
+
# Use the OAuth module to refresh the token
|
64
|
+
result = Schwab::OAuth.refresh_token(
|
65
|
+
refresh_token: @refresh_token,
|
66
|
+
client_id: @client_id,
|
67
|
+
client_secret: @client_secret,
|
68
|
+
)
|
69
|
+
|
70
|
+
# Update our tokens
|
71
|
+
@access_token = result[:access_token]
|
72
|
+
@refresh_token = result[:refresh_token] if result[:refresh_token]
|
73
|
+
|
74
|
+
# Call the callback if provided
|
75
|
+
@on_token_refresh&.call(result)
|
76
|
+
|
77
|
+
result
|
78
|
+
rescue => e
|
79
|
+
# If refresh fails, wrap the error with more context
|
80
|
+
raise Schwab::TokenExpiredError, "Failed to refresh access token: #{e.message}"
|
81
|
+
end
|
82
|
+
end
|
83
|
+
|
84
|
+
# Simple middleware for adding bearer token to requests
|
85
|
+
class Authentication < Faraday::Middleware
|
86
|
+
def initialize(app, token)
|
87
|
+
super(app)
|
88
|
+
@token = token
|
89
|
+
end
|
90
|
+
|
91
|
+
# Add bearer token to the request
|
92
|
+
# @param env [Faraday::Env] The request environment
|
93
|
+
# @return [Faraday::Response] The response
|
94
|
+
def call(env)
|
95
|
+
env[:request_headers]["Authorization"] = "Bearer #{@token}" if @token
|
96
|
+
@app.call(env)
|
97
|
+
end
|
98
|
+
end
|
99
|
+
end
|
100
|
+
end
|
@@ -0,0 +1,119 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require "faraday"
|
4
|
+
|
5
|
+
module Schwab
|
6
|
+
# Middleware components for the HTTP client
|
7
|
+
module Middleware
|
8
|
+
# Faraday middleware for handling rate limits with exponential backoff
|
9
|
+
class RateLimit < Faraday::Middleware
|
10
|
+
# Default maximum number of retries for rate-limited requests
|
11
|
+
DEFAULT_MAX_RETRIES = 3
|
12
|
+
# Default initial retry delay in seconds
|
13
|
+
DEFAULT_RETRY_DELAY = 1 # seconds
|
14
|
+
# Default exponential backoff factor for retries
|
15
|
+
DEFAULT_BACKOFF_FACTOR = 2
|
16
|
+
RETRY_STATUSES = [429, 503].freeze # Rate limited and Service Unavailable
|
17
|
+
|
18
|
+
def initialize(app, options = {})
|
19
|
+
super(app)
|
20
|
+
@max_retries = options[:max_retries] || DEFAULT_MAX_RETRIES
|
21
|
+
@retry_delay = options[:retry_delay] || DEFAULT_RETRY_DELAY
|
22
|
+
@backoff_factor = options[:backoff_factor] || DEFAULT_BACKOFF_FACTOR
|
23
|
+
@logger = options[:logger]
|
24
|
+
end
|
25
|
+
|
26
|
+
# Process the request with rate limit handling
|
27
|
+
# @param env [Faraday::Env] The request environment
|
28
|
+
# @return [Faraday::Response] The response
|
29
|
+
def call(env)
|
30
|
+
retries = 0
|
31
|
+
delay = @retry_delay
|
32
|
+
|
33
|
+
begin
|
34
|
+
response = @app.call(env)
|
35
|
+
|
36
|
+
# Check if we should retry this response
|
37
|
+
if should_retry?(response) && retries < @max_retries
|
38
|
+
retries += 1
|
39
|
+
|
40
|
+
# Check for Retry-After header
|
41
|
+
retry_after = response.headers["retry-after"]
|
42
|
+
wait_time = retry_after ? parse_retry_after(retry_after) : delay
|
43
|
+
|
44
|
+
log_retry(env, response, retries, wait_time)
|
45
|
+
|
46
|
+
# Wait before retrying
|
47
|
+
sleep(wait_time)
|
48
|
+
|
49
|
+
# Exponential backoff for next retry
|
50
|
+
delay *= @backoff_factor
|
51
|
+
|
52
|
+
# Retry the request by raising a custom error
|
53
|
+
raise Faraday::RetriableResponse.new(nil, response)
|
54
|
+
end
|
55
|
+
|
56
|
+
response
|
57
|
+
rescue Faraday::TimeoutError, Faraday::ConnectionFailed => e
|
58
|
+
# Retry on network errors
|
59
|
+
if retries < @max_retries
|
60
|
+
retries += 1
|
61
|
+
|
62
|
+
log_retry_error(env, e, retries, delay)
|
63
|
+
|
64
|
+
sleep(delay)
|
65
|
+
delay *= @backoff_factor
|
66
|
+
|
67
|
+
retry
|
68
|
+
else
|
69
|
+
raise e
|
70
|
+
end
|
71
|
+
rescue Faraday::RetriableResponse
|
72
|
+
# This is our custom retry signal
|
73
|
+
retry
|
74
|
+
end
|
75
|
+
end
|
76
|
+
|
77
|
+
private
|
78
|
+
|
79
|
+
def should_retry?(response)
|
80
|
+
RETRY_STATUSES.include?(response.status)
|
81
|
+
end
|
82
|
+
|
83
|
+
def parse_retry_after(value)
|
84
|
+
# Retry-After can be in seconds (integer) or HTTP date
|
85
|
+
if value.match?(/^\d+$/)
|
86
|
+
value.to_i
|
87
|
+
else
|
88
|
+
# Parse HTTP date and calculate seconds to wait
|
89
|
+
retry_time = Time.httpdate(value)
|
90
|
+
wait_seconds = retry_time - Time.now
|
91
|
+
wait_seconds > 0 ? wait_seconds : @retry_delay
|
92
|
+
end
|
93
|
+
rescue ArgumentError
|
94
|
+
# If we can't parse it, use default delay
|
95
|
+
@retry_delay
|
96
|
+
end
|
97
|
+
|
98
|
+
def log_retry(env, response, attempt, wait_time)
|
99
|
+
return unless @logger
|
100
|
+
|
101
|
+
@logger.info(
|
102
|
+
"[RateLimit] Retrying request to #{env[:url].path} " \
|
103
|
+
"(attempt #{attempt}/#{@max_retries}, status: #{response.status}, " \
|
104
|
+
"waiting: #{wait_time}s)",
|
105
|
+
)
|
106
|
+
end
|
107
|
+
|
108
|
+
def log_retry_error(env, error, attempt, wait_time)
|
109
|
+
return unless @logger
|
110
|
+
|
111
|
+
@logger.info(
|
112
|
+
"[RateLimit] Retrying request to #{env[:url].path} after error " \
|
113
|
+
"(attempt #{attempt}/#{@max_retries}, error: #{error.class}, " \
|
114
|
+
"waiting: #{wait_time}s)",
|
115
|
+
)
|
116
|
+
end
|
117
|
+
end
|
118
|
+
end
|
119
|
+
end
|