rager 0.2.1 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -12,42 +12,38 @@ module Rager
12
12
 
13
13
  sig { override.params(prompt: String, options: Rager::ImageGen::Options).returns(Rager::Types::ImageGenOutput) }
14
14
  def image_gen(prompt, options)
15
- url = "https://api.replicate.com/v1/models/#{options.model}/predictions"
16
15
  api_key = options.api_key || ENV["REPLICATE_API_KEY"]
17
16
  raise Rager::Errors::MissingCredentialsError.new("Replicate", "REPLICATE_API_KEY") if api_key.nil?
18
17
 
19
- headers = {
20
- "Authorization" => "Bearer #{api_key}",
21
- "Content-Type" => "application/json",
22
- "Prefer" => "wait"
23
- }
24
-
25
18
  body = {
26
19
  input: {
27
20
  prompt: prompt
28
21
  }
29
- }
30
- body[:input][:seed] = options.seed unless options.seed.nil?
22
+ }.tap do |b|
23
+ b[:input][:output_format] = T.must(options.output_format).serialize if options.output_format
24
+ b[:input][:seed] = options.seed if options.seed
25
+ end
31
26
 
32
27
  request = Rager::Http::Request.new(
33
- url: url,
28
+ url: "https://api.replicate.com/v1/models/#{options.model}/predictions",
34
29
  verb: Rager::Http::Verb::Post,
35
- headers: headers,
30
+ headers: {
31
+ "Authorization" => "Bearer #{api_key}",
32
+ "Content-Type" => "application/json",
33
+ "Prefer" => "wait"
34
+ },
36
35
  body: body.to_json
37
36
  )
38
37
 
39
- http_adapter = Rager.config.http_adapter
40
- response = http_adapter.make_request(request)
38
+ response = Rager.config.http_adapter.make_request(request)
41
39
  response_body = T.cast(T.must(response.body), String)
42
40
 
43
- raise Rager::Errors::HttpError.new(http_adapter, response.status, response_body) unless [200, 201].include?(response.status)
41
+ raise Rager::Errors::HttpError.new(Rager.config.http_adapter, response.status, response_body) unless [200, 201].include?(response.status)
44
42
 
45
- begin
46
- parsed = JSON.parse(response_body)
47
- parsed.fetch("output").first
48
- rescue JSON::ParserError, KeyError => e
49
- raise Rager::Errors::ParseError.new(e.message, response_body)
50
- end
43
+ parsed = JSON.parse(response_body)
44
+ parsed.fetch("output").first
45
+ rescue JSON::ParserError, KeyError => e
46
+ raise Rager::Errors::ParseError.new(e.message, response_body)
51
47
  end
52
48
  end
53
49
  end
@@ -9,7 +9,7 @@ module Rager
9
9
 
10
10
  sig do
11
11
  params(
12
- prompt: String,
12
+ prompt: Rager::Types::ImageGenInput,
13
13
  options: Rager::ImageGen::Options
14
14
  ).returns(Rager::Types::ImageGenOutput)
15
15
  end
@@ -14,17 +14,6 @@ module Rager
14
14
  const :version, String, default: "4876f2a8da1c544772dffa32e8889da4a1bab3a1f5c1937bfcfccb99ae347251"
15
15
  const :api_key, T.nilable(String)
16
16
  const :seed, T.nilable(Integer)
17
-
18
- sig { override.returns(T::Hash[String, T.untyped]) }
19
- def serialize_safe
20
- result = serialize
21
- result["api_key"] = "[REDACTED]" if result.key?("api_key")
22
- result
23
- end
24
-
25
- sig { override.void }
26
- def validate
27
- end
28
17
  end
29
18
  end
30
19
  end
@@ -15,12 +15,6 @@ module Rager
15
15
  api_key = options.api_key || ENV["REPLICATE_API_KEY"]
16
16
  raise Rager::Errors::MissingCredentialsError.new("Replicate", "REPLICATE_API_KEY") if api_key.nil?
17
17
 
