atproto_auth 0.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
 - data/.rubocop.yml +16 -0
 - data/CHANGELOG.md +5 -0
 - data/LICENSE.txt +21 -0
 - data/README.md +179 -0
 - data/Rakefile +16 -0
 - data/examples/confidential_client/Gemfile +12 -0
 - data/examples/confidential_client/Gemfile.lock +84 -0
 - data/examples/confidential_client/README.md +110 -0
 - data/examples/confidential_client/app.rb +136 -0
 - data/examples/confidential_client/config/client-metadata.json +25 -0
 - data/examples/confidential_client/config.ru +4 -0
 - data/examples/confidential_client/public/client-metadata.json +24 -0
 - data/examples/confidential_client/public/styles.css +70 -0
 - data/examples/confidential_client/scripts/generate_keys.rb +15 -0
 - data/examples/confidential_client/views/authorized.erb +29 -0
 - data/examples/confidential_client/views/index.erb +44 -0
 - data/examples/confidential_client/views/layout.erb +11 -0
 - data/lib/atproto_auth/client.rb +410 -0
 - data/lib/atproto_auth/client_metadata.rb +264 -0
 - data/lib/atproto_auth/configuration.rb +17 -0
 - data/lib/atproto_auth/dpop/client.rb +122 -0
 - data/lib/atproto_auth/dpop/key_manager.rb +235 -0
 - data/lib/atproto_auth/dpop/nonce_manager.rb +138 -0
 - data/lib/atproto_auth/dpop/proof_generator.rb +112 -0
 - data/lib/atproto_auth/errors.rb +47 -0
 - data/lib/atproto_auth/http_client.rb +227 -0
 - data/lib/atproto_auth/identity/document.rb +104 -0
 - data/lib/atproto_auth/identity/resolver.rb +221 -0
 - data/lib/atproto_auth/identity.rb +24 -0
 - data/lib/atproto_auth/par/client.rb +203 -0
 - data/lib/atproto_auth/par/client_assertion.rb +50 -0
 - data/lib/atproto_auth/par/request.rb +140 -0
 - data/lib/atproto_auth/par/response.rb +23 -0
 - data/lib/atproto_auth/par.rb +40 -0
 - data/lib/atproto_auth/pkce.rb +105 -0
 - data/lib/atproto_auth/server_metadata/authorization_server.rb +175 -0
 - data/lib/atproto_auth/server_metadata/origin_url.rb +51 -0
 - data/lib/atproto_auth/server_metadata/resource_server.rb +71 -0
 - data/lib/atproto_auth/server_metadata.rb +24 -0
 - data/lib/atproto_auth/state/session.rb +117 -0
 - data/lib/atproto_auth/state/session_manager.rb +75 -0
 - data/lib/atproto_auth/state/token_set.rb +68 -0
 - data/lib/atproto_auth/state.rb +54 -0
 - data/lib/atproto_auth/version.rb +5 -0
 - data/lib/atproto_auth.rb +56 -0
 - data/sig/atproto_auth/client_metadata.rbs +95 -0
 - data/sig/atproto_auth/dpop/client.rbs +38 -0
 - data/sig/atproto_auth/dpop/key_manager.rbs +33 -0
 - data/sig/atproto_auth/dpop/nonce_manager.rbs +48 -0
 - data/sig/atproto_auth/dpop/proof_generator.rbs +42 -0
 - data/sig/atproto_auth/http_client.rbs +58 -0
 - data/sig/atproto_auth/identity/document.rbs +31 -0
 - data/sig/atproto_auth/identity/resolver.rbs +41 -0
 - data/sig/atproto_auth/par/client.rbs +31 -0
 - data/sig/atproto_auth/par/request.rbs +73 -0
 - data/sig/atproto_auth/par/response.rbs +17 -0
 - data/sig/atproto_auth/pkce.rbs +24 -0
 - data/sig/atproto_auth/server_metadata/authorization_server.rbs +69 -0
 - data/sig/atproto_auth/server_metadata/origin_url.rbs +21 -0
 - data/sig/atproto_auth/server_metadata/resource_server.rbs +27 -0
 - data/sig/atproto_auth/state/session.rbs +50 -0
 - data/sig/atproto_auth/state/session_manager.rbs +26 -0
 - data/sig/atproto_auth/state/token_set.rbs +40 -0
 - data/sig/atproto_auth/version.rbs +3 -0
 - data/sig/atproto_auth.rbs +39 -0
 - metadata +142 -0
 
| 
         @@ -0,0 +1,138 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            # frozen_string_literal: true
         
     | 
| 
      
 2 
     | 
    
         
            +
             
     | 
| 
      
 3 
     | 
    
         
            +
            require "monitor"
         
     | 
| 
      
 4 
     | 
    
         
            +
             
     | 
| 
      
 5 
     | 
    
         
            +
            module AtprotoAuth
         
     | 
| 
      
 6 
     | 
    
         
            +
              module DPoP
         
     | 
| 
      
 7 
     | 
    
         
            +
                # Manages DPoP nonces provided by servers during the OAuth flow.
         
     | 
| 
      
 8 
     | 
    
         
            +
                # Tracks separate nonces for Resource Server and Authorization Server.
         
     | 
| 
      
 9 
     | 
    
         
            +
                # Thread-safe to handle concurrent requests.
         
     | 
| 
      
 10 
     | 
    
         
            +
                class NonceManager
         
     | 
| 
      
 11 
     | 
    
         
            +
                  # Error for nonce-related issues
         
     | 
| 
      
 12 
     | 
    
         
            +
                  class NonceError < AtprotoAuth::Error; end
         
     | 
| 
      
 13 
     | 
    
         
            +
             
     | 
| 
      
 14 
     | 
    
         
            +
                  # Represents a stored nonce with its timestamp
         
     | 
| 
      
 15 
     | 
    
         
            +
                  class StoredNonce
         
     | 
| 
      
 16 
     | 
    
         
            +
                    attr_reader :value, :timestamp, :server_url
         
     | 
| 
      
 17 
     | 
    
         
            +
             
     | 
| 
      
 18 
     | 
    
         
            +
                    def initialize(value, server_url)
         
     | 
| 
      
 19 
     | 
    
         
            +
                      @value = value
         
     | 
| 
      
 20 
     | 
    
         
            +
                      @server_url = server_url
         
     | 
| 
      
 21 
     | 
    
         
            +
                      @timestamp = Time.now.to_i
         
     | 
| 
      
 22 
     | 
    
         
            +
                    end
         
     | 
| 
      
 23 
     | 
    
         
            +
             
     | 
| 
      
 24 
     | 
    
         
            +
                    def expired?(ttl = nil)
         
     | 
| 
      
 25 
     | 
    
         
            +
                      return false unless ttl
         
     | 
| 
      
 26 
     | 
    
         
            +
             
     | 
| 
      
 27 
     | 
    
         
            +
                      (Time.now.to_i - @timestamp) > ttl
         
     | 
| 
      
 28 
     | 
    
         
            +
                    end
         
     | 
| 
      
 29 
     | 
    
         
            +
                  end
         
     | 
| 
      
 30 
     | 
    
         
            +
             
     | 
| 
      
 31 
     | 
    
         
            +
                  # Maximum time in seconds a nonce is considered valid
         
     | 
| 
      
 32 
     | 
    
         
            +
                  DEFAULT_TTL = 300 # 5 minutes
         
     | 
| 
      
 33 
     | 
    
         
            +
             
     | 
| 
      
 34 
     | 
    
         
            +
                  def initialize(ttl: nil)
         
     | 
| 
      
 35 
     | 
    
         
            +
                    @ttl = ttl || DEFAULT_TTL
         
     | 
| 
      
 36 
     | 
    
         
            +
                    @nonces = {}
         
     | 
| 
      
 37 
     | 
    
         
            +
                    @monitor = Monitor.new
         
     | 
| 
      
 38 
     | 
    
         
            +
                  end
         
     | 
| 
      
 39 
     | 
    
         
            +
             
     | 
| 
      
 40 
     | 
    
         
            +
                  # Updates the stored nonce for a server
         
     | 
| 
      
 41 
     | 
    
         
            +
                  # @param nonce [String] The new nonce value
         
     | 
