rager 0.2.1 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -15,12 +15,6 @@ module Rager
15
15
  api_key = options.api_key || ENV["REPLICATE_API_KEY"]
16
16
  raise Rager::Errors::MissingCredentialsError.new("Replicate", "REPLICATE_API_KEY") if api_key.nil?
17
17
 
18
- headers = {
19
- "Authorization" => "Bearer #{api_key}",
20
- "Content-Type" => "application/json",
21
- "Prefer" => "wait"
22
- }
23
-
24
18
  body = {
25
19
  version: options.version,
26
20
  input: {
@@ -31,29 +25,29 @@ module Rager
31
25
  save_gaussian_ply: true,
32
26
  ss_sampling_steps: 38
33
27
  }
34
- }
35
- body[:input][:seed] = options.seed unless options.seed.nil?
28
+ }.tap do |b|
29
+ b[:input][:seed] = options.seed if options.seed
30
+ end
36
31
 
37
32
  request = Rager::Http::Request.new(
38
33
  url: "https://api.replicate.com/v1/predictions",
39
34
  verb: Rager::Http::Verb::Post,
40
- headers: headers,
35
+ headers: {
36
+ "Authorization" => "Bearer #{api_key}",
37
+ "Content-Type" => "application/json"
38
+ },
41
39
  body: body.to_json
42
40
  )
43
41
 
44
42
  response = Rager.config.http_adapter.make_request(request)
45
43
  response_body = T.cast(T.must(response.body), String)
46
44
 
47
- raise Rager::Errors::HttpError.new(Rager.config.http_adapter, response.status, response_body) unless [200,
48
- 201, 202].include?(response.status)
45
+ raise Rager::Errors::HttpError.new(Rager.config.http_adapter, response.status, response_body) unless [200, 201, 202].include?(response.status)
49
46
 
50
- begin
51
- json = JSON.parse(response_body)
52
- model_file = json.dig("output", "model_file")
53
- model_file || json.fetch("urls").fetch("get")
54
- rescue JSON::ParserError, KeyError => e
55
- raise Rager::Errors::ParseError.new(e.message, response_body)
56
- end
47
+ json = JSON.parse(response_body)
48
+ json.fetch("urls").fetch("get")
49
+ rescue JSON::ParserError, KeyError => e
50
+ raise Rager::Errors::ParseError.new(e.message, response_body)
57
51
  end
58
52
  end
59
53
  end
@@ -9,7 +9,7 @@ module Rager
9
9
 
10
10
  sig do
11
11
  params(
12
- image_url: String,
12
+ image_url: Rager::Types::MeshGenInput,
13
13
  options: Rager::MeshGen::Options
14
14
  ).returns(Rager::Types::MeshGenOutput)
15
15
  end
data/lib/rager/options.rb CHANGED
@@ -8,13 +8,15 @@ module Rager
8
8
  extend T::Sig
9
9
  extend T::Helpers
10
10
  requires_ancestor { T::Struct }
11
- interface!
12
11
 
13
- sig { abstract.returns(T::Hash[String, T.untyped]) }
12
+ sig { returns(T::Hash[String, T.untyped]) }
14
13
  def serialize_safe
14
+ result = T.cast(self, T::Struct).serialize
15
+ result["api_key"] = "[REDACTED]" if result.key?("api_key")
16
+ result
15
17
  end
16
18
 
17
- sig { abstract.void }
19
+ sig { void }
18
20
  def validate
19
21
  end
20
22
  end
@@ -14,17 +14,6 @@ module Rager
14
14
  const :model, T.nilable(String)
15
15
  const :n, T.nilable(Integer)
16
16
  const :api_key, T.nilable(String)
17
-
18
- sig { override.returns(T::Hash[String, T.untyped]) }
19
- def serialize_safe
20
- result = serialize
21
- result["api_key"] = "[REDACTED]" if result.key?("api_key")
22
- result
23
- end
24
-
25
- sig { override.void }
26
- def validate
27
- end
28
17
  end
29
18
  end
30
19
  end
@@ -20,50 +20,36 @@ module Rager
20
20
  api_key = options.api_key || ENV["COHERE_API_KEY"]