18
- headers = {
19
- "Authorization" => "Bearer #{api_key}",
20
- "Content-Type" => "application/json",
21
- "Prefer" => "wait"
22
- }
23
-
24
18
  body = {
25
19
  version: options.version,
26
20
  input: {
@@ -31,29 +25,29 @@ module Rager
31
25
  save_gaussian_ply: true,
32
26
  ss_sampling_steps: 38
33
27
  }
34
- }
35
- body[:input][:seed] = options.seed unless options.seed.nil?
28
+ }.tap do |b|
29
+ b[:input][:seed] = options.seed if options.seed
30
+ end
36
31
 
37
32
  request = Rager::Http::Request.new(
38
33
  url: "https://api.replicate.com/v1/predictions",
39
34
  verb: Rager::Http::Verb::Post,
40
- headers: headers,
35
+ headers: {
36
+ "Authorization" => "Bearer #{api_key}",
37
+ "Content-Type" => "application/json"
38
+ },
41
39
  body: body.to_json
42
40
  )
43
41
 
44
42
  response = Rager.config.http_adapter.make_request(request)
45
43
  response_body = T.cast(T.must(response.body), String)
46
44
 
47
- raise Rager::Errors::HttpError.new(Rager.config.http_adapter, response.status, response_body) unless [200,
48
- 201, 202].include?(response.status)
45
+ raise Rager::Errors::HttpError.new(Rager.config.http_adapter, response.status, response_body) unless [200, 201, 202].include?(response.status)
49
46
 
50
- begin
51
- json = JSON.parse(response_body)
52
- model_file = json.dig("output", "model_file")
53
- model_file || json.fetch("urls").fetch("get")
54
- rescue JSON::ParserError, KeyError => e
55
- raise Rager::Errors::ParseError.new(e.message, response_body)
56
- end
47
+ json = JSON.parse(response_body)
48
+ json.fetch("urls").fetch("get")
49
+ rescue JSON::ParserError, KeyError => e
50
+ raise Rager::Errors::ParseError.new(e.message, response_body)
57
51
  end
58
52
  end
59
53
  end
@@ -9,7 +9,7 @@ module Rager
9
9
 
10
10
  sig do
11
11
  params(
12
- image_url: String,
12
+ image_url: Rager::Types::MeshGenInput,
13
13
  options: Rager::MeshGen::Options
14
14
  ).returns(Rager::Types::MeshGenOutput)
15
15
  end
data/lib/rager/options.rb CHANGED
@@ -8,13 +8,15 @@ module Rager
8
8
  extend T::Sig
9
9
  extend T::Helpers
10
10
  requires_ancestor { T::Struct }
11
- interface!
12
11
 
13
- sig { abstract.returns(T::Hash[String, T.untyped]) }
12
+ sig { returns(T::Hash[String, T.untyped]) }
14
13
  def serialize_safe
14
+ result = T.cast(self, T::Struct).serialize
15
+ result["api_key"] = "[REDACTED]" if result.key?("api_key")
16
+ result
15
17
  end
16
18
 
17
- sig { abstract.void }
19
+ sig { void }
18
20
  def validate
19
21
  end
20
22
  end
@@ -14,17 +14,6 @@ module Rager
14
14
  const :model, T.nilable(String)
15
15
  const :n, T.nilable(Integer)
16
16
  const :api_key, T.nilable(String)
17
-
18
- sig { override.returns(T::Hash[String, T.untyped]) }
19
- def serialize_safe
20
- result = serialize
21
- result["api_key"] = "[REDACTED]" if result.key?("api_key")
22
- result
23
- end
24
-
25
- sig { override.void }
26
- def validate
27
- end
28
17
  end
29
18
  end
30
19
  end
@@ -20,50 +20,36 @@ module Rager
20
20
  api_key = options.api_key || ENV["COHERE_API_KEY"]
21
21
  raise Rager::Errors::MissingCredentialsError.new("Cohere", "COHERE_API_KEY") if api_key.nil?
22
22
 