| 
      
 42 
     | 
    
         
            +
                  # @param server_url [String] The server's URL
         
     | 
| 
      
 43 
     | 
    
         
            +
                  # @raise [NonceError] if inputs are invalid
         
     | 
| 
      
 44 
     | 
    
         
            +
                  def update(nonce:, server_url:)
         
     | 
| 
      
 45 
     | 
    
         
            +
                    validate_inputs!(nonce, server_url)
         
     | 
| 
      
 46 
     | 
    
         
            +
                    origin = normalize_server_url(server_url)
         
     | 
| 
      
 47 
     | 
    
         
            +
             
     | 
| 
      
 48 
     | 
    
         
            +
                    @monitor.synchronize do
         
     | 
| 
      
 49 
     | 
    
         
            +
                      @nonces[origin] = StoredNonce.new(nonce, origin)
         
     | 
| 
      
 50 
     | 
    
         
            +
                    end
         
     | 
| 
      
 51 
     | 
    
         
            +
                  end
         
     | 
| 
      
 52 
     | 
    
         
            +
             
     | 
| 
      
 53 
     | 
    
         
            +
                  # Gets the current nonce for a server
         
     | 
| 
      
 54 
     | 
    
         
            +
                  # @param server_url [String] The server's URL
         
     | 
| 
      
 55 
     | 
    
         
            +
                  # @return [String, nil] The current nonce or nil if none exists/expired
         
     | 
| 
      
 56 
     | 
    
         
            +
                  # @raise [NonceError] if server_url is invalid
         
     | 
| 
      
 57 
     | 
    
         
            +
                  def get(server_url)
         
     | 
| 
      
 58 
     | 
    
         
            +
                    validate_server_url!(server_url)
         
     | 
| 
      
 59 
     | 
    
         
            +
                    origin = normalize_server_url(server_url)
         
     | 
| 
      
 60 
     | 
    
         
            +
             
     | 
| 
      
 61 
     | 
    
         
            +
                    @monitor.synchronize do
         
     | 
| 
      
 62 
     | 
    
         
            +
                      stored = @nonces[origin]
         
     | 
| 
      
 63 
     | 
    
         
            +
                      return nil if stored.nil? || stored.expired?(@ttl)
         
     | 
| 
      
 64 
     | 
    
         
            +
             
     | 
| 
      
 65 
     | 
    
         
            +
                      stored.value
         
     | 
| 
      
 66 
     | 
    
         
            +
                    end
         
     | 
| 
      
 67 
     | 
    
         
            +
                  end
         
     | 
| 
      
 68 
     | 
    
         
            +
             
     | 
| 
      
 69 
     | 
    
         
            +
                  # Clears an expired nonce for a server
         
     | 
| 
      
 70 
     | 
    
         
            +
                  # @param server_url [String] The server's URL
         
     | 
| 
      
 71 
     | 
    
         
            +
                  def clear(server_url)
         
     | 
| 
      
 72 
     | 
    
         
            +
                    @monitor.synchronize do
         
     | 
| 
      
 73 
     | 
    
         
            +
                      @nonces.delete(server_url)
         
     | 
| 
      
 74 
     | 
    
         
            +
                    end
         
     | 
| 
      
 75 
     | 
    
         
            +
                  end
         
     | 
| 
      
 76 
     | 
    
         
            +
             
     | 
| 
      
 77 
     | 
    
         
            +
                  # Clears all stored nonces
         
     | 
| 
      
 78 
     | 
    
         
            +
                  def clear_all
         
     | 
| 
      
 79 
     | 
    
         
            +
                    @monitor.synchronize do
         
     | 
| 
      
 80 
     | 
    
         
            +
                      @nonces.clear
         
     | 
| 
      
 81 
     | 
    
         
            +
                    end
         
     | 
| 
      
 82 
     | 
    
         
            +
                  end
         
     | 
| 
      
 83 
     | 
    
         
            +
             
     | 
| 
      
 84 
     | 
    
         
            +
                  # Get all currently stored server URLs
         
     | 
| 
      
 85 
     | 
    
         
            +
                  # @return [Array<String>] Array of server URLs with stored nonces
         
     | 
| 
      
 86 
     | 
    
         
            +
                  def server_urls
         
     | 
| 
      
 87 
     | 
    
         
            +
                    @monitor.synchronize do
         
     | 
| 
      
 88 
     | 
    
         
            +
                      @nonces.keys
         
     | 
| 
      
 89 
     | 
    
         
            +
                    end
         
     | 
| 
      
 90 
     | 
    
         
            +
                  end
         
     | 
| 
      
 91 
     | 
    
         
            +
             
     | 
| 
      
 92 
     | 
    
         
            +
                  # Check if a server has a valid nonce
         
     | 
| 
      
 93 
     | 
    
         
            +
                  # @param server_url [String] The server's URL
         
     | 
| 
      
 94 
     | 
    
         
            +
                  # @return [Boolean] true if server has a valid nonce
         
     | 
| 
      
 95 
     | 
    
         
            +
                  def valid_nonce?(server_url)
         
     | 
| 
      
 96 
     | 
    
         
            +
                    validate_server_url!(server_url)
         
     | 
| 
      
 97 
     | 
    
         
            +
             
     | 
| 
      
 98 
     | 
    
         
            +
                    @monitor.synchronize do
         
     | 
| 
      
 99 
     | 
    
         
            +
                      stored = @nonces[server_url]
         
     | 
| 
      
 100 
     | 
    
         
            +
                      !stored.nil? && !stored.expired?(@ttl)
         
     | 
| 
      
 101 
     | 
    
         
            +
                    end
         
     | 
| 
      
 102 
     | 
    
         
            +
                  end
         
     | 
| 
      
 103 
     | 
    
         
            +
             
     | 
| 
      
 104 
     | 
    
         
            +
                  private
         
     | 
| 
      
 105 
     | 
    
         
            +
             
     | 
| 
      
 106 
     | 
    
         
            +
                  def normalize_server_url(url)
         
     | 
| 
      
 107 
     | 
    
         
            +
                    uri = URI(url)
         
     | 
| 
      
 108 
     | 
    
         
            +
                    port = uri.port
         
     | 
| 
      
 109 
     | 
    
         
            +
                    port = nil if (uri.scheme == "https" && port == 443) ||
         
     | 
| 
      
 110 
     | 
    
         
            +
                                  (uri.scheme == "http" && port == 80)
         
     | 
| 
      
 111 
     | 
    
         
            +
             
     | 
| 
      
 112 
     | 
    
         
            +
                    origin = "#{uri.scheme}://#{uri.host}"
         
     | 
| 
      
 113 
     | 
    
         
            +
                    origin = "#{origin}:#{port}" if port
         
     | 
| 
      
 114 
     | 
    
         
            +
                    origin
         
     | 
| 
      
 115 
     | 
    
         
            +
                  end
         
     | 
| 
      
 116 
     | 
    
         
            +
             
     | 
| 
      
 117 
     | 
    
         
            +
                  def validate_inputs!(nonce, server_url)
         
     | 
| 
      
 118 
     | 
    
         
            +
                    raise NonceError, "nonce is required" if nonce.nil? || nonce.empty?
         
     | 
| 
      
 119 
     | 
    
         
            +
             
     | 
| 
      
 120 
     | 
    
         
            +
                    validate_server_url!(server_url)
         
     | 
| 
      
 121 
     | 
    
         
            +
                  end
         
     | 
| 
      
 122 
     | 
    
         
            +
             
     | 
| 
      
 123 
     | 
    
         
            +
                  def validate_server_url!(server_url)
         
     | 
| 
      
 124 
     | 
    
         
            +
                    raise NonceError, "server_url is required" if server_url.nil? || server_url.empty?
         
     | 
| 
      
 125 
     | 
    
         
            +
             
     | 
| 
      
 126 
     | 
    
         
            +
                    uri = URI(server_url)
         
     | 
| 
      
 127 
     | 
    
         
            +
                    raise NonceError, "server_url must be HTTP(S)" unless uri.is_a?(URI::HTTP)
         
     | 
| 
      
 128 
     | 
    
         
            +
             
     | 
| 
      
 129 
     | 
    
         
            +
                    # Allow HTTP for localhost only
         
     | 