21
21
  raise Rager::Errors::MissingCredentialsError.new("Cohere", "COHERE_API_KEY") if api_key.nil?
22
22
 
23
- url = options.url || ENV["COHERE_URL"] || "https://api.cohere.com/v2/rerank"
24
- model = options.model || "rerank-v3.5"
25
-
26
- headers = {
27
- "Content-Type" => "application/json"
28
- }
29
- headers["Authorization"] = "Bearer #{api_key}" if api_key
30
-
31
23
  body = {
32
- model: model,
24
+ model: options.model || "rerank-v3.5",
33
25
  query: query.query,
34
26
  documents: query.documents
35
- }
36
- body[:top_n] = options.n if options.n
27
+ }.tap do |b|
28
+ b[:top_n] = options.n if options.n
29
+ end
30
+
31
+ headers = {"Content-Type" => "application/json"}
32
+ headers["Authorization"] = "Bearer #{api_key}" if api_key
37
33
 
38
34
  request = Rager::Http::Request.new(
39
35
  verb: Rager::Http::Verb::Post,
40
- url: url,
36
+ url: options.url || ENV["COHERE_URL"] || "https://api.cohere.com/v2/rerank",
41
37
  headers: headers,
42
38
  body: body.to_json
43
39
  )
44
40
 
45
- http_adapter = Rager.config.http_adapter
46
- response = http_adapter.make_request(request)
41
+ response = Rager.config.http_adapter.make_request(request)
47
42
  response_body = T.cast(T.must(response.body), String)
48
43
 
49
- if response.status != 200
50
- raise Rager::Errors::HttpError.new(
51
- http_adapter,
52
- response.status,
53
- response_body
54
- )
55
- end
44
+ raise Rager::Errors::HttpError.new(Rager.config.http_adapter, response.status, response_body) if response.status != 200
56
45
 
57
46
  parsed_response = JSON.parse(response_body)
58
-
59
47
  results = parsed_response["results"] || []
60
- results.map do |result|
61
- index = result["index"]
62
- score = result["relevance_score"]
63
48
 
49
+ results.map do |result|
64
50
  Rager::Rerank::Result.new(
65
- score: score,
66
- index: index
51
+ score: result["relevance_score"],
52
+ index: result["index"]
67
53
  )
68
54
  end
69
55
  end
data/lib/rager/rerank.rb CHANGED
@@ -9,7 +9,7 @@ module Rager
9
9
 
10
10
  sig do
11
11
  params(
12
- query: Rager::Rerank::Query,
12
+ query: Rager::Types::RerankInput,
13
13
  options: Rager::Rerank::Options
14
14
  ).returns(Rager::Types::RerankOutput)
15
15
  end
data/lib/rager/result.rb CHANGED
@@ -2,6 +2,7 @@
2
2
  # frozen_string_literal: true
3
3
 
4
4
  require "json"
5
+ require "logger"
5
6
  require "securerandom"
6
7
  require "sorbet-runtime"
7
8
 
@@ -12,6 +13,9 @@ module Rager
12
13
  sig { returns(String) }
13
14
  attr_reader :id
14
15
 
16
+ sig { returns(String) }
17
+ attr_reader :context_id
18
+
15
19
  sig { returns(Rager::Operation) }
16
20
  attr_reader :operation
17
21
 
@@ -21,6 +25,9 @@ module Rager
21
25
  sig { returns(Rager::Types::Input) }
22
26
  attr_reader :input
23
27
 
28
+ sig { returns(T.nilable(Rager::Types::Output)) }
29
+ attr_reader :output
30
+
24
31
  sig { returns(Integer) }
25
32
  attr_reader :start_time
26
33
 
@@ -28,57 +35,62 @@ module Rager
28
35
  attr_reader :end_time
29
36
 
30
37
  sig { returns(T.nilable(String)) }
31
- attr_reader :result_id
38
+ attr_reader :name
32
39
 
33
- sig { returns(T.nilable(String)) }
34
- attr_reader :context_id
40
+ sig { returns(T.nilable(T::Array[String])) }
41
+ attr_reader :iids
35
42
 