23
- url = options.url || ENV["COHERE_URL"] || "https://api.cohere.com/v2/rerank"
24
- model = options.model || "rerank-v3.5"
25
-
26
- headers = {
27
- "Content-Type" => "application/json"
28
- }
29
- headers["Authorization"] = "Bearer #{api_key}" if api_key
30
-
31
23
  body = {
32
- model: model,
24
+ model: options.model || "rerank-v3.5",
33
25
  query: query.query,
34
26
  documents: query.documents
35
- }
36
- body[:top_n] = options.n if options.n
27
+ }.tap do |b|
28
+ b[:top_n] = options.n if options.n
29
+ end
30
+
31
+ headers = {"Content-Type" => "application/json"}
32
+ headers["Authorization"] = "Bearer #{api_key}" if api_key
37
33
 
38
34
  request = Rager::Http::Request.new(
39
35
  verb: Rager::Http::Verb::Post,
40
- url: url,
36
+ url: options.url || ENV["COHERE_URL"] || "https://api.cohere.com/v2/rerank",
41
37
  headers: headers,
42
38
  body: body.to_json
43
39
  )
44
40
 
45
- http_adapter = Rager.config.http_adapter
46
- response = http_adapter.make_request(request)
41
+ response = Rager.config.http_adapter.make_request(request)
47
42
  response_body = T.cast(T.must(response.body), String)
48
43
 
49
- if response.status != 200
50
- raise Rager::Errors::HttpError.new(
51
- http_adapter,
52
- response.status,
53
- response_body
54
- )
55
- end
44
+ raise Rager::Errors::HttpError.new(Rager.config.http_adapter, response.status, response_body) if response.status != 200
56
45
 
57
46
  parsed_response = JSON.parse(response_body)
58
-
59
47
  results = parsed_response["results"] || []
60
- results.map do |result|
61
- index = result["index"]
62
- score = result["relevance_score"]
63
48
 
49
+ results.map do |result|
64
50
  Rager::Rerank::Result.new(
65
- score: score,
66
- index: index
51
+ score: result["relevance_score"],
52
+ index: result["index"]
67
53
  )
68
54
  end
69
55
  end
data/lib/rager/rerank.rb CHANGED
@@ -9,7 +9,7 @@ module Rager
9
9
 
10
10
  sig do
11
11
  params(
12
- query: Rager::Rerank::Query,
12
+ query: Rager::Types::RerankInput,
13
13
  options: Rager::Rerank::Options
14
14
  ).returns(Rager::Types::RerankOutput)
15
15
  end
data/lib/rager/result.rb CHANGED
@@ -2,6 +2,7 @@
2
2
  # frozen_string_literal: true
3
3
 
4
4
  require "json"
5
+ require "logger"
5
6
  require "securerandom"
6
7
  require "sorbet-runtime"
7
8
 
@@ -12,15 +13,21 @@ module Rager
12
13
  sig { returns(String) }
13
14
  attr_reader :id
14
15
 
16
+ sig { returns(String) }
17
+ attr_reader :context_id
18
+
15
19
  sig { returns(Rager::Operation) }
16
20
  attr_reader :operation
17
21
 
18
- sig { returns(Rager::Options) }
19
- attr_reader :options
20
-
21
22
  sig { returns(Rager::Types::Input) }
22
23
  attr_reader :input
23
24
 
25
+ sig { returns(T.nilable(Rager::Types::Output)) }
26
+ attr_reader :output
27
+
28
+ sig { returns(Rager::Options) }
29
+ attr_reader :options
30
+
24
31
  sig { returns(Integer) }
25
32
  attr_reader :start_time
26
33
 
@@ -28,57 +35,68 @@ module Rager
28
35
  attr_reader :end_time
29
36
 
30
37
  sig { returns(T.nilable(String)) }
31
- attr_reader :result_id
32
-
33
- sig { returns(T.nilable(String)) }
34
- attr_reader :context_id
38
+ attr_reader :name
35
39
 
36
40
  sig { returns(T.nilable(String)) }
37
- attr_reader :hash
41
+ attr_reader :context_name
38
42
 
39
43
  sig { returns(T.nilable(T::Array[String])) }