| 
      
 130 
     | 
    
         
            +
                    if uri.host != "localhost" && uri.scheme != "https"
         
     | 
| 
      
 131 
     | 
    
         
            +
                      raise NonceError, "server_url must be HTTPS (except for localhost)"
         
     | 
| 
      
 132 
     | 
    
         
            +
                    end
         
     | 
| 
      
 133 
     | 
    
         
            +
                  rescue URI::InvalidURIError => e
         
     | 
| 
      
 134 
     | 
    
         
            +
                    raise NonceError, "invalid server_url: #{e.message}"
         
     | 
| 
      
 135 
     | 
    
         
            +
                  end
         
     | 
| 
      
 136 
     | 
    
         
            +
                end
         
     | 
| 
      
 137 
     | 
    
         
            +
              end
         
     | 
| 
      
 138 
     | 
    
         
            +
            end
         
     | 
| 
         @@ -0,0 +1,112 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            # frozen_string_literal: true
         
     | 
| 
      
 2 
     | 
    
         
            +
             
     | 
| 
      
 3 
     | 
    
         
            +
            require "securerandom"
         
     | 
| 
      
 4 
     | 
    
         
            +
            require "time"
         
     | 
| 
      
 5 
     | 
    
         
            +
             
     | 
| 
      
 6 
     | 
    
         
            +
            module AtprotoAuth
         
     | 
| 
      
 7 
     | 
    
         
            +
              module DPoP
         
     | 
| 
      
 8 
     | 
    
         
            +
                # Creates and manages DPoP proof JWTs according to RFC 9449.
         
     | 
| 
      
 9 
     | 
    
         
            +
                # DPoP proofs are used to prove possession of a key when making
         
     | 
| 
      
 10 
     | 
    
         
            +
                # HTTP requests. Each proof is a JWT that includes details about
         
     | 
| 
      
 11 
     | 
    
         
            +
                # the request and is signed by the DPoP key.
         
     | 
| 
      
 12 
     | 
    
         
            +
                class ProofGenerator
         
     | 
| 
      
 13 
     | 
    
         
            +
                  # Error raised for proof generation/validation issues
         
     | 
| 
      
 14 
     | 
    
         
            +
                  class ProofError < AtprotoAuth::Error; end
         
     | 
| 
      
 15 
     | 
    
         
            +
             
     | 
| 
      
 16 
     | 
    
         
            +
                  # @return [KeyManager] The key manager used for signing proofs
         
     | 
| 
      
 17 
     | 
    
         
            +
                  attr_reader :key_manager
         
     | 
| 
      
 18 
     | 
    
         
            +
             
     | 
| 
      
 19 
     | 
    
         
            +
                  # Creates a new ProofGenerator instance
         
     | 
| 
      
 20 
     | 
    
         
            +
                  # @param key_manager [KeyManager] Key manager to use for signing proofs
         
     | 
| 
      
 21 
     | 
    
         
            +
                  # @raise [ProofError] if key_manager is invalid
         
     | 
| 
      
 22 
     | 
    
         
            +
                  def initialize(key_manager)
         
     | 
| 
      
 23 
     | 
    
         
            +
                    raise ProofError, "key_manager is required" unless key_manager
         
     | 
| 
      
 24 
     | 
    
         
            +
                    raise ProofError, "invalid key_manager type" unless key_manager.is_a?(KeyManager)
         
     | 
| 
      
 25 
     | 
    
         
            +
             
     | 
| 
      
 26 
     | 
    
         
            +
                    @key_manager = key_manager
         
     | 
| 
      
 27 
     | 
    
         
            +
                  end
         
     | 
| 
      
 28 
     | 
    
         
            +
             
     | 
| 
      
 29 
     | 
    
         
            +
                  # Generates a new DPoP proof JWT for an HTTP request
         
     | 
| 
      
 30 
     | 
    
         
            +
                  # @param http_method [String] HTTP method (e.g. "POST")
         
     | 
| 
      
 31 
     | 
    
         
            +
                  # @param http_uri [String] Full HTTP URI for the request
         
     | 
| 
      
 32 
     | 
    
         
            +
                  # @param nonce [String, nil] Server-provided nonce (required if available)
         
     | 
| 
      
 33 
     | 
    
         
            +
                  # @param access_token [String, nil] Access token being used (if any)
         
     | 
| 
      
 34 
     | 
    
         
            +
                  # @param ath [Boolean] Whether to include access token hash (default: true if token provided)
         
     | 
| 
      
 35 
     | 
    
         
            +
                  # @return [String] The signed DPoP proof JWT
         
     | 
| 
      
 36 
     | 
    
         
            +
                  # @raise [ProofError] if generation fails or parameters are invalid
         
     | 
| 
      
 37 
     | 
    
         
            +
                  def generate(http_method:, http_uri:, nonce: nil, access_token: nil, ath: nil)
         
     | 
| 
      
 38 
     | 
    
         
            +
                    validate_inputs!(http_method, http_uri)
         
     | 
| 
      
 39 
     | 
    
         
            +
                    ath = !access_token.nil? if ath.nil?
         
     | 
| 
      
 40 
     | 
    
         
            +
             
     | 
| 
      
 41 
     | 
    
         
            +
                    header = build_header
         
     | 
| 
      
 42 
     | 
    
         
            +
                    payload = build_payload(
         
     | 
| 
      
 43 
     | 
    
         
            +
                      http_method: http_method,
         
     | 
| 
      
 44 
     | 
    
         
            +
                      http_uri: http_uri,
         
     | 
| 
      
 45 
     | 
    
         
            +
                      nonce: nonce,
         
     | 
| 
      
 46 
     | 
    
         
            +
                      access_token: access_token,
         
     | 
| 
      
 47 
     | 
    
         
            +
                      include_ath: ath
         
     | 
| 
      
 48 
     | 
    
         
            +
                    )
         
     | 
| 
      
 49 
     | 
    
         
            +
             
     | 
| 
      
 50 
     | 
    
         
            +
                    key_manager.sign_segments(header, payload)
         
     | 
| 
      
 51 
     | 
    
         
            +
                  rescue StandardError => e
         
     | 
| 
      
 52 
     | 
    
         
            +
                    raise ProofError, "Failed to generate proof: #{e.message}"
         
     | 
| 
      
 53 
     | 
    
         
            +
                  end
         
     | 
| 
      
 54 
     | 
    
         
            +
             
     | 
| 
      
 55 
     | 
    
         
            +
                  private
         
     | 
| 
      
 56 
     | 
    
         
            +
             
     | 
| 
      
 57 
     | 
    
         
            +
                  def validate_inputs!(http_method, http_uri)
         
     | 
| 
      
 58 
     | 
    
         
            +
                    raise ProofError, "http_method is required" if http_method.nil? || http_method.empty?
         
     | 
| 
      
 59 
     | 
    
         
            +
                    raise ProofError, "http_uri is required" if http_uri.nil? || http_uri.empty?
         
     | 
| 
      
 60 
     | 
    
         
            +
             
     | 
| 
      
 61 
     | 
    
         
            +
                    uri = URI(http_uri)
         
     | 
| 
      
 62 
     | 
    
         
            +
                    raise ProofError, "invalid http_uri" unless uri.is_a?(URI::HTTP)
         
     | 
| 
      
 63 
     | 
    
         
            +
                  rescue URI::InvalidURIError => e
         
     | 
| 
      
 64 
     | 
    
         
            +
                    raise ProofError, "invalid http_uri: #{e.message}"
         
     | 
| 
      
 65 
     | 
    
         
            +
                  end
         
     | 
| 
      
 66 
     | 
    
         
            +
             
     | 
| 
      
 67 
     | 
    
         
            +
                  def build_header
         
     | 
| 
      
 68 
     | 
    
         
            +
                    {
         
     | 
| 
      
 69 
     | 
    
         
            +
                      typ: "dpop+jwt",
         
     | 
| 
      
 70 
     | 
    
         
            +
                      alg: "ES256",
         
     | 
| 
      
 71 
     | 
    
         
            +
                      jwk: key_manager.public_jwk.to_h
         
     | 
| 
      
 72 
     | 
    
         
            +
                    }
         
     | 