36
43
  sig { returns(T.nilable(String)) }
37
- attr_reader :hash
38
-
39
- sig { returns(T.nilable(T::Array[String])) }
40
- attr_reader :input_ids
44
+ attr_reader :error
41
45
 
42
46
  sig do
43
47
  params(
44
48
  id: String,
49
+ context_id: String,
45
50
  operation: Rager::Operation,
46
- options: Rager::Options,
47
51
  input: Rager::Types::Input,
48
- output: Rager::Types::Output,
52
+ output: T.nilable(Rager::Types::Output),
53
+ options: Rager::Options,
49
54
  start_time: Integer,
50
55
  end_time: Integer,
51
- context_id: T.nilable(String),
52
- hash: T.nilable(String),
53
- input_ids: T.nilable(T::Array[String])
56
+ name: T.nilable(String),
57
+ iids: T.nilable(T::Array[String]),
58
+ error: T.nilable(String)
54
59
  ).void
55
60
  end
56
61
  def initialize(
57
62
  id:,
63
+ context_id:,
58
64
  operation:,
59
- options:,
60
65
  input:,
61
66
  output:,
67
+ options:,
62
68
  start_time:,
63
69
  end_time:,
64
- context_id: nil,
65
- hash: nil,
66
- input_ids: nil
70
+ name: nil,
71
+ iids: nil,
72
+ error: nil
67
73
  )
68
74
  @id = id
75
+ @context_id = context_id
69
76
  @operation = operation
70
- @options = options
71
77
  @input = input
72
78
  @output = output
79
+ @options = options
73
80
  @start_time = start_time
74
81
  @end_time = end_time
82
+ @name = T.let(name, T.nilable(String))
83
+ @iids = T.let(iids, T.nilable(T::Array[String]))
84
+ @error = T.let(error, T.nilable(String))
85
+
75
86
  @stream = T.let(nil, T.nilable(Rager::Types::Stream))
76
87
  @buffer = T.let([], Rager::Types::Buffer)
77
88
  @consumed = T.let(false, T::Boolean)
78
- @result_id = T.let(SecureRandom.uuid, T.nilable(String))
79
- @context_id = T.let(context_id, T.nilable(String))
80
- @hash = T.let(hash, T.nilable(String))
81
- @input_ids = T.let(input_ids, T.nilable(T::Array[String]))
89
+ end
90
+
91
+ sig { returns(T::Boolean) }
92
+ def success?
93
+ @error.nil?
82
94
  end
83
95
 
84
96
  sig { returns(T::Boolean) }
@@ -88,11 +100,9 @@ module Rager
88
100
 
89
101
  sig { returns(Rager::Types::Output) }
90
102
  def out
91
- return @output unless stream?
103
+ return T.must(@output) unless stream?
92
104
  return @buffer.each if @consumed
93
105
 
94
- log
95
-
96
106
  @stream = Enumerator.new do |yielder|
97
107
  T.cast(@output, Rager::Types::Stream)