40
- attr_reader :input_ids
44
+ attr_reader :iids
45
+
46
+ sig { returns(T.nilable(String)) }
47
+ attr_reader :error
41
48
 
42
49
  sig do
43
50
  params(
44
51
  id: String,
52
+ context_id: String,
45
53
  operation: Rager::Operation,
46
- options: Rager::Options,
47
54
  input: Rager::Types::Input,
48
- output: Rager::Types::Output,
55
+ output: T.nilable(Rager::Types::Output),
56
+ options: Rager::Options,
49
57
  start_time: Integer,
50
58
  end_time: Integer,
51
- context_id: T.nilable(String),
52
- hash: T.nilable(String),
53
- input_ids: T.nilable(T::Array[String])
59
+ name: T.nilable(String),
60
+ context_name: T.nilable(String),
61
+ iids: T.nilable(T::Array[String]),
62
+ error: T.nilable(String)
54
63
  ).void
55
64
  end
56
65
  def initialize(
57
66
  id:,
67
+ context_id:,
58
68
  operation:,
59
- options:,
60
69
  input:,
61
70
  output:,
71
+ options:,
62
72
  start_time:,
63
73
  end_time:,
64
- context_id: nil,
65
- hash: nil,
66
- input_ids: nil
74
+ name: nil,
75
+ context_name: nil,
76
+ iids: nil,
77
+ error: nil
67
78
  )
68
79
  @id = id
80
+ @context_id = context_id
69
81
  @operation = operation
70
- @options = options
71
82
  @input = input
72
83
  @output = output
84
+ @options = options
73
85
  @start_time = start_time
74
86
  @end_time = end_time
87
+ @name = T.let(name, T.nilable(String))
88
+ @context_name = T.let(context_name, T.nilable(String))
89
+ @iids = T.let(iids, T.nilable(T::Array[String]))
90
+ @error = T.let(error, T.nilable(String))
91
+
75
92
  @stream = T.let(nil, T.nilable(Rager::Types::Stream))
76
93
  @buffer = T.let([], Rager::Types::Buffer)
77
94
  @consumed = T.let(false, T::Boolean)
78
- @result_id = T.let(SecureRandom.uuid, T.nilable(String))
79
- @context_id = T.let(context_id, T.nilable(String))
80
- @hash = T.let(hash, T.nilable(String))
81
- @input_ids = T.let(input_ids, T.nilable(T::Array[String]))
95
+ end
96
+
97
+ sig { returns(T::Boolean) }
98
+ def success?
99
+ @error.nil?
82
100
  end
83
101
 
84
102
  sig { returns(T::Boolean) }
@@ -88,11 +106,9 @@ module Rager
88
106
 
89
107
  sig { returns(Rager::Types::Output) }
90
108
  def out
91
- return @output unless stream?
109
+ return T.must(@output) unless stream?
92
110
  return @buffer.each if @consumed
93
111
 
94
- log
95
-
96
112
  @stream = Enumerator.new do |yielder|
97
113
  T.cast(@output, Rager::Types::Stream)