| 
      
 73 
     | 
    
         
            +
                  end
         
     | 
| 
      
 74 
     | 
    
         
            +
             
     | 
| 
      
 75 
     | 
    
         
            +
                  def build_payload(http_method:, http_uri:, nonce: nil, access_token: nil, include_ath: nil)
         
     | 
| 
      
 76 
     | 
    
         
            +
                    payload = {
         
     | 
| 
      
 77 
     | 
    
         
            +
                      "jti" => SecureRandom.uuid,
         
     | 
| 
      
 78 
     | 
    
         
            +
                      "htm" => http_method.upcase,
         
     | 
| 
      
 79 
     | 
    
         
            +
                      "htu" => normalize_uri(http_uri),
         
     | 
| 
      
 80 
     | 
    
         
            +
                      "iat" => Time.now.to_i
         
     | 
| 
      
 81 
     | 
    
         
            +
                    }
         
     | 
| 
      
 82 
     | 
    
         
            +
             
     | 
| 
      
 83 
     | 
    
         
            +
                    # Add the nonce if provided
         
     | 
| 
      
 84 
     | 
    
         
            +
                    payload["nonce"] = nonce if nonce
         
     | 
| 
      
 85 
     | 
    
         
            +
             
     | 
| 
      
 86 
     | 
    
         
            +
                    # Add access token hash if needed
         
     | 
| 
      
 87 
     | 
    
         
            +
                    payload["ath"] = generate_access_token_hash(access_token) if access_token && include_ath
         
     | 
| 
      
 88 
     | 
    
         
            +
             
     | 
| 
      
 89 
     | 
    
         
            +
                    payload
         
     | 
| 
      
 90 
     | 
    
         
            +
                  end
         
     | 
| 
      
 91 
     | 
    
         
            +
             
     | 
| 
      
 92 
     | 
    
         
            +
                  def normalize_uri(uri)
         
     | 
| 
      
 93 
     | 
    
         
            +
                    uri = URI(uri)
         
     | 
| 
      
 94 
     | 
    
         
            +
                    # Remove default ports
         
     | 
| 
      
 95 
     | 
    
         
            +
                    uri.port = nil if (uri.scheme == "https" && uri.port == 443) || (uri.scheme == "http" && uri.port == 80)
         
     | 
| 
      
 96 
     | 
    
         
            +
                    uri.fragment = nil
         
     | 
| 
      
 97 
     | 
    
         
            +
                    uri.to_s
         
     | 
| 
      
 98 
     | 
    
         
            +
                  end
         
     | 
| 
      
 99 
     | 
    
         
            +
             
     | 
| 
      
 100 
     | 
    
         
            +
                  def generate_access_token_hash(access_token)
         
     | 
| 
      
 101 
     | 
    
         
            +
                    digest = OpenSSL::Digest::SHA256.digest(access_token)
         
     | 
| 
      
 102 
     | 
    
         
            +
                    Base64.urlsafe_encode64(digest[0...(digest.length / 2)], padding: false)
         
     | 
| 
      
 103 
     | 
    
         
            +
                  end
         
     | 
| 
      
 104 
     | 
    
         
            +
             
     | 
| 
      
 105 
     | 
    
         
            +
                  def encode_jwt_segments(header, payload)
         
     | 
| 
      
 106 
     | 
    
         
            +
                    encoded_header = Base64.urlsafe_encode64(JSON.generate(header), padding: false)
         
     | 
| 
      
 107 
     | 
    
         
            +
                    encoded_payload = Base64.urlsafe_encode64(JSON.generate(payload), padding: false)
         
     | 
| 
      
 108 
     | 
    
         
            +
                    "#{encoded_header}.#{encoded_payload}"
         
     | 
| 
      
 109 
     | 
    
         
            +
                  end
         
     | 
| 
      
 110 
     | 
    
         
            +
                end
         
     | 
| 
      
 111 
     | 
    
         
            +
              end
         
     | 
| 
      
 112 
     | 
    
         
            +
            end
         
     | 
| 
         @@ -0,0 +1,47 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            # frozen_string_literal: true
         
     | 
| 
      
 2 
     | 
    
         
            +
             
     | 
| 
      
 3 
     | 
    
         
            +
            module AtprotoAuth
         
     | 
| 
      
 4 
     | 
    
         
            +
              class Error < StandardError; end
         
     | 
| 
      
 5 
     | 
    
         
            +
             
     | 
| 
      
 6 
     | 
    
         
            +
              # Base class for AT Protocol OAuth errors
         
     | 
| 
      
 7 
     | 
    
         
            +
              class OAuthError < Error
         
     | 
| 
      
 8 
     | 
    
         
            +
                attr_reader :error_code
         
     | 
| 
      
 9 
     | 
    
         
            +
             
     | 
| 
      
 10 
     | 
    
         
            +
                def initialize(message, error_code)
         
     | 
| 
      
 11 
     | 
    
         
            +
                  @error_code = error_code
         
     | 
| 
      
 12 
     | 
    
         
            +
                  # @type-ignore
         
     | 
| 
      
 13 
     | 
    
         
            +
                  super(message)
         
     | 
| 
      
 14 
     | 
    
         
            +
                end
         
     | 
| 
      
 15 
     | 
    
         
            +
              end
         
     | 
| 
      
 16 
     | 
    
         
            +
             
     | 
| 
      
 17 
     | 
    
         
            +
              # Error raised when client metadata is invalid or cannot be retrieved.
         
     | 
| 
      
 18 
     | 
    
         
            +
              # This can occur during client metadata fetching, parsing, or validation.
         
     | 
| 
      
 19 
     | 
    
         
            +
              #
         
     | 
| 
      
 20 
     | 
    
         
            +
              # @example Handling client metadata errors
         
     | 
| 
      
 21 
     | 
    
         
            +
              #   begin
         
     | 
| 
      
 22 
     | 
    
         
            +
              #     client = AtprotoAuth::Client.new(client_id: "https://myapp.com/metadata.json")
         
     | 
| 
      
 23 
     | 
    
         
            +
              #   rescue AtprotoAuth::InvalidClientMetadata => e
         
     | 
| 
      
 24 
     | 
    
         
            +
              #     puts "Failed to validate client metadata: #{e.message}"
         
     | 
| 
      
 25 
     | 
    
         
            +
              #   end
         
     | 
| 
      
 26 
     | 
    
         
            +
              class InvalidClientMetadata < OAuthError
         
     | 
| 
      
 27 
     | 
    
         
            +
                def initialize(message)
         
     | 
| 
      
 28 
     | 
    
         
            +
                  super(message, "invalid_client_metadata")
         
     | 
| 
      
 29 
     | 
    
         
            +
                end
         
     | 
| 
      
 30 
     | 
    
         
            +
              end
         
     | 
| 
      
 31 
     | 
    
         
            +
             
     | 
| 
      
 32 
     | 
    
         
            +
              # Error raised when authorization server metadata is invalid or cannot be retrieved.
         
     | 
| 
      
 33 
     | 
    
         
            +
              # This includes issues with server metadata fetching, parsing, or validation against
         
     | 
| 
      
 34 
     | 
    
         
            +
              # the AT Protocol OAuth requirements.
         
     | 
| 
      
 35 
     | 
    
         
            +
              #
         
     | 
| 
      
 36 
     | 
    
         
            +
              # @example Handling authorization server errors
         
     | 
| 
      
 37 
     | 
    
         
            +
              #   begin
         
     | 
| 
      
 38 
     | 
    
         
            +
              #     server = AtprotoAuth::AuthorizationServer.new(issuer: "https://auth.example.com")
         
     | 
| 
      
 39 
     | 
    
         
            +
              #   rescue AtprotoAuth::InvalidAuthorizationServer => e
         
     | 
| 
      
 40 
     | 
    
         
            +
              #     puts "Failed to validate authorization server: #{e.message}"
         
     | 
| 
      
 41 
     | 
    
         
            +
              #   end
         
     | 
| 
      
 42 
     | 
    
         
            +
              class InvalidAuthorizationServer < OAuthError
         
     | 
| 
      
 43 
     | 
    
         
            +
                def initialize(message)
         
     | 
| 
      
 44 
     | 
    
         
            +
                  super(message, "invalid_authorization_server")
         
     | 
