rager 0.2.0 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +1 -1
- data/lib/rager/chat/providers/openai.rb +72 -105
- data/lib/rager/chat.rb +1 -1
- data/lib/rager/config.rb +5 -1
- data/lib/rager/context.rb +42 -22
- data/lib/rager/embed/options.rb +0 -11
- data/lib/rager/embed/providers/openai.rb +5 -17
- data/lib/rager/embed.rb +1 -1
- data/lib/rager/errors/http_error.rb +2 -2
- data/lib/rager/errors/options_error.rb +1 -1
- data/lib/rager/http/response.rb +7 -0
- data/lib/rager/image_gen/options.rb +1 -11
- data/lib/rager/image_gen/output_format.rb +18 -0
- data/lib/rager/image_gen/providers/replicate.rb +16 -20
- data/lib/rager/image_gen.rb +1 -1
- data/lib/rager/mesh_gen/options.rb +0 -11
- data/lib/rager/mesh_gen/providers/replicate.rb +12 -18
- data/lib/rager/mesh_gen.rb +1 -1
- data/lib/rager/options.rb +5 -3
- data/lib/rager/rerank/options.rb +0 -11
- data/lib/rager/rerank/providers/cohere.rb +13 -27
- data/lib/rager/rerank.rb +1 -1
- data/lib/rager/result.rb +50 -50
- data/lib/rager/search/options.rb +0 -11
- data/lib/rager/search/providers/brave.rb +19 -36
- data/lib/rager/search.rb +1 -1
- data/lib/rager/template/options.rb +0 -9
- data/lib/rager/template.rb +1 -1
- data/lib/rager/types.rb +0 -1
- data/lib/rager/utils/http.rb +39 -0
- data/lib/rager/utils/replicate.rb +72 -0
- data/lib/rager/version.rb +1 -1
- metadata +19 -2
|
@@ -15,12 +15,6 @@ module Rager
|
|
|
15
15
|
api_key = options.api_key || ENV["REPLICATE_API_KEY"]
|
|
16
16
|
raise Rager::Errors::MissingCredentialsError.new("Replicate", "REPLICATE_API_KEY") if api_key.nil?
|
|
17
17
|
|
|
18
|
-
headers = {
|
|
19
|
-
"Authorization" => "Bearer #{api_key}",
|
|
20
|
-
"Content-Type" => "application/json",
|
|
21
|
-
"Prefer" => "wait"
|
|
22
|
-
}
|
|
23
|
-
|
|
24
18
|
body = {
|
|
25
19
|
version: options.version,
|
|
26
20
|
input: {
|
|
@@ -31,29 +25,29 @@ module Rager
|
|
|
31
25
|
save_gaussian_ply: true,
|
|
32
26
|
ss_sampling_steps: 38
|
|
33
27
|
}
|
|
34
|
-
}
|
|
35
|
-
|
|
28
|
+
}.tap do |b|
|
|
29
|
+
b[:input][:seed] = options.seed if options.seed
|
|
30
|
+
end
|
|
36
31
|
|
|
37
32
|
request = Rager::Http::Request.new(
|
|
38
33
|
url: "https://api.replicate.com/v1/predictions",
|
|
39
34
|
verb: Rager::Http::Verb::Post,
|
|
40
|
-
headers:
|
|
35
|
+
headers: {
|
|
36
|
+
"Authorization" => "Bearer #{api_key}",
|
|
37
|
+
"Content-Type" => "application/json"
|
|
38
|
+
},
|
|
41
39
|
body: body.to_json
|
|
42
40
|
)
|
|
43
41
|
|
|
44
42
|
response = Rager.config.http_adapter.make_request(request)
|
|
45
43
|
response_body = T.cast(T.must(response.body), String)
|
|
46
44
|
|
|
47
|
-
raise Rager::Errors::HttpError.new(Rager.config.http_adapter, response.status, response_body) unless [200,
|
|
48
|
-
201, 202].include?(response.status)
|
|
45
|
+
raise Rager::Errors::HttpError.new(Rager.config.http_adapter, response.status, response_body) unless [200, 201, 202].include?(response.status)
|
|
49
46
|
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
rescue JSON::ParserError, KeyError => e
|
|
55
|
-
raise Rager::Errors::ParseError.new(e.message, response_body)
|
|
56
|
-
end
|
|
47
|
+
json = JSON.parse(response_body)
|
|
48
|
+
json.fetch("urls").fetch("get")
|
|
49
|
+
rescue JSON::ParserError, KeyError => e
|
|
50
|
+
raise Rager::Errors::ParseError.new(e.message, response_body)
|
|
57
51
|
end
|
|
58
52
|
end
|
|
59
53
|
end
|
data/lib/rager/mesh_gen.rb
CHANGED
data/lib/rager/options.rb
CHANGED
|
@@ -8,13 +8,15 @@ module Rager
|
|
|
8
8
|
extend T::Sig
|
|
9
9
|
extend T::Helpers
|
|
10
10
|
requires_ancestor { T::Struct }
|
|
11
|
-
interface!
|
|
12
11
|
|
|
13
|
-
sig {
|
|
12
|
+
sig { returns(T::Hash[String, T.untyped]) }
|
|
14
13
|
def serialize_safe
|
|
14
|
+
result = T.cast(self, T::Struct).serialize
|
|
15
|
+
result["api_key"] = "[REDACTED]" if result.key?("api_key")
|
|
16
|
+
result
|
|
15
17
|
end
|
|
16
18
|
|
|
17
|
-
sig {
|
|
19
|
+
sig { void }
|
|
18
20
|
def validate
|
|
19
21
|
end
|
|
20
22
|
end
|
data/lib/rager/rerank/options.rb
CHANGED
|
@@ -14,17 +14,6 @@ module Rager
|
|
|
14
14
|
const :model, T.nilable(String)
|
|
15
15
|
const :n, T.nilable(Integer)
|
|
16
16
|
const :api_key, T.nilable(String)
|
|
17
|
-
|
|
18
|
-
sig { override.returns(T::Hash[String, T.untyped]) }
|
|
19
|
-
def serialize_safe
|
|
20
|
-
result = serialize
|
|
21
|
-
result["api_key"] = "[REDACTED]" if result.key?("api_key")
|
|
22
|
-
result
|
|
23
|
-
end
|
|
24
|
-
|
|
25
|
-
sig { override.void }
|
|
26
|
-
def validate
|
|
27
|
-
end
|
|
28
17
|
end
|
|
29
18
|
end
|
|
30
19
|
end
|
|
@@ -20,50 +20,36 @@ module Rager
|
|
|
20
20
|
api_key = options.api_key || ENV["COHERE_API_KEY"]
|
|
21
21
|
raise Rager::Errors::MissingCredentialsError.new("Cohere", "COHERE_API_KEY") if api_key.nil?
|
|
22
22
|
|
|
23
|
-
url = options.url || ENV["COHERE_URL"] || "https://api.cohere.com/v2/rerank"
|
|
24
|
-
model = options.model || "rerank-v3.5"
|
|
25
|
-
|
|
26
|
-
headers = {
|
|
27
|
-
"Content-Type" => "application/json"
|
|
28
|
-
}
|
|
29
|
-
headers["Authorization"] = "Bearer #{api_key}" if api_key
|
|
30
|
-
|
|
31
23
|
body = {
|
|
32
|
-
model: model,
|
|
24
|
+
model: options.model || "rerank-v3.5",
|
|
33
25
|
query: query.query,
|
|
34
26
|
documents: query.documents
|
|
35
|
-
}
|
|
36
|
-
|
|
27
|
+
}.tap do |b|
|
|
28
|
+
b[:top_n] = options.n if options.n
|
|
29
|
+
end
|
|
30
|
+
|
|
31
|
+
headers = {"Content-Type" => "application/json"}
|
|
32
|
+
headers["Authorization"] = "Bearer #{api_key}" if api_key
|
|
37
33
|
|
|
38
34
|
request = Rager::Http::Request.new(
|
|
39
35
|
verb: Rager::Http::Verb::Post,
|
|
40
|
-
url: url,
|
|
36
|
+
url: options.url || ENV["COHERE_URL"] || "https://api.cohere.com/v2/rerank",
|
|
41
37
|
headers: headers,
|
|
42
38
|
body: body.to_json
|
|
43
39
|
)
|
|
44
40
|
|
|
45
|
-
|
|
46
|
-
response = http_adapter.make_request(request)
|
|
41
|
+
response = Rager.config.http_adapter.make_request(request)
|
|
47
42
|
response_body = T.cast(T.must(response.body), String)
|
|
48
43
|
|
|
49
|
-
if response.status != 200
|
|
50
|
-
raise Rager::Errors::HttpError.new(
|
|
51
|
-
http_adapter,
|
|
52
|
-
response.status,
|
|
53
|
-
response_body
|
|
54
|
-
)
|
|
55
|
-
end
|
|
44
|
+
raise Rager::Errors::HttpError.new(Rager.config.http_adapter, response.status, response_body) if response.status != 200
|
|
56
45
|
|
|
57
46
|
parsed_response = JSON.parse(response_body)
|
|
58
|
-
|
|
59
47
|
results = parsed_response["results"] || []
|
|
60
|
-
results.map do |result|
|
|
61
|
-
index = result["index"]
|
|
62
|
-
score = result["relevance_score"]
|
|
63
48
|
|
|
49
|
+
results.map do |result|
|
|
64
50
|
Rager::Rerank::Result.new(
|
|
65
|
-
score:
|
|
66
|
-
index: index
|
|
51
|
+
score: result["relevance_score"],
|
|
52
|
+
index: result["index"]
|
|
67
53
|
)
|
|
68
54
|
end
|
|
69
55
|
end
|
data/lib/rager/rerank.rb
CHANGED
data/lib/rager/result.rb
CHANGED
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
# frozen_string_literal: true
|
|
3
3
|
|
|
4
4
|
require "json"
|
|
5
|
+
require "logger"
|
|
5
6
|
require "securerandom"
|
|
6
7
|
require "sorbet-runtime"
|
|
7
8
|
|
|
@@ -12,6 +13,9 @@ module Rager
|
|
|
12
13
|
sig { returns(String) }
|
|
13
14
|
attr_reader :id
|
|
14
15
|
|
|
16
|
+
sig { returns(String) }
|
|
17
|
+
attr_reader :context_id
|
|
18
|
+
|
|
15
19
|
sig { returns(Rager::Operation) }
|
|
16
20
|
attr_reader :operation
|
|
17
21
|
|
|
@@ -21,6 +25,9 @@ module Rager
|
|
|
21
25
|
sig { returns(Rager::Types::Input) }
|
|
22
26
|
attr_reader :input
|
|
23
27
|
|
|
28
|
+
sig { returns(T.nilable(Rager::Types::Output)) }
|
|
29
|
+
attr_reader :output
|
|
30
|
+
|
|
24
31
|
sig { returns(Integer) }
|
|
25
32
|
attr_reader :start_time
|
|
26
33
|
|
|
@@ -28,57 +35,62 @@ module Rager
|
|
|
28
35
|
attr_reader :end_time
|
|
29
36
|
|
|
30
37
|
sig { returns(T.nilable(String)) }
|
|
31
|
-
attr_reader :
|
|
38
|
+
attr_reader :name
|
|
32
39
|
|
|
33
|
-
sig { returns(T.nilable(String)) }
|
|
34
|
-
attr_reader :
|
|
40
|
+
sig { returns(T.nilable(T::Array[String])) }
|
|
41
|
+
attr_reader :iids
|
|
35
42
|
|
|
36
43
|
sig { returns(T.nilable(String)) }
|
|
37
|
-
attr_reader :
|
|
38
|
-
|
|
39
|
-
sig { returns(T.nilable(T::Array[String])) }
|
|
40
|
-
attr_reader :input_ids
|
|
44
|
+
attr_reader :error
|
|
41
45
|
|
|
42
46
|
sig do
|
|
43
47
|
params(
|
|
44
48
|
id: String,
|
|
49
|
+
context_id: String,
|
|
45
50
|
operation: Rager::Operation,
|
|
46
|
-
options: Rager::Options,
|
|
47
51
|
input: Rager::Types::Input,
|
|
48
|
-
output: Rager::Types::Output,
|
|
52
|
+
output: T.nilable(Rager::Types::Output),
|
|
53
|
+
options: Rager::Options,
|
|
49
54
|
start_time: Integer,
|
|
50
55
|
end_time: Integer,
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
56
|
+
name: T.nilable(String),
|
|
57
|
+
iids: T.nilable(T::Array[String]),
|
|
58
|
+
error: T.nilable(String)
|
|
54
59
|
).void
|
|
55
60
|
end
|
|
56
61
|
def initialize(
|
|
57
62
|
id:,
|
|
63
|
+
context_id:,
|
|
58
64
|
operation:,
|
|
59
|
-
options:,
|
|
60
65
|
input:,
|
|
61
66
|
output:,
|
|
67
|
+
options:,
|
|
62
68
|
start_time:,
|
|
63
69
|
end_time:,
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
70
|
+
name: nil,
|
|
71
|
+
iids: nil,
|
|
72
|
+
error: nil
|
|
67
73
|
)
|
|
68
74
|
@id = id
|
|
75
|
+
@context_id = context_id
|
|
69
76
|
@operation = operation
|
|
70
|
-
@options = options
|
|
71
77
|
@input = input
|
|
72
78
|
@output = output
|
|
79
|
+
@options = options
|
|
73
80
|
@start_time = start_time
|
|
74
81
|
@end_time = end_time
|
|
82
|
+
@name = T.let(name, T.nilable(String))
|
|
83
|
+
@iids = T.let(iids, T.nilable(T::Array[String]))
|
|
84
|
+
@error = T.let(error, T.nilable(String))
|
|
85
|
+
|
|
75
86
|
@stream = T.let(nil, T.nilable(Rager::Types::Stream))
|
|
76
87
|
@buffer = T.let([], Rager::Types::Buffer)
|
|
77
88
|
@consumed = T.let(false, T::Boolean)
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
89
|
+
end
|
|
90
|
+
|
|
91
|
+
sig { returns(T::Boolean) }
|
|
92
|
+
def success?
|
|
93
|
+
@error.nil?
|
|
82
94
|
end
|
|
83
95
|
|
|
84
96
|
sig { returns(T::Boolean) }
|
|
@@ -88,11 +100,9 @@ module Rager
|
|
|
88
100
|
|
|
89
101
|
sig { returns(Rager::Types::Output) }
|
|
90
102
|
def out
|
|
91
|
-
return @output unless stream?
|
|
103
|
+
return T.must(@output) unless stream?
|
|
92
104
|
return @buffer.each if @consumed
|
|
93
105
|
|
|
94
|
-
log
|
|
95
|
-
|
|
96
106
|
@stream = Enumerator.new do |yielder|
|
|
97
107
|
T.cast(@output, Rager::Types::Stream)
|
|
98
108
|
.each { |message_delta|
|
|
@@ -157,49 +167,41 @@ module Rager
|
|
|
157
167
|
end
|
|
158
168
|
end
|
|
159
169
|
|
|
160
|
-
sig { returns(Rager::Types::NonStreamOutput) }
|
|
170
|
+
sig { returns(T.nilable(Rager::Types::NonStreamOutput)) }
|
|
161
171
|
def serialize_output
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
"[STREAM]"
|
|
166
|
-
else
|
|
167
|
-
T.cast(@output, Rager::Types::NonStreamOutput)
|
|
168
|
-
end
|
|
172
|
+
return nil unless success?
|
|
173
|
+
return @consumed ? mat : "[STREAM]" if stream?
|
|
174
|
+
T.cast(@output, Rager::Types::NonStreamOutput)
|
|
169
175
|
end
|
|
170
176
|
|
|
171
177
|
sig { returns(T::Hash[String, T.untyped]) }
|
|
172
178
|
def to_h
|
|
173
179
|
{
|
|
174
180
|
id: @id,
|
|
181
|
+
context_id: @context_id,
|
|
175
182
|
operation: @operation.serialize,
|
|
176
|
-
options: @options.serialize_safe,
|
|
177
183
|
input: serialize_input,
|
|
178
|
-
output:
|
|
179
|
-
|
|
180
|
-
elsif stream?
|
|
181
|
-
"[STREAM]"
|
|
182
|
-
else
|
|
183
|
-
@output
|
|
184
|
-
end,
|
|
184
|
+
output: serialize_output,
|
|
185
|
+
options: @options.serialize_safe,
|
|
185
186
|
start_time: @start_time,
|
|
186
187
|
end_time: @end_time,
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
input_ids: @input_ids
|
|
188
|
+
name: @name,
|
|
189
|
+
iids: @iids,
|
|
190
|
+
error: @error
|
|
191
191
|
}
|
|
192
192
|
end
|
|
193
193
|
|
|
194
194
|
sig { void }
|
|
195
195
|
def log
|
|
196
|
-
return unless Rager.config.
|
|
196
|
+
return unless Rager.config.logger_type
|
|
197
197
|
|
|
198
198
|
json = to_h.to_json
|
|
199
199
|
|
|
200
|
-
|
|
200
|
+
logger = Rager.config.logger
|
|
201
|
+
|
|
202
|
+
case Rager.config.logger_type
|
|
201
203
|
when Rager::Logger::Stdout
|
|
202
|
-
|
|
204
|
+
success? ? logger.info(json) : logger.error(json)
|
|
203
205
|
when Rager::Logger::Remote
|
|
204
206
|
http_adapter = Rager.config.http_adapter
|
|
205
207
|
url = Rager.config.url
|
|
@@ -220,9 +222,7 @@ module Rager
|
|
|
220
222
|
|
|
221
223
|
response = http_adapter.make_request(request)
|
|
222
224
|
|
|
223
|
-
|
|
224
|
-
warn "Remote log failed: \\#{response.status} \\#{response.body}"
|
|
225
|
-
end
|
|
225
|
+
logger.warn "Remote log failed: \\#{response.status} \\#{response.body}" unless response.success?
|
|
226
226
|
end
|
|
227
227
|
end
|
|
228
228
|
end
|
data/lib/rager/search/options.rb
CHANGED
|
@@ -12,17 +12,6 @@ module Rager
|
|
|
12
12
|
const :provider, String, default: "brave"
|
|
13
13
|
const :n, T.nilable(Integer)
|
|
14
14
|
const :api_key, T.nilable(String)
|
|
15
|
-
|
|
16
|
-
sig { override.returns(T::Hash[String, T.untyped]) }
|
|
17
|
-
def serialize_safe
|
|
18
|
-
result = serialize
|
|
19
|
-
result["api_key"] = "[REDACTED]" if result.key?("api_key")
|
|
20
|
-
result
|
|
21
|
-
end
|
|
22
|
-
|
|
23
|
-
sig { override.void }
|
|
24
|
-
def validate
|
|
25
|
-
end
|
|
26
15
|
end
|
|
27
16
|
end
|
|
28
17
|
end
|
|
@@ -21,54 +21,37 @@ module Rager
|
|
|
21
21
|
api_key = options.api_key || ENV["BRAVE_API_KEY"]
|
|
22
22
|
raise Rager::Errors::MissingCredentialsError.new("Brave", "BRAVE_API_KEY") if api_key.nil?
|
|
23
23
|
|
|
24
|
-
params = {"q" => query}
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
query_string = URI.encode_www_form(params)
|
|
28
|
-
url = "https://api.search.brave.com/res/v1/web/search?#{query_string}"
|
|
29
|
-
|
|
30
|
-
headers = {
|
|
31
|
-
"Accept" => "application/json",
|
|
32
|
-
"Accept-Encoding" => "gzip",
|
|
33
|
-
"x-subscription-token" => api_key
|
|
34
|
-
}
|
|
24
|
+
params = {"q" => query}.tap do |p|
|
|
25
|
+
p["count"] = options.n.to_s if options.n
|
|
26
|
+
end
|
|
35
27
|
|
|
36
28
|
request = Rager::Http::Request.new(
|
|
37
29
|
verb: Rager::Http::Verb::Get,
|
|
38
|
-
url:
|
|
39
|
-
headers:
|
|
30
|
+
url: "https://api.search.brave.com/res/v1/web/search?#{URI.encode_www_form(params)}",
|
|
31
|
+
headers: {
|
|
32
|
+
"Accept" => "application/json",
|
|
33
|
+
"Accept-Encoding" => "gzip",
|
|
34
|
+
"x-subscription-token" => api_key
|
|
35
|
+
}
|
|
40
36
|
)
|
|
41
37
|
|
|
42
|
-
|
|
43
|
-
response = http_adapter.make_request(request)
|
|
38
|
+
response = Rager.config.http_adapter.make_request(request)
|
|
44
39
|
response_body = T.cast(T.must(response.body), String)
|
|
45
40
|
|
|
46
|
-
if response.status != 200
|
|
47
|
-
raise Rager::Errors::HttpError.new(
|
|
48
|
-
http_adapter,
|
|
49
|
-
response.status,
|
|
50
|
-
response_body
|
|
51
|
-
)
|
|
52
|
-
end
|
|
41
|
+
raise Rager::Errors::HttpError.new(Rager.config.http_adapter, response.status, response_body) if response.status != 200
|
|
53
42
|
|
|
54
43
|
parsed_response = JSON.parse(response_body)
|
|
44
|
+
web_results = parsed_response.dig("web", "results") || []
|
|
55
45
|
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
if parsed_response["web"] && parsed_response["web"]["results"]
|
|
59
|
-
web_results = parsed_response["web"]["results"]
|
|
60
|
-
web_results.each do |result|
|
|
61
|
-
next unless result["title"] && result["url"] && result["description"]
|
|
46
|
+
web_results.filter_map do |result|
|
|
47
|
+
next unless result["title"] && result["url"] && result["description"]
|
|
62
48
|
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
end
|
|
49
|
+
Rager::Search::Result.new(
|
|
50
|
+
title: result["title"],
|
|
51
|
+
url: result["url"],
|
|
52
|
+
description: result["description"]
|
|
53
|
+
)
|
|
69
54
|
end
|
|
70
|
-
|
|
71
|
-
results
|
|
72
55
|
end
|
|
73
56
|
end
|
|
74
57
|
end
|
data/lib/rager/search.rb
CHANGED
|
@@ -8,15 +8,6 @@ module Rager
|
|
|
8
8
|
extend T::Sig
|
|
9
9
|
|
|
10
10
|
const :provider, String, default: "erb"
|
|
11
|
-
|
|
12
|
-
sig { override.returns(T::Hash[String, T.untyped]) }
|
|
13
|
-
def serialize_safe
|
|
14
|
-
serialize.transform_keys(&:to_s)
|
|
15
|
-
end
|
|
16
|
-
|
|
17
|
-
sig { override.void }
|
|
18
|
-
def validate
|
|
19
|
-
end
|
|
20
11
|
end
|
|
21
12
|
end
|
|
22
13
|
end
|
data/lib/rager/template.rb
CHANGED
data/lib/rager/types.rb
CHANGED
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
# typed: strict
|
|
2
|
+
# frozen_string_literal: true
|
|
3
|
+
|
|
4
|
+
require "sorbet-runtime"
|
|
5
|
+
require "uri"
|
|
6
|
+
|
|
7
|
+
module Rager
|
|
8
|
+
module Utils
|
|
9
|
+
module Http
|
|
10
|
+
extend T::Sig
|
|
11
|
+
|
|
12
|
+
sig { params(url: String, path: T.nilable(String), headers: T.nilable(T::Hash[String, String])).returns(String) }
|
|
13
|
+
def self.download_file(url, path = nil, headers = nil)
|
|
14
|
+
path ||= begin
|
|
15
|
+
uri_path = URI.parse(url).path
|
|
16
|
+
if uri_path.nil? || uri_path.empty?
|
|
17
|
+
"output"
|
|
18
|
+
else
|
|
19
|
+
filename = File.basename(uri_path)
|
|
20
|
+
filename.empty? ? "output" : filename
|
|
21
|
+
end
|
|
22
|
+
end
|
|
23
|
+
|
|
24
|
+
response = Rager.config.http_adapter.make_request(
|
|
25
|
+
Rager::Http::Request.new(url: url, headers: headers || {})
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
raise Rager::Errors::HttpError.new(Rager.config.http_adapter, response.status, T.cast(response.body, T.nilable(String))) unless response.success?
|
|
29
|
+
|
|
30
|
+
File.binwrite(path, T.cast(T.must(response.body), String))
|
|
31
|
+
path
|
|
32
|
+
rescue Rager::Errors::HttpError
|
|
33
|
+
raise
|
|
34
|
+
rescue => e
|
|
35
|
+
raise Rager::Errors::HttpError.new(Rager.config.http_adapter, 500, "Download failed: #{e.message}")
|
|
36
|
+
end
|
|
37
|
+
end
|
|
38
|
+
end
|
|
39
|
+
end
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
# typed: strict
|
|
2
|
+
# frozen_string_literal: true
|
|
3
|
+
|
|
4
|
+
require "sorbet-runtime"
|
|
5
|
+
require "json"
|
|
6
|
+
|
|
7
|
+
module Rager
|
|
8
|
+
module Utils
|
|
9
|
+
module Replicate
|
|
10
|
+
extend T::Sig
|
|
11
|
+
|
|
12
|
+
sig { params(prediction_url: String, key: T.nilable(String), path: T.nilable(String)).returns(T.nilable(String)) }
|
|
13
|
+
def self.download_prediction(prediction_url, key = nil, path = nil)
|
|
14
|
+
api_key = ENV["REPLICATE_API_KEY"]
|
|
15
|
+
raise Rager::Errors::MissingCredentialsError.new("Replicate", "REPLICATE_API_KEY") if api_key.nil?
|
|
16
|
+
|
|
17
|
+
response = Rager.config.http_adapter.make_request(
|
|
18
|
+
Rager::Http::Request.new(
|
|
19
|
+
url: prediction_url,
|
|
20
|
+
headers: {"Authorization" => "Bearer #{api_key}", "Content-Type" => "application/json"}
|
|
21
|
+
)
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
raise Rager::Errors::HttpError.new(Rager.config.http_adapter, response.status, T.cast(response.body, T.nilable(String))) unless response.success?
|
|
25
|
+
|
|
26
|
+
data = JSON.parse(T.cast(T.must(response.body), String))
|
|
27
|
+
return nil if ["starting", "processing"].include?(data["status"])
|
|
28
|
+
|
|
29
|
+
if data["status"] == "failed"
|
|
30
|
+
error_msg = data["error"] || "Prediction failed"
|
|
31
|
+
raise Rager::Errors::HttpError.new(Rager.config.http_adapter, 422, "Prediction failed: #{error_msg}")
|
|
32
|
+
end
|
|
33
|
+
|
|
34
|
+
return nil unless data["status"] == "succeeded"
|
|
35
|
+
|
|
36
|
+
output = data["output"]
|
|
37
|
+
return nil if output.nil?
|
|
38
|
+
|
|
39
|
+
download_url = if key && output.is_a?(Hash) && output[key]
|
|
40
|
+
output[key]
|
|
41
|
+
elsif output.is_a?(Hash)
|
|
42
|
+
output.values.compact.first
|
|
43
|
+
elsif output.is_a?(Array) && !output.empty?
|
|
44
|
+
output.first
|
|
45
|
+
else
|
|
46
|
+
output.to_s
|
|
47
|
+
end
|
|
48
|
+
|
|
49
|
+
return nil if download_url.nil? || download_url.empty?
|
|
50
|
+
|
|
51
|
+
Rager::Utils::Http.download_file(download_url, path)
|
|
52
|
+
rescue JSON::ParserError => e
|
|
53
|
+
raise Rager::Errors::ParseError.new("Failed to parse prediction response", e.message)
|
|
54
|
+
rescue Rager::Errors::HttpError, Rager::Errors::ParseError, Rager::Errors::MissingCredentialsError
|
|
55
|
+
raise
|
|
56
|
+
rescue
|
|
57
|
+
nil
|
|
58
|
+
end
|
|
59
|
+
|
|
60
|
+
sig { params(prediction_url: String, key: T.nilable(String), path: T.nilable(String), max_attempts: Integer, sleep_interval: Integer).returns(T.nilable(String)) }
|
|
61
|
+
def self.poll_prediction(prediction_url, key: nil, path: nil, max_attempts: 30, sleep_interval: 10)
|
|
62
|
+
max_attempts.times do
|
|
63
|
+
result = download_prediction(prediction_url, key, path)
|
|
64
|
+
return result unless result.nil?
|
|
65
|
+
sleep(sleep_interval)
|
|
66
|
+
end
|
|
67
|
+
|
|
68
|
+
raise Rager::Errors::HttpError.new(Rager.config.http_adapter, 408, "Prediction polling timed out after #{max_attempts} attempts")
|
|
69
|
+
end
|
|
70
|
+
end
|
|
71
|
+
end
|
|
72
|
+
end
|
data/lib/rager/version.rb
CHANGED