98
108
  .each { |message_delta|
@@ -157,49 +167,41 @@ module Rager
157
167
  end
158
168
  end
159
169
 
160
- sig { returns(Rager::Types::NonStreamOutput) }
170
+ sig { returns(T.nilable(Rager::Types::NonStreamOutput)) }
161
171
  def serialize_output
162
- if @consumed
163
- mat
164
- elsif stream?
165
- "[STREAM]"
166
- else
167
- T.cast(@output, Rager::Types::NonStreamOutput)
168
- end
172
+ return nil unless success?
173
+ return @consumed ? mat : "[STREAM]" if stream?
174
+ T.cast(@output, Rager::Types::NonStreamOutput)
169
175
  end
170
176
 
171
177
  sig { returns(T::Hash[String, T.untyped]) }
172
178
  def to_h
173
179
  {
174
180
  id: @id,
181
+ context_id: @context_id,
175
182
  operation: @operation.serialize,
176
- options: @options.serialize_safe,
177
183
  input: serialize_input,
178
- output: if @consumed
179
- mat
180
- elsif stream?
181
- "[STREAM]"
182
- else
183
- @output
184
- end,
184
+ output: serialize_output,
185
+ options: @options.serialize_safe,
185
186
  start_time: @start_time,
186
187
  end_time: @end_time,
187
- result_id: @result_id,
188
- context_id: @context_id,
189
- hash: @hash,
190
- input_ids: @input_ids
188
+ name: @name,
189
+ iids: @iids,
190
+ error: @error
191
191
  }
192
192
  end
193
193
 
194
194
  sig { void }
195
195
  def log
196
- return unless Rager.config.logger
196
+ return unless Rager.config.logger_type
197
197
 
198
198
  json = to_h.to_json
199
199
 
200
- case Rager.config.logger
200
+ logger = Rager.config.logger
201
+
202
+ case Rager.config.logger_type
201
203
  when Rager::Logger::Stdout
202
- puts "\nLOG: #{json}"
204
+ success? ? logger.info(json) : logger.error(json)
203
205
  when Rager::Logger::Remote
204
206
  http_adapter = Rager.config.http_adapter
205
207
  url = Rager.config.url
@@ -220,9 +222,7 @@ module Rager
220
222
 
221
223
  response = http_adapter.make_request(request)
222
224
 
223
- unless response.status >= 200 && response.status < 300
224
- warn "Remote log failed: \\#{response.status} \\#{response.body}"
225
- end
225
+ logger.warn "Remote log failed: \\#{response.status} \\#{response.body}" unless response.success?
226
226
  end
227
227
  end
228
228
  end
@@ -12,17 +12,6 @@ module Rager
12
12
  const :provider, String, default: "brave"
13
13
  const :n, T.nilable(Integer)
14
14
  const :api_key, T.nilable(String)
15
-
16
- sig { override.returns(T::Hash[String, T.untyped]) }
17
- def serialize_safe
18
- result = serialize
19
- result["api_key"] = "[REDACTED]" if result.key?("api_key")
20
- result
21
- end
22
-
23
- sig { override.void }
24
- def validate
25
- end
26
15
  end
27
16
  end
28
17
  end
@@ -21,54 +21,37 @@ module Rager
21
21
  api_key = options.api_key || ENV["BRAVE_API_KEY"]
22
22
  raise Rager::Errors::MissingCredentialsError.new("Brave", "BRAVE_API_KEY") if api_key.nil?
23
23
 
24
- params = {"q" => query}
25
- params["count"] = options.n.to_s if options.n
26
-
27
- query_string = URI.encode_www_form(params)
28
- url = "https://api.search.brave.com/res/v1/web/search?#{query_string}"
29
-
30
- headers = {
31
- "Accept" => "application/json",
32
- "Accept-Encoding" => "gzip",
33
- "x-subscription-token" => api_key
34
- }
24
+ params = {"q" => query}.tap do |p|
25
+ p["count"] = options.n.to_s if options.n
26
+ end
35
27
 
36
28
  request = Rager::Http::Request.new(
37
29
  verb: Rager::Http::Verb::Get,
38
- url: url,
39
- headers: headers
30
+ url: "https://api.search.brave.com/res/v1/web/search?#{URI.encode_www_form(params)}",
31
+ headers: {
32
+ "Accept" => "application/json",
33
+ "Accept-Encoding" => "gzip",
34
+ "x-subscription-token" => api_key
35
+ }
40
36
  )
41
37
 
42
- http_adapter = Rager.config.http_adapter
43
- response = http_adapter.make_request(request)
38
+ response = Rager.config.http_adapter.make_request(request)
44
39
  response_body = T.cast(T.must(response.body), String)
45
40
 
46
- if response.status != 200
47
- raise Rager::Errors::HttpError.new(
48
- http_adapter,
49
- response.status,
50
- response_body
51
- )
52
- end
41
+ raise Rager::Errors::HttpError.new(Rager.config.http_adapter, response.status, response_body) if response.status != 200
53
42
 
54
43
  parsed_response = JSON.parse(response_body)
44
+ web_results = parsed_response.dig("web", "results") || []
55
45
 
56
- results = []
57
-
58
- if parsed_response["web"] && parsed_response["web"]["results"]
59
- web_results = parsed_response["web"]["results"]
60
- web_results.each do |result|
61
- next unless result["title"] && result["url"] && result["description"]
46
+ web_results.filter_map do |result|
47
+ next unless result["title"] && result["url"] && result["description"]
62
48
 
63
- results << Rager::Search::Result.new(
64
- title: result["title"],
65
- url: result["url"],
66
- description: result["description"]
67
- )
68
- end
49
+ Rager::Search::Result.new(
50
+ title: result["title"],
51
+ url: result["url"],
52
+ description: result["description"]
53
+ )
69
54
  end
70
-
71
- results
72
55
  end
73
56
  end
74
57
  end
data/lib/rager/search.rb CHANGED
@@ -9,7 +9,7 @@ module Rager
9
9
 
10
10
  sig do
11
11
  params(
12
- query: String,
12
+ query: Rager::Types::SearchInput,
13
13
  options: Rager::Search::Options
14
14
  ).returns(Rager::Types::SearchOutput)
15
15
  end
@@ -8,15 +8,6 @@ module Rager
8
8
  extend T::Sig
9
9
 
10
10
  const :provider, String, default: "erb"
11
-
12
- sig { override.returns(T::Hash[String, T.untyped]) }
13
- def serialize_safe
14
- serialize.transform_keys(&:to_s)
15
- end
16
-
17
- sig { override.void }
18
- def validate
19
- end
20
11
  end
21
12
  end
22
13
  end
@@ -9,7 +9,7 @@ module Rager
9
9
 
10
10
  sig do
11
11
  params(
12
- input: Rager::Template::Input,
12
+ input: Rager::Types::TemplateInput,
13
13
  options: Rager::Template::Options
14
14
  ).returns(Rager::Types::TemplateOutput)
15
15
  end
data/lib/rager/types.rb CHANGED
@@ -52,7 +52,6 @@ module Rager
52
52
  }