| 
      
 45 
     | 
    
         
            +
                end
         
     | 
| 
      
 46 
     | 
    
         
            +
              end
         
     | 
| 
      
 47 
     | 
    
         
            +
            end
         
     | 
| 
         @@ -0,0 +1,227 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            # frozen_string_literal: true
         
     | 
| 
      
 2 
     | 
    
         
            +
             
     | 
| 
      
 3 
     | 
    
         
            +
            require "net/http"
         
     | 
| 
      
 4 
     | 
    
         
            +
            require "uri"
         
     | 
| 
      
 5 
     | 
    
         
            +
            require "ipaddr"
         
     | 
| 
      
 6 
     | 
    
         
            +
             
     | 
| 
      
 7 
     | 
    
         
            +
            module AtprotoAuth
         
     | 
| 
      
 8 
     | 
    
         
            +
              # A secure HTTP client for making OAuth-related requests.
         
     | 
| 
      
 9 
     | 
    
         
            +
              # Implements protections against SSRF attacks and enforces security headers.
         
     | 
| 
      
 10 
     | 
    
         
            +
              class HttpClient
         
     | 
| 
      
 11 
     | 
    
         
            +
                FORBIDDEN_IP_RANGES = [
         
     | 
| 
      
 12 
     | 
    
         
            +
                  IPAddr.new("0.0.0.0/8"),      # Current network
         
     | 
| 
      
 13 
     | 
    
         
            +
                  IPAddr.new("10.0.0.0/8"),     # Private network
         
     | 
| 
      
 14 
     | 
    
         
            +
                  IPAddr.new("127.0.0.0/8"),    # Loopback
         
     | 
| 
      
 15 
     | 
    
         
            +
                  IPAddr.new("169.254.0.0/16"), # Link-local
         
     | 
| 
      
 16 
     | 
    
         
            +
                  IPAddr.new("172.16.0.0/12"),  # Private network
         
     | 
| 
      
 17 
     | 
    
         
            +
                  IPAddr.new("192.168.0.0/16"), # Private network
         
     | 
| 
      
 18 
     | 
    
         
            +
                  IPAddr.new("fc00::/7"),       # Unique local address
         
     | 
| 
      
 19 
     | 
    
         
            +
                  IPAddr.new("fe80::/10")       # Link-local address
         
     | 
| 
      
 20 
     | 
    
         
            +
                ].freeze
         
     | 
| 
      
 21 
     | 
    
         
            +
             
     | 
| 
      
 22 
     | 
    
         
            +
                ALLOWED_SCHEMES = ["https"].freeze
         
     | 
| 
      
 23 
     | 
    
         
            +
                DEFAULT_TIMEOUT = 10 # seconds
         
     | 
| 
      
 24 
     | 
    
         
            +
                MAX_REDIRECTS = 5
         
     | 
| 
      
 25 
     | 
    
         
            +
                MAX_RESPONSE_SIZE = 10 * 1024 * 1024 # 10MB
         
     | 
| 
      
 26 
     | 
    
         
            +
             
     | 
| 
      
 27 
     | 
    
         
            +
                # Error raised when a request is blocked due to SSRF protection
         
     | 
| 
      
 28 
     | 
    
         
            +
                class SSRFError < Error; end
         
     | 
| 
      
 29 
     | 
    
         
            +
             
     | 
| 
      
 30 
     | 
    
         
            +
                # Error raised when an HTTP request fails
         
     | 
| 
      
 31 
     | 
    
         
            +
                class HttpError < Error
         
     | 
| 
      
 32 
     | 
    
         
            +
                  attr_reader :response
         
     | 
| 
      
 33 
     | 
    
         
            +
             
     | 
| 
      
 34 
     | 
    
         
            +
                  def initialize(message, response)
         
     | 
| 
      
 35 
     | 
    
         
            +
                    @response = response
         
     | 
| 
      
 36 
     | 
    
         
            +
                    super(message)
         
     | 
| 
      
 37 
     | 
    
         
            +
                  end
         
     | 
| 
      
 38 
     | 
    
         
            +
                end
         
     | 
| 
      
 39 
     | 
    
         
            +
             
     | 
| 
      
 40 
     | 
    
         
            +
                RedirectHandlerOptions = Data.define(:original_uri, :method, :response, :headers, :redirect_count, :body)
         
     | 
| 
      
 41 
     | 
    
         
            +
             
     | 
| 
      
 42 
     | 
    
         
            +
                # @param timeout [Integer] Request timeout in seconds
         
     | 
| 
      
 43 
     | 
    
         
            +
                # @param verify_ssl [Boolean] Whether to verify SSL certificates
         
     | 
| 
      
 44 
     | 
    
         
            +
                def initialize(timeout: DEFAULT_TIMEOUT, verify_ssl: true)
         
     | 
| 
      
 45 
     | 
    
         
            +
                  @timeout = timeout
         
     | 
| 
      
 46 
     | 
    
         
            +
                  @verify_ssl = verify_ssl
         
     | 
| 
      
 47 
     | 
    
         
            +
                end
         
     | 
| 
      
 48 
     | 
    
         
            +
             
     | 
| 
      
 49 
     | 
    
         
            +
                # Makes a secure HTTP GET request
         
     | 
| 
      
 50 
     | 
    
         
            +
                # @param url [String] URL to request
         
     | 
| 
      
 51 
     | 
    
         
            +
                # @param headers [Hash] Additional headers to send
         
     | 
| 
      
 52 
     | 
    
         
            +
                # @return [Hash] Response with :status, :headers, and :body
         
     | 
| 
      
 53 
     | 
    
         
            +
                # @raise [SSRFError] If the request would be unsafe
         
     | 
| 
      
 54 
     | 
    
         
            +
                # @raise [HttpError] If the request fails
         
     | 
| 
      
 55 
     | 
    
         
            +
                def get(url, headers = {})
         
     | 
| 
      
 56 
     | 
    
         
            +
                  uri = validate_uri!(url)
         
     | 
| 
      
 57 
     | 
    
         
            +
                  validate_ip!(uri)
         
     | 
| 
      
 58 
     | 
    
         
            +
             
     | 
| 
      
 59 
     | 
    
         
            +
                  response = make_request(uri, headers)
         
     | 
| 
      
 60 
     | 
    
         
            +
                  validate_response!(response)
         
     | 
| 
      
 61 
     | 
    
         
            +
             
     | 
| 
      
 62 
     | 
    
         
            +
                  {
         
     | 
| 
      
 63 
     | 
    
         
            +
                    status: response.code.to_i,
         
     | 
| 
      
 64 
     | 
    
         
            +
                    headers: response.each_header.to_h,
         
     | 
| 
      
 65 
     | 
    
         
            +
                    body: response.body
         
     | 
| 
      
 66 
     | 
    
         
            +
                  }
         
     | 
| 
      
 67 
     | 
    
         
            +
                end
         
     | 
| 
      
 68 
     | 
    
         
            +
             
     | 
| 
      
 69 
     | 
    
         
            +
                # Makes a secure HTTP POST request
         
     | 
| 
      
 70 
     | 
    
         
            +
                # @param url [String] URL to request
         
     | 
| 
      
 71 
     | 
    
         
            +
                # @param body [String] Request body
         
     | 
| 
      
 72 
     | 
    
         
            +
                # @param headers [Hash] Additional headers to send
         
     | 
| 
      
 73 
     | 
    
         
            +
                # @return [Hash] Response with :status, :headers, and :body
         
     | 
| 
      
 74 
     | 
    
         
            +
                # @raise [SSRFError] If the request would be unsafe
         
     | 
| 
      
 75 
     | 
    
         
            +
                # @raise [HttpError] If the request fails
         
     | 
| 
      
 76 
     | 
    
         
            +
                def post(url, body: nil, headers: {})
         
     | 
| 
      
 77 
     | 
    
         
            +
                  uri = validate_uri!(url)
         
     | 
| 
      
 78 
     | 
    
         
            +
                  validate_ip!(uri)
         
     | 
| 
      
 79 
     | 
    
         
            +
             
     | 
| 
      
 80 
     | 
    
         
            +
                  response = make_post_request(uri, body, headers)
         
     | 