98
114
  .each { |message_delta|
@@ -157,49 +173,42 @@ module Rager
157
173
  end
158
174
  end
159
175
 
160
- sig { returns(Rager::Types::NonStreamOutput) }
176
+ sig { returns(T.nilable(Rager::Types::NonStreamOutput)) }
161
177
  def serialize_output
162
- if @consumed
163
- mat
164
- elsif stream?
165
- "[STREAM]"
166
- else
167
- T.cast(@output, Rager::Types::NonStreamOutput)
168
- end
178
+ return nil unless success?
179
+ return @consumed ? mat : "[STREAM]" if stream?
180
+ T.cast(@output, Rager::Types::NonStreamOutput)
169
181
  end
170
182
 
171
183
  sig { returns(T::Hash[String, T.untyped]) }
172
184
  def to_h
173
185
  {
174
186
  id: @id,
187
+ context_id: @context_id,
175
188
  operation: @operation.serialize,
176
- options: @options.serialize_safe,
177
189
  input: serialize_input,
178
- output: if @consumed
179
- mat
180
- elsif stream?
181
- "[STREAM]"
182
- else
183
- @output
184
- end,
190
+ output: serialize_output,
191
+ options: @options.serialize_safe,
185
192
  start_time: @start_time,
186
193
  end_time: @end_time,
187
- result_id: @result_id,
188
- context_id: @context_id,
189
- hash: @hash,
190
- input_ids: @input_ids
194
+ name: @name,
195
+ context_name: @context_name,
196
+ iids: @iids,
197
+ error: @error
191
198
  }
192
199
  end
193
200
 
194
201
  sig { void }
195
202
  def log
196
- return unless Rager.config.logger
203
+ return unless Rager.config.logger_type
197
204
 
198
205
  json = to_h.to_json
199
206
 
200
- case Rager.config.logger
207
+ logger = Rager.config.logger
208
+
209
+ case Rager.config.logger_type
201
210
  when Rager::Logger::Stdout
202
- puts "\nLOG: #{json}"
211
+ success? ? logger.info(json) : logger.error(json)
203
212
  when Rager::Logger::Remote
204
213
  http_adapter = Rager.config.http_adapter
205
214
  url = Rager.config.url
@@ -220,9 +229,7 @@ module Rager
220
229
 
221
230
  response = http_adapter.make_request(request)
222
231
 
223
- unless response.status >= 200 && response.status < 300
224
- warn "Remote log failed: \\#{response.status} \\#{response.body}"
225
- end
232
+ logger.warn "Remote log failed: \\#{response.status} \\#{response.body}" unless response.success?
226
233
  end
227
234
  end
228
235
  end
@@ -12,17 +12,6 @@ module Rager
12
12
  const :provider, String, default: "brave"
13
13
  const :n, T.nilable(Integer)
14
14
  const :api_key, T.nilable(String)
15
-
16
- sig { override.returns(T::Hash[String, T.untyped]) }
17
- def serialize_safe
18
- result = serialize
19
- result["api_key"] = "[REDACTED]" if result.key?("api_key")
20
- result
21
- end
22
-
23
- sig { override.void }
24
- def validate
25
- end
26
15
  end
27
16
  end
28
17
  end
@@ -21,54 +21,37 @@ module Rager
21
21
  api_key = options.api_key || ENV["BRAVE_API_KEY"]
22
22
  raise Rager::Errors::MissingCredentialsError.new("Brave", "BRAVE_API_KEY") if api_key.nil?
23
23
 
24
- params = {"q" => query}
25
- params["count"] = options.n.to_s if options.n
26
-
27
- query_string = URI.encode_www_form(params)
28
- url = "https://api.search.brave.com/res/v1/web/search?#{query_string}"
29
-
30
- headers = {
31
- "Accept" => "application/json",
32
- "Accept-Encoding" => "gzip",
33
- "x-subscription-token" => api_key
34
- }
24
+ params = {"q" => query}.tap do |p|
25
+ p["count"] = options.n.to_s if options.n
26
+ end
35
27
 
36
28
  request = Rager::Http::Request.new(
37
29
  verb: Rager::Http::Verb::Get,
38
- url: url,
39
- headers: headers
30
+ url: "https://api.search.brave.com/res/v1/web/search?#{URI.encode_www_form(params)}",
31
+ headers: {
32
+ "Accept" => "application/json",
33
+ "Accept-Encoding" => "gzip",
34
+ "x-subscription-token" => api_key
35
+ }
40
36
  )
41
37
 
42
- http_adapter = Rager.config.http_adapter
43
- response = http_adapter.make_request(request)
38
+ response = Rager.config.http_adapter.make_request(request)
44
39
  response_body = T.cast(T.must(response.body), String)
45
40
 
46
- if response.status != 200
47
- raise Rager::Errors::HttpError.new(
48
- http_adapter,
49
- response.status,
50
- response_body
51
- )
52
- end
41
+ raise Rager::Errors::HttpError.new(Rager.config.http_adapter, response.status, response_body) if response.status != 200
53
42
 
54
43
  parsed_response = JSON.parse(response_body)
44
+ web_results = parsed_response.dig("web", "results") || []
55
45
 
56
- results = []
57
-
58
- if parsed_response["web"] && parsed_response["web"]["results"]
59
- web_results = parsed_response["web"]["results"]
60
- web_results.each do |result|
61
- next unless result["title"] && result["url"] && result["description"]
46
+ web_results.filter_map do |result|
47
+ next unless result["title"] && result["url"] && result["description"]
62
48
 
63
- results << Rager::Search::Result.new(
64
- title: result["title"],
65
- url: result["url"],
66
- description: result["description"]
67
- )
68
- end
49
+ Rager::Search::Result.new(
50
+ title: result["title"],
51
+ url: result["url"],
52
+ description: result["description"]
53
+ )
69
54
  end
70
-
71
- results
72
55
  end
73
56
  end
74
57
  end
data/lib/rager/search.rb CHANGED
@@ -9,7 +9,7 @@ module Rager
9
9
 
10
10
  sig do
11
11
  params(
12
- query: String,
12
+ query: Rager::Types::SearchInput,
13
13
  options: Rager::Search::Options
14
14
  ).returns(Rager::Types::SearchOutput)
15
15
  end
@@ -8,15 +8,6 @@ module Rager
8
8
  extend T::Sig
9
9
 
10
10
  const :provider, String, default: "erb"
11
-
12
- sig { override.returns(T::Hash[String, T.untyped]) }
13
- def serialize_safe
14
- serialize.transform_keys(&:to_s)
15
- end
16
-
17
- sig { override.void }
18
- def validate
19
- end
20
11
  end
21
12
  end
22
13
  end
@@ -9,7 +9,7 @@ module Rager
9
9
 
10
10
  sig do
11
11
  params(
12
- input: Rager::Template::Input,
12
+ input: Rager::Types::TemplateInput,
13
13
  options: Rager::Template::Options
14
14
  ).returns(Rager::Types::TemplateOutput)
15
15
  end
data/lib/rager/types.rb CHANGED
@@ -52,7 +52,6 @@ module Rager
52
52
  }