53
53
 
54
54
  ChatNonStreamOutput = T.type_alias { ChatNonStream }
55
-
56
55
  NonStreamOutput = T.type_alias {
57
56
  T.any(
58
57
  ChatNonStreamOutput,
@@ -0,0 +1,39 @@
1
+ # typed: strict
2
+ # frozen_string_literal: true
3
+
4
+ require "sorbet-runtime"
5
+ require "uri"
6
+
7
+ module Rager
8
+ module Utils
9
+ module Http
10
+ extend T::Sig
11
+
12
+ sig { params(url: String, path: T.nilable(String), headers: T.nilable(T::Hash[String, String])).returns(String) }
13
+ def self.download_file(url, path = nil, headers = nil)
14
+ path ||= begin
15
+ uri_path = URI.parse(url).path
16
+ if uri_path.nil? || uri_path.empty?
17
+ "output"
18
+ else
19
+ filename = File.basename(uri_path)
20
+ filename.empty? ? "output" : filename
21
+ end
22
+ end
23
+
24
+ response = Rager.config.http_adapter.make_request(
25
+ Rager::Http::Request.new(url: url, headers: headers || {})
26
+ )
27
+
28
+ raise Rager::Errors::HttpError.new(Rager.config.http_adapter, response.status, T.cast(response.body, T.nilable(String))) unless response.success?
29
+
30
+ File.binwrite(path, T.cast(T.must(response.body), String))
31
+ path
32
+ rescue Rager::Errors::HttpError
33
+ raise
34
+ rescue => e
35
+ raise Rager::Errors::HttpError.new(Rager.config.http_adapter, 500, "Download failed: #{e.message}")
36
+ end
37
+ end
38
+ end
39
+ end
@@ -0,0 +1,72 @@
1
+ # typed: strict
2
+ # frozen_string_literal: true
3
+
4
+ require "sorbet-runtime"
5
+ require "json"
6
+
7
+ module Rager
8
+ module Utils
9
+ module Replicate
10
+ extend T::Sig
11
+
12
+ sig { params(prediction_url: String, key: T.nilable(String), path: T.nilable(String)).returns(T.nilable(String)) }
13
+ def self.download_prediction(prediction_url, key = nil, path = nil)
14
+ api_key = ENV["REPLICATE_API_KEY"]
15
+ raise Rager::Errors::MissingCredentialsError.new("Replicate", "REPLICATE_API_KEY") if api_key.nil?
16
+
17
+ response = Rager.config.http_adapter.make_request(
18
+ Rager::Http::Request.new(
19
+ url: prediction_url,
20
+ headers: {"Authorization" => "Bearer #{api_key}", "Content-Type" => "application/json"}
21
+ )
22
+ )
23
+
24
+ raise Rager::Errors::HttpError.new(Rager.config.http_adapter, response.status, T.cast(response.body, T.nilable(String))) unless response.success?
25
+
26
+ data = JSON.parse(T.cast(T.must(response.body), String))
27
+ return nil if ["starting", "processing"].include?(data["status"])
28
+
29
+ if data["status"] == "failed"
30
+ error_msg = data["error"] || "Prediction failed"
31
+ raise Rager::Errors::HttpError.new(Rager.config.http_adapter, 422, "Prediction failed: #{error_msg}")
32
+ end
33
+
34
+ return nil unless data["status"] == "succeeded"
35
+
36
+ output = data["output"]
37
+ return nil if output.nil?
38
+
39
+ download_url = if key && output.is_a?(Hash) && output[key]
40
+ output[key]
41
+ elsif output.is_a?(Hash)
42
+ output.values.compact.first
43
+ elsif output.is_a?(Array) && !output.empty?
44
+ output.first
45
+ else
46
+ output.to_s
47
+ end
48
+
49
+ return nil if download_url.nil? || download_url.empty?
50
+
51
+ Rager::Utils::Http.download_file(download_url, path)
52
+ rescue JSON::ParserError => e
53
+ raise Rager::Errors::ParseError.new("Failed to parse prediction response", e.message)
54
+ rescue Rager::Errors::HttpError, Rager::Errors::ParseError, Rager::Errors::MissingCredentialsError
55
+ raise
56
+ rescue
57
+ nil
58
+ end
59
+
60
+ sig { params(prediction_url: String, key: T.nilable(String), path: T.nilable(String), max_attempts: Integer, sleep_interval: Integer).returns(T.nilable(String)) }
61
+ def self.poll_prediction(prediction_url, key: nil, path: nil, max_attempts: 30, sleep_interval: 10)
62
+ max_attempts.times do
63
+ result = download_prediction(prediction_url, key, path)
64
+ return result unless result.nil?
65
+ sleep(sleep_interval)
66
+ end
67
+
68
+ raise Rager::Errors::HttpError.new(Rager.config.http_adapter, 408, "Prediction polling timed out after #{max_attempts} attempts")
69
+ end
70
+ end
71
+ end
72
+ end
data/lib/rager/version.rb CHANGED
@@ -2,5 +2,5 @@
2
2
  # frozen_string_literal: true
3
3
 
4
4
  module Rager
5
- VERSION = "0.2.1"
5
+ VERSION = "0.3.0"
6
6
  end
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: rager
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.2.1
4
+ version: 0.3.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - mvkvc
@@ -106,6 +106,7 @@ files:
106
106
  - lib/rager/http/verb.rb
107
107
  - lib/rager/image_gen.rb
108
108
  - lib/rager/image_gen/options.rb
109
+ - lib/rager/image_gen/output_format.rb
109
110
  - lib/rager/image_gen/providers/abstract.rb
110
111
  - lib/rager/image_gen/providers/replicate.rb
111
112
  - lib/rager/logger.rb
@@ -133,8 +134,10 @@ files:
133
134
  - lib/rager/template/providers/abstract.rb
134
135
  - lib/rager/template/providers/erb.rb
135
136
  - lib/rager/types.rb
137
+ - lib/rager/utils/http.rb
138
+ - lib/rager/utils/replicate.rb
136
139
  - lib/rager/version.rb
137
- homepage: https://github.com/mvkvc/rager_ruby
140
+ homepage: https://github.com/mvkvc/rager_rb
138
141
  licenses:
139
142
  - MIT
140
143
  metadata: {}