| 
      
 81 
     | 
    
         
            +
                  validate_response!(response)
         
     | 
| 
      
 82 
     | 
    
         
            +
             
     | 
| 
      
 83 
     | 
    
         
            +
                  {
         
     | 
| 
      
 84 
     | 
    
         
            +
                    status: response.code.to_i,
         
     | 
| 
      
 85 
     | 
    
         
            +
                    headers: response.each_header.to_h,
         
     | 
| 
      
 86 
     | 
    
         
            +
                    body: response.body
         
     | 
| 
      
 87 
     | 
    
         
            +
                  }
         
     | 
| 
      
 88 
     | 
    
         
            +
                end
         
     | 
| 
      
 89 
     | 
    
         
            +
             
     | 
| 
      
 90 
     | 
    
         
            +
                private
         
     | 
| 
      
 91 
     | 
    
         
            +
             
     | 
| 
      
 92 
     | 
    
         
            +
                def validate_uri!(url)
         
     | 
| 
      
 93 
     | 
    
         
            +
                  uri = URI(url)
         
     | 
| 
      
 94 
     | 
    
         
            +
                  unless ALLOWED_SCHEMES.include?(uri.scheme)
         
     | 
| 
      
 95 
     | 
    
         
            +
                    raise SSRFError, "URL scheme must be one of: #{ALLOWED_SCHEMES.join(", ")}"
         
     | 
| 
      
 96 
     | 
    
         
            +
                  end
         
     | 
| 
      
 97 
     | 
    
         
            +
                  raise SSRFError, "URL must include host" unless uri.host
         
     | 
| 
      
 98 
     | 
    
         
            +
                  raise SSRFError, "URL must not include fragment" if uri.fragment
         
     | 
| 
      
 99 
     | 
    
         
            +
             
     | 
| 
      
 100 
     | 
    
         
            +
                  uri
         
     | 
| 
      
 101 
     | 
    
         
            +
                end
         
     | 
| 
      
 102 
     | 
    
         
            +
             
     | 
| 
      
 103 
     | 
    
         
            +
                def validate_ip!(uri)
         
     | 
| 
      
 104 
     | 
    
         
            +
                  ip = resolve_ip(uri.host)
         
     | 
| 
      
 105 
     | 
    
         
            +
                  return unless forbidden_ip?(ip)
         
     | 
| 
      
 106 
     | 
    
         
            +
             
     | 
| 
      
 107 
     | 
    
         
            +
                  raise SSRFError, "Request to forbidden IP address"
         
     | 
| 
      
 108 
     | 
    
         
            +
                end
         
     | 
| 
      
 109 
     | 
    
         
            +
             
     | 
| 
      
 110 
     | 
    
         
            +
                def resolve_ip(hostname)
         
     | 
| 
      
 111 
     | 
    
         
            +
                  IPAddr.new(Addrinfo.ip(hostname).ip_address)
         
     | 
| 
      
 112 
     | 
    
         
            +
                rescue SocketError => e
         
     | 
| 
      
 113 
     | 
    
         
            +
                  raise SSRFError, "Failed to resolve hostname: #{e.message}"
         
     | 
| 
      
 114 
     | 
    
         
            +
                end
         
     | 
| 
      
 115 
     | 
    
         
            +
             
     | 
| 
      
 116 
     | 
    
         
            +
                def forbidden_ip?(ip)
         
     | 
| 
      
 117 
     | 
    
         
            +
                  FORBIDDEN_IP_RANGES.any? { |range| range.include?(ip) }
         
     | 
| 
      
 118 
     | 
    
         
            +
                end
         
     | 
| 
      
 119 
     | 
    
         
            +
             
     | 
| 
      
 120 
     | 
    
         
            +
                def make_request(uri, headers = {}, redirect_count = 0)
         
     | 
| 
      
 121 
     | 
    
         
            +
                  http = Net::HTTP.new(uri.host, uri.port)
         
     | 
| 
      
 122 
     | 
    
         
            +
                  configure_http_client!(http)
         
     | 
| 
      
 123 
     | 
    
         
            +
             
     | 
| 
      
 124 
     | 
    
         
            +
                  request = Net::HTTP::Get.new(uri.request_uri)
         
     | 
| 
      
 125 
     | 
    
         
            +
                  add_security_headers!(request, headers)
         
     | 
| 
      
 126 
     | 
    
         
            +
                  response = http.request(request)
         
     | 
| 
      
 127 
     | 
    
         
            +
                  handle_redirect(
         
     | 
| 
      
 128 
     | 
    
         
            +
                    original_uri: uri,
         
     | 
| 
      
 129 
     | 
    
         
            +
                    response: response,
         
     | 
| 
      
 130 
     | 
    
         
            +
                    headers: headers,
         
     | 
| 
      
 131 
     | 
    
         
            +
                    redirect_count: redirect_count
         
     | 
| 
      
 132 
     | 
    
         
            +
                  )
         
     | 
| 
      
 133 
     | 
    
         
            +
                rescue Net::OpenTimeout, Net::ReadTimeout => e
         
     | 
| 
      
 134 
     | 
    
         
            +
                  raise HttpError.new("Request timeout: #{e.message}", nil)
         
     | 
| 
      
 135 
     | 
    
         
            +
                rescue StandardError => e
         
     | 
| 
      
 136 
     | 
    
         
            +
                  raise HttpError.new("Request failed: #{e.message}", nil)
         
     | 
| 
      
 137 
     | 
    
         
            +
                end
         
     | 
| 
      
 138 
     | 
    
         
            +
             
     | 
| 
      
 139 
     | 
    
         
            +
                def make_post_request(uri, body, headers = {}, redirect_count = 0) # rubocop:disable Metrics/AbcSize
         
     | 
| 
      
 140 
     | 
    
         
            +
                  http = Net::HTTP.new(uri.host, uri.port)
         
     | 
| 
      
 141 
     | 
    
         
            +
                  configure_http_client!(http)
         
     | 
| 
      
 142 
     | 
    
         
            +
             
     | 
| 
      
 143 
     | 
    
         
            +
                  request = Net::HTTP::Post.new(uri.request_uri)
         
     | 
| 
      
 144 
     | 
    
         
            +
                  add_security_headers!(request, headers)
         
     | 
| 
      
 145 
     | 
    
         
            +
                  request.body = body.is_a?(Hash) ? URI.encode_www_form(body) : body if body
         
     | 
| 
      
 146 
     | 
    
         
            +
             
     | 
| 
      
 147 
     | 
    
         
            +
                  response = http.request(request)
         
     | 
| 
      
 148 
     | 
    
         
            +
                  handle_redirect(
         
     | 
| 
      
 149 
     | 
    
         
            +
                    original_uri: uri,
         
     | 
| 
      
 150 
     | 
    
         
            +
                    body: body,
         
     | 
| 
      
 151 
     | 
    
         
            +
                    method: :post,
         
     | 
| 
      
 152 
     | 
    
         
            +
                    response: response,
         
     | 
| 
      
 153 
     | 
    
         
            +
                    headers: headers,
         
     | 
| 
      
 154 
     | 
    
         
            +
                    redirect_count: redirect_count
         
     | 
| 
      
 155 
     | 
    
         
            +
                  )
         
     | 
| 
      
 156 
     | 
    
         
            +
                rescue Net::OpenTimeout, Net::ReadTimeout => e
         
     | 
| 
      
 157 
     | 
    
         
            +
                  raise HttpError.new("Request timeout: #{e.message}", nil)
         
     | 
| 
      
 158 
     | 
    
         
            +
                rescue StandardError => e
         
     | 
| 
      
 159 
     | 
    
         
            +
                  raise HttpError.new("Request failed: #{e.message}", nil)
         
     | 
| 
      
 160 
     | 
    
         
            +
                end
         
     | 
| 
      
 161 
     | 
    
         
            +
             
     | 
| 
      
 162 
     | 
    
         
            +
                def configure_http_client!(http)
         
     | 
| 
      
 163 
     | 
    
         
            +
                  http.use_ssl = true
         
     | 
| 
      
 164 
     | 
    
         
            +
                  http.verify_mode = @verify_ssl ? OpenSSL::SSL::VERIFY_PEER : OpenSSL::SSL::VERIFY_NONE
         
     | 