53
53
 
54
54
  ChatNonStreamOutput = T.type_alias { ChatNonStream }
55
-
56
55
  NonStreamOutput = T.type_alias {
57
56
  T.any(
58
57
  ChatNonStreamOutput,
@@ -0,0 +1,49 @@
1
+ # typed: strict
2
+ # frozen_string_literal: true
3
+
4
+ require "sorbet-runtime"
5
+ require "uri"
6
+
7
+ module Rager
8
+ module Utils
9
+ module Http
10
+ extend T::Sig
11
+
12
+ sig { params(url: String, headers: T.nilable(T::Hash[String, String])).returns(T.nilable(String)) }
13
+ def self.download_binary(url, headers = nil)
14
+ response = Rager.config.http_adapter.make_request(
15
+ Rager::Http::Request.new(url: url, headers: headers || {})
16
+ )
17
+
18
+ if response.success? && response.body
19
+ binary_data = T.cast(response.body, String)
20
+ binary_data.force_encoding(Encoding::BINARY)
21
+ binary_data
22
+ end
23
+ rescue
24
+ nil
25
+ end
26
+
27
+ sig { params(url: String, path: T.nilable(String), headers: T.nilable(T::Hash[String, String])).returns(String) }
28
+ def self.download_file(url, path = nil, headers = nil)
29
+ path ||= begin
30
+ uri_path = URI.parse(url).path
31
+ if uri_path.nil? || uri_path.empty?
32
+ "output"
33
+ else
34
+ filename = File.basename(uri_path)
35
+ filename.empty? ? "output" : filename
36
+ end
37
+ end
38
+
39
+ binary_data = download_binary(url, headers)
40
+ raise Rager::Errors::HttpError.new(Rager.config.http_adapter, 500, "Download failed") if binary_data.nil?
41
+
42
+ File.binwrite(path, binary_data)
43
+ path
44
+ rescue => e
45
+ raise Rager::Errors::HttpError.new(Rager.config.http_adapter, 500, "Download failed: #{e.message}")
46
+ end
47
+ end
48
+ end
49
+ end