| 
      
 165 
     | 
    
         
            +
                  http.read_timeout = @timeout
         
     | 
| 
      
 166 
     | 
    
         
            +
                  http.open_timeout = @timeout
         
     | 
| 
      
 167 
     | 
    
         
            +
                end
         
     | 
| 
      
 168 
     | 
    
         
            +
             
     | 
| 
      
 169 
     | 
    
         
            +
                def add_security_headers!(request, headers)
         
     | 
| 
      
 170 
     | 
    
         
            +
                  # Prevent caching of sensitive data
         
     | 
| 
      
 171 
     | 
    
         
            +
                  request["Cache-Control"] = "no-store"
         
     | 
| 
      
 172 
     | 
    
         
            +
             
     | 
| 
      
 173 
     | 
    
         
            +
                  # Add user-provided headers
         
     | 
| 
      
 174 
     | 
    
         
            +
                  headers.each { |k, v| request[k] = v }
         
     | 
| 
      
 175 
     | 
    
         
            +
                end
         
     | 
| 
      
 176 
     | 
    
         
            +
             
     | 
| 
      
 177 
     | 
    
         
            +
                # Handle HTTP redirects
         
     | 
| 
      
 178 
     | 
    
         
            +
                # kwargs can include:
         
     | 
| 
      
 179 
     | 
    
         
            +
                # - original_uri: URI of the original request
         
     | 
| 
      
 180 
     | 
    
         
            +
                # - method: HTTP method of the original request (:get or :post)
         
     | 
| 
      
 181 
     | 
    
         
            +
                # - response: Net::HTTPResponse object
         
     | 
| 
      
 182 
     | 
    
         
            +
                # - headers: Hash of headers from the original request
         
     | 
| 
      
 183 
     | 
    
         
            +
                # - redirect_count: Number of redirects so far
         
     | 
| 
      
 184 
     | 
    
         
            +
                # - body: Request body for POST requests
         
     | 
| 
      
 185 
     | 
    
         
            +
                def handle_redirect(**kwargs) # rubocop:disable Metrics/AbcSize
         
     | 
| 
      
 186 
     | 
    
         
            +
                  response = kwargs[:response]
         
     | 
| 
      
 187 
     | 
    
         
            +
                  redirect_count = kwargs[:redirect_count]
         
     | 
| 
      
 188 
     | 
    
         
            +
             
     | 
| 
      
 189 
     | 
    
         
            +
                  return response unless response.is_a?(Net::HTTPRedirection)
         
     | 
| 
      
 190 
     | 
    
         
            +
                  raise HttpError.new("Too many redirects", response) if redirect_count >= MAX_REDIRECTS
         
     | 
| 
      
 191 
     | 
    
         
            +
             
     | 
| 
      
 192 
     | 
    
         
            +
                  location = URI(response["location"])
         
     | 
| 
      
 193 
     | 
    
         
            +
                  location = kwargs[:original_uri] + location if location.relative?
         
     | 
| 
      
 194 
     | 
    
         
            +
             
     | 
| 
      
 195 
     | 
    
         
            +
                  validate_uri!(location.to_s)
         
     | 
| 
      
 196 
     | 
    
         
            +
                  validate_ip!(location)
         
     | 
| 
      
 197 
     | 
    
         
            +
             
     | 
| 
      
 198 
     | 
    
         
            +
                  # Increment redirect count for the next request
         
     | 
| 
      
 199 
     | 
    
         
            +
                  redirect_count += 1
         
     | 
| 
      
 200 
     | 
    
         
            +
             
     | 
| 
      
 201 
     | 
    
         
            +
                  # Recursive call to handle the next redirect
         
     | 
| 
      
 202 
     | 
    
         
            +
                  if kwargs[:method] == :post
         
     | 
| 
      
 203 
     | 
    
         
            +
                    make_post_request(location, kwargs[:body], kwargs[:headers], redirect_count)
         
     | 
| 
      
 204 
     | 
    
         
            +
                  else
         
     | 
| 
      
 205 
     | 
    
         
            +
                    make_request(location, kwargs[:headers], redirect_count)
         
     | 
| 
      
 206 
     | 
    
         
            +
                  end
         
     | 
| 
      
 207 
     | 
    
         
            +
                end
         
     | 
| 
      
 208 
     | 
    
         
            +
             
     | 
| 
      
 209 
     | 
    
         
            +
                def validate_response!(response)
         
     | 
| 
      
 210 
     | 
    
         
            +
                  # check_success_status!(response)
         
     | 
| 
      
 211 
     | 
    
         
            +
                  check_content_length!(response)
         
     | 
| 
      
 212 
     | 
    
         
            +
                end
         
     | 
| 
      
 213 
     | 
    
         
            +
             
     | 
| 
      
 214 
     | 
    
         
            +
                def check_success_status!(response)
         
     | 
| 
      
 215 
     | 
    
         
            +
                  return if response.is_a?(Net::HTTPSuccess)
         
     | 
| 
      
 216 
     | 
    
         
            +
             
     | 
| 
      
 217 
     | 
    
         
            +
                  raise HttpError.new("HTTP request failed: #{response.code} #{response.message}", response)
         
     | 
| 
      
 218 
     | 
    
         
            +
                end
         
     | 
| 
      
 219 
     | 
    
         
            +
             
     | 
| 
      
 220 
     | 
    
         
            +
                def check_content_length!(response)
         
     | 
| 
      
 221 
     | 
    
         
            +
                  content_length = response["content-length"]&.to_i || response.body&.bytesize || 0
         
     | 
| 
      
 222 
     | 
    
         
            +
                  return unless content_length > MAX_RESPONSE_SIZE
         
     | 
| 
      
 223 
     | 
    
         
            +
             
     | 
| 
      
 224 
     | 
    
         
            +
                  raise HttpError.new("Response too large: #{content_length} bytes", response)
         
     | 
| 
      
 225 
     | 
    
         
            +
                end
         
     | 
| 
      
 226 
     | 
    
         
            +
              end
         
     | 
| 
      
 227 
     | 
    
         
            +
            end
         
     | 
| 
         @@ -0,0 +1,104 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            # frozen_string_literal: true
         
     | 
| 
      
 2 
     | 
    
         
            +
             
     | 
| 
      
 3 
     | 
    
         
            +
            module AtprotoAuth
         
     | 
| 
      
 4 
     | 
    
         
            +
              module Identity
         
     | 
| 
      
 5 
     | 
    
         
            +
                # Represents and validates a DID Document in the AT Protocol.
         
     | 
| 
      
 6 
     | 
    
         
            +
                #
         
     | 
| 
      
 7 
     | 
    
         
            +
                # DID Documents contain critical service information about user accounts, including:
         
     | 
| 
      
 8 
     | 
    
         
            +
                # - The Personal Data Server (PDS) hosting the account
         
     | 
| 
      
 9 
     | 
    
         
            +
                # - Associated handles for the account
         
     | 
| 
      
 10 
     | 
    
         
            +
                # - Key material for identity verification
         
     | 
| 
      
 11 
     | 
    
         
            +
                # - Service endpoints for various protocols
         
     | 
| 
      
 12 
     | 
    
         
            +
                #
         
     | 
| 
      
 13 
     | 
    
         
            +
                # This class handles both current and legacy DID document formats, providing
         
     | 
| 
      
 14 
     | 
    
         
            +
                # a consistent interface for accessing and validating document data.
         
     | 
| 
      
 15 
     | 
    
         
            +
                #
         
     | 
| 
      
 16 
     | 
    
         
            +
                # @example Creating a document from JSON
         
     | 
| 
      
 17 
     | 
    
         
            +
                #   data = {
         
     | 
| 
      
 18 
     | 
    
         
            +
                #     "id" => "did:plc:abc123",
         
     | 
| 
      
 19 
     | 
    
         
            +
                #     "alsoKnownAs" => ["at://alice.example.com"],
         
     | 
| 
      
 20 
     | 
    
         
            +
                #     "pds" => "https://pds.example.com"
         
     | 
| 
      
 21 
     | 
    
         
            +
                #   }
         
     | 
| 
      
 22 
     | 
    
         
            +
                #   doc = AtprotoAuth::Identity::Document.new(data)
         
     | 
| 
      
 23 
     | 
    
         
            +
                #
         
     | 
| 
      
 24 
     | 
    
         
            +
                #   puts doc.pds                    # => "https://pds.example.com"
         
     | 
| 
      
 25 
     | 
    
         
            +
                #   puts doc.has_handle?("alice.example.com")  # => true
         
     | 
| 
      
 26 
     | 
    
         
            +
                #
         
     | 
| 
      
 27 
     | 
    
         
            +
                # @example Handling legacy format
         
     | 
| 
      
 28 
     | 
    
         
            +
                #   legacy_data = {
         
     | 
| 
      
 29 
     | 
    
         
            +
                #     "id" => "did:plc:abc123",
         
     | 
| 
      
 30 
     | 
    
         
            +
                #     "service" => [{
         
     | 
| 
      
 31 
     | 
    
         
            +
                #       "id" => "#atproto_pds",
         
     | 
| 
      
 32 
     | 
    
         
            +
                #       "type" => "AtprotoPersonalDataServer",
         
     | 
| 
      
 33 
     | 
    
         
            +
                #       "serviceEndpoint" => "https://pds.example.com"
         
     | 
| 
      
 34 
     | 
    
         
            +
                #     }]
         
     | 
| 
      
 35 
     | 
    
         
            +
                #   }
         
     | 
| 
      
 36 
     | 
    
         
            +
                #   doc = AtprotoAuth::Identity::Document.new(legacy_data)
         
     | 
| 
      
 37 
     | 
    
         
            +
                #   puts doc.pds  # => "https://pds.example.com"
         
     | 
| 
      
 38 
     | 
    
         
            +
                class Document
         
     | 
| 
      
 39 
     | 
    
         
            +
                  attr_reader :did, :rotation_keys, :also_known_as, :services, :pds
         
     | 
| 
      
 40 
     | 
    
         
            +
             
     | 
| 
      
 41 
     | 
    
         
            +
                  # Creates a new Document from parsed JSON
         
     | 
| 
      
 42 
     | 
    
         
            +
                  # @param data [Hash] Parsed DID document data
         
     | 
| 
      
 43 
     | 
    
         
            +
                  # @raise [DocumentError] if document is invalid
         
     | 
| 
      
 44 
     | 
    
         
            +
                  def initialize(data)
         
     | 
| 
      
 45 
     | 
    
         
            +
                    validate_document!(data)
         
     | 
| 
      
 46 
     | 
    
         
            +
             
     | 
| 
      
 47 
     | 
    
         
            +
                    @did = data["id"]
         
     | 
| 
      
 48 
     | 
    
         
            +
                    @rotation_keys = data["verificationMethod"]&.map { |m| m["publicKeyMultibase"] } || []
         
     | 
| 
      
 49 
     | 
    
         
            +
                    @also_known_as = data["alsoKnownAs"] || []
         
     | 
| 
      
 50 
     | 
    
         
            +
                    @services = data["service"] || []
         
     | 
| 
      
 51 
     | 
    
         
            +
                    @pds = extract_pds!(data)
         
     | 
| 
      
 52 
     | 
    
         
            +
                  end
         
     | 
| 
      
 53 
     | 
    
         
            +
             
     | 
| 
      
 54 
     | 
    
         
            +
                  # Checks if this document contains a specific handle
         
     | 
| 
      
 55 
     | 
    
         
            +
                  # @param handle [String] Handle to check (with or without @ prefix)
         
     | 
| 
      
 56 
     | 
    
         
            +
                  # @return [Boolean] true if handle is listed in alsoKnownAs
         
     | 
| 
      
 57 
     | 
    
         
            +
                  def has_handle?(handle) # rubocop:disable Naming/PredicateName
         
     | 
| 
      
 58 
     | 
    
         
            +
                    normalized = handle.start_with?("@") ? handle[1..] : handle
         
     | 
| 
      
 59 
     | 
    
         
            +
                    @also_known_as.any? do |aka|
         
     | 
| 
      
 60 
     | 
    
         
            +
                      aka.start_with?("at://") && aka.delete_prefix("at://") == normalized
         
     | 
| 
      
 61 
     | 
    
         
            +
                    end
         
     | 
| 
      
 62 
     | 
    
         
            +
                  end
         
     | 
| 
      
 63 
     | 
    
         
            +
             
     | 
| 
      
 64 
     | 
    
         
            +
                  private
         
     | 
| 
      
 65 
     | 
    
         
            +
             
     | 
| 
      
 66 
     | 
    
         
            +
                  def validate_document!(data)
         
     | 
| 
      
 67 
     | 
    
         
            +
                    raise DocumentError, "Document cannot be nil" if data.nil?
         
     | 
| 
      
 68 
     | 
    
         
            +
                    raise DocumentError, "Document must be a Hash" unless data.is_a?(Hash)
         
     | 
| 
      
 69 
     | 
    
         
            +
                    raise DocumentError, "Document must have id" unless data["id"]
         
     | 
| 
      
 70 
     | 
    
         
            +
             
     | 
| 
      
 71 
     | 
    
         
            +
                    validate_did!(data["id"])
         
     | 
| 
      
 72 
     | 
    
         
            +
                    validate_services!(data["service"])
         
     | 
| 
      
 73 
     | 
    
         
            +
                  end
         
     | 
| 
      
 74 
     | 
    
         
            +
             
     | 
| 
      
 75 
     | 
    
         
            +
                  def validate_did!(did)
         
     | 
| 
      
 76 
     | 
    
         
            +
                    return if did.start_with?("did:plc:")
         
     | 
| 
      
 77 
     | 
    
         
            +
             
     | 
| 
      
 78 
     | 
    
         
            +
                    raise DocumentError, "Invalid DID format (must be did:plc:): #{did}"
         
     | 
| 
      
 79 
     | 
    
         
            +
                  end
         
     | 
| 
      
 80 
     | 
    
         
            +
             
     | 
| 
      
 81 
     | 
    
         
            +
                  def validate_services!(services) # rubocop:disable Metrics/CyclomaticComplexity
         
     | 
| 
      
 82 
     | 
    
         
            +
                    return if services.nil?
         
     | 
| 
      
 83 
     | 
    
         
            +
                    raise DocumentError, "services must be an array" unless services.is_a?(Array)
         
     | 
| 
      
 84 
     | 
    
         
            +
             
     | 
| 
      
 85 
     | 
    
         
            +
                    services.each do |svc|
         
     | 
| 
      
 86 
     | 
    
         
            +
                      unless svc.is_a?(Hash) && svc["id"] && svc["type"] && svc["serviceEndpoint"]
         
     | 
| 
      
 87 
     | 
    
         
            +
                        raise DocumentError, "Invalid service entry format"
         
     | 
| 
      
 88 
     | 
    
         
            +
                      end
         
     | 
| 
      
 89 
     | 
    
         
            +
                    end
         
     | 
| 
      
 90 
     | 
    
         
            +
                  end
         
     | 
| 
      
 91 
     | 
    
         
            +
             
     | 
| 
      
 92 
     | 
    
         
            +
                  def extract_pds!(data)
         
     | 
| 
      
 93 
     | 
    
         
            +
                    pds = data["pds"] # New format
         
     | 
| 
      
 94 
     | 
    
         
            +
                    return pds if pds
         
     | 
| 
      
 95 
     | 
    
         
            +
             
     | 
| 
      
 96 
     | 
    
         
            +
                    # Legacy format - look through services
         
     | 
| 
      
 97 
     | 
    
         
            +
                    service = @services.find { |s| s["type"] == "AtprotoPersonalDataServer" }
         
     | 
| 
      
 98 
     | 
    
         
            +
                    raise DocumentError, "No PDS location found in document" unless service
         
     | 
| 
      
 99 
     | 
    
         
            +
             
     | 
| 
      
 100 
     | 
    
         
            +
                    service["serviceEndpoint"]
         
     | 
| 
      
 101 
     | 
    
         
            +
                  end
         
     | 
| 
      
 102 
     | 
    
         
            +
                end
         
     | 
| 
      
 103 
     | 
    
         
            +
              end
         
     | 
| 
      
 104 
     | 
    
         
            +
            end
         
     |