rager 0.2.0 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +1 -1
- data/lib/rager/chat/providers/openai.rb +72 -105
- data/lib/rager/chat.rb +1 -1
- data/lib/rager/config.rb +5 -1
- data/lib/rager/context.rb +42 -22
- data/lib/rager/embed/options.rb +0 -11
- data/lib/rager/embed/providers/openai.rb +5 -17
- data/lib/rager/embed.rb +1 -1
- data/lib/rager/errors/http_error.rb +2 -2
- data/lib/rager/errors/options_error.rb +1 -1
- data/lib/rager/http/response.rb +7 -0
- data/lib/rager/image_gen/options.rb +1 -11
- data/lib/rager/image_gen/output_format.rb +18 -0
- data/lib/rager/image_gen/providers/replicate.rb +16 -20
- data/lib/rager/image_gen.rb +1 -1
- data/lib/rager/mesh_gen/options.rb +0 -11
- data/lib/rager/mesh_gen/providers/replicate.rb +12 -18
- data/lib/rager/mesh_gen.rb +1 -1
- data/lib/rager/options.rb +5 -3
- data/lib/rager/rerank/options.rb +0 -11
- data/lib/rager/rerank/providers/cohere.rb +13 -27
- data/lib/rager/rerank.rb +1 -1
- data/lib/rager/result.rb +50 -50
- data/lib/rager/search/options.rb +0 -11
- data/lib/rager/search/providers/brave.rb +19 -36
- data/lib/rager/search.rb +1 -1
- data/lib/rager/template/options.rb +0 -9
- data/lib/rager/template.rb +1 -1
- data/lib/rager/types.rb +0 -1
- data/lib/rager/utils/http.rb +39 -0
- data/lib/rager/utils/replicate.rb +72 -0
- data/lib/rager/version.rb +1 -1
- metadata +19 -2
@@ -15,12 +15,6 @@ module Rager
|
|
15
15
|
api_key = options.api_key || ENV["REPLICATE_API_KEY"]
|
16
16
|
raise Rager::Errors::MissingCredentialsError.new("Replicate", "REPLICATE_API_KEY") if api_key.nil?
|
17
17
|
|
18
|
-
headers = {
|
19
|
-
"Authorization" => "Bearer #{api_key}",
|
20
|
-
"Content-Type" => "application/json",
|
21
|
-
"Prefer" => "wait"
|
22
|
-
}
|
23
|
-
|
24
18
|
body = {
|
25
19
|
version: options.version,
|
26
20
|
input: {
|
@@ -31,29 +25,29 @@ module Rager
|
|
31
25
|
save_gaussian_ply: true,
|
32
26
|
ss_sampling_steps: 38
|
33
27
|
}
|
34
|
-
}
|
35
|
-
|
28
|
+
}.tap do |b|
|
29
|
+
b[:input][:seed] = options.seed if options.seed
|
30
|
+
end
|
36
31
|
|
37
32
|
request = Rager::Http::Request.new(
|
38
33
|
url: "https://api.replicate.com/v1/predictions",
|
39
34
|
verb: Rager::Http::Verb::Post,
|
40
|
-
headers:
|
35
|
+
headers: {
|
36
|
+
"Authorization" => "Bearer #{api_key}",
|
37
|
+
"Content-Type" => "application/json"
|
38
|
+
},
|
41
39
|
body: body.to_json
|
42
40
|
)
|
43
41
|
|
44
42
|
response = Rager.config.http_adapter.make_request(request)
|
45
43
|
response_body = T.cast(T.must(response.body), String)
|
46
44
|
|
47
|
-
raise Rager::Errors::HttpError.new(Rager.config.http_adapter, response.status, response_body) unless [200,
|
48
|
-
201, 202].include?(response.status)
|
45
|
+
raise Rager::Errors::HttpError.new(Rager.config.http_adapter, response.status, response_body) unless [200, 201, 202].include?(response.status)
|
49
46
|
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
rescue JSON::ParserError, KeyError => e
|
55
|
-
raise Rager::Errors::ParseError.new(e.message, response_body)
|
56
|
-
end
|
47
|
+
json = JSON.parse(response_body)
|
48
|
+
json.fetch("urls").fetch("get")
|
49
|
+
rescue JSON::ParserError, KeyError => e
|
50
|
+
raise Rager::Errors::ParseError.new(e.message, response_body)
|
57
51
|
end
|
58
52
|
end
|
59
53
|
end
|
data/lib/rager/mesh_gen.rb
CHANGED
data/lib/rager/options.rb
CHANGED
@@ -8,13 +8,15 @@ module Rager
|
|
8
8
|
extend T::Sig
|
9
9
|
extend T::Helpers
|
10
10
|
requires_ancestor { T::Struct }
|
11
|
-
interface!
|
12
11
|
|
13
|
-
sig {
|
12
|
+
sig { returns(T::Hash[String, T.untyped]) }
|
14
13
|
def serialize_safe
|
14
|
+
result = T.cast(self, T::Struct).serialize
|
15
|
+
result["api_key"] = "[REDACTED]" if result.key?("api_key")
|
16
|
+
result
|
15
17
|
end
|
16
18
|
|
17
|
-
sig {
|
19
|
+
sig { void }
|
18
20
|
def validate
|
19
21
|
end
|
20
22
|
end
|
data/lib/rager/rerank/options.rb
CHANGED
@@ -14,17 +14,6 @@ module Rager
|
|
14
14
|
const :model, T.nilable(String)
|
15
15
|
const :n, T.nilable(Integer)
|
16
16
|
const :api_key, T.nilable(String)
|
17
|
-
|
18
|
-
sig { override.returns(T::Hash[String, T.untyped]) }
|
19
|
-
def serialize_safe
|
20
|
-
result = serialize
|
21
|
-
result["api_key"] = "[REDACTED]" if result.key?("api_key")
|
22
|
-
result
|
23
|
-
end
|
24
|
-
|
25
|
-
sig { override.void }
|
26
|
-
def validate
|
27
|
-
end
|
28
17
|
end
|
29
18
|
end
|
30
19
|
end
|
@@ -20,50 +20,36 @@ module Rager
|
|
20
20
|
api_key = options.api_key || ENV["COHERE_API_KEY"]
|
21
21
|
raise Rager::Errors::MissingCredentialsError.new("Cohere", "COHERE_API_KEY") if api_key.nil?
|
22
22
|
|
23
|
-
url = options.url || ENV["COHERE_URL"] || "https://api.cohere.com/v2/rerank"
|
24
|
-
model = options.model || "rerank-v3.5"
|
25
|
-
|
26
|
-
headers = {
|
27
|
-
"Content-Type" => "application/json"
|
28
|
-
}
|
29
|
-
headers["Authorization"] = "Bearer #{api_key}" if api_key
|
30
|
-
|
31
23
|
body = {
|
32
|
-
model: model,
|
24
|
+
model: options.model || "rerank-v3.5",
|
33
25
|
query: query.query,
|
34
26
|
documents: query.documents
|
35
|
-
}
|
36
|
-
|
27
|
+
}.tap do |b|
|
28
|
+
b[:top_n] = options.n if options.n
|
29
|
+
end
|
30
|
+
|
31
|
+
headers = {"Content-Type" => "application/json"}
|
32
|
+
headers["Authorization"] = "Bearer #{api_key}" if api_key
|
37
33
|
|
38
34
|
request = Rager::Http::Request.new(
|
39
35
|
verb: Rager::Http::Verb::Post,
|
40
|
-
url: url,
|
36
|
+
url: options.url || ENV["COHERE_URL"] || "https://api.cohere.com/v2/rerank",
|
41
37
|
headers: headers,
|
42
38
|
body: body.to_json
|
43
39
|
)
|
44
40
|
|
45
|
-
|
46
|
-
response = http_adapter.make_request(request)
|
41
|
+
response = Rager.config.http_adapter.make_request(request)
|
47
42
|
response_body = T.cast(T.must(response.body), String)
|
48
43
|
|
49
|
-
if response.status != 200
|
50
|
-
raise Rager::Errors::HttpError.new(
|
51
|
-
http_adapter,
|
52
|
-
response.status,
|
53
|
-
response_body
|
54
|
-
)
|
55
|
-
end
|
44
|
+
raise Rager::Errors::HttpError.new(Rager.config.http_adapter, response.status, response_body) if response.status != 200
|
56
45
|
|
57
46
|
parsed_response = JSON.parse(response_body)
|
58
|
-
|
59
47
|
results = parsed_response["results"] || []
|
60
|
-
results.map do |result|
|
61
|
-
index = result["index"]
|
62
|
-
score = result["relevance_score"]
|
63
48
|
|
49
|
+
results.map do |result|
|
64
50
|
Rager::Rerank::Result.new(
|
65
|
-
score:
|
66
|
-
index: index
|
51
|
+
score: result["relevance_score"],
|
52
|
+
index: result["index"]
|
67
53
|
)
|
68
54
|
end
|
69
55
|
end
|
data/lib/rager/rerank.rb
CHANGED
data/lib/rager/result.rb
CHANGED
@@ -2,6 +2,7 @@
|
|
2
2
|
# frozen_string_literal: true
|
3
3
|
|
4
4
|
require "json"
|
5
|
+
require "logger"
|
5
6
|
require "securerandom"
|
6
7
|
require "sorbet-runtime"
|
7
8
|
|
@@ -12,6 +13,9 @@ module Rager
|
|
12
13
|
sig { returns(String) }
|
13
14
|
attr_reader :id
|
14
15
|
|
16
|
+
sig { returns(String) }
|
17
|
+
attr_reader :context_id
|
18
|
+
|
15
19
|
sig { returns(Rager::Operation) }
|
16
20
|
attr_reader :operation
|
17
21
|
|
@@ -21,6 +25,9 @@ module Rager
|
|
21
25
|
sig { returns(Rager::Types::Input) }
|
22
26
|
attr_reader :input
|
23
27
|
|
28
|
+
sig { returns(T.nilable(Rager::Types::Output)) }
|
29
|
+
attr_reader :output
|
30
|
+
|
24
31
|
sig { returns(Integer) }
|
25
32
|
attr_reader :start_time
|
26
33
|
|
@@ -28,57 +35,62 @@ module Rager
|
|
28
35
|
attr_reader :end_time
|
29
36
|
|
30
37
|
sig { returns(T.nilable(String)) }
|
31
|
-
attr_reader :
|
38
|
+
attr_reader :name
|
32
39
|
|
33
|
-
sig { returns(T.nilable(String)) }
|
34
|
-
attr_reader :
|
40
|
+
sig { returns(T.nilable(T::Array[String])) }
|
41
|
+
attr_reader :iids
|
35
42
|
|
36
43
|
sig { returns(T.nilable(String)) }
|
37
|
-
attr_reader :
|
38
|
-
|
39
|
-
sig { returns(T.nilable(T::Array[String])) }
|
40
|
-
attr_reader :input_ids
|
44
|
+
attr_reader :error
|
41
45
|
|
42
46
|
sig do
|
43
47
|
params(
|
44
48
|
id: String,
|
49
|
+
context_id: String,
|
45
50
|
operation: Rager::Operation,
|
46
|
-
options: Rager::Options,
|
47
51
|
input: Rager::Types::Input,
|
48
|
-
output: Rager::Types::Output,
|
52
|
+
output: T.nilable(Rager::Types::Output),
|
53
|
+
options: Rager::Options,
|
49
54
|
start_time: Integer,
|
50
55
|
end_time: Integer,
|
51
|
-
|
52
|
-
|
53
|
-
|
56
|
+
name: T.nilable(String),
|
57
|
+
iids: T.nilable(T::Array[String]),
|
58
|
+
error: T.nilable(String)
|
54
59
|
).void
|
55
60
|
end
|
56
61
|
def initialize(
|
57
62
|
id:,
|
63
|
+
context_id:,
|
58
64
|
operation:,
|
59
|
-
options:,
|
60
65
|
input:,
|
61
66
|
output:,
|
67
|
+
options:,
|
62
68
|
start_time:,
|
63
69
|
end_time:,
|
64
|
-
|
65
|
-
|
66
|
-
|
70
|
+
name: nil,
|
71
|
+
iids: nil,
|
72
|
+
error: nil
|
67
73
|
)
|
68
74
|
@id = id
|
75
|
+
@context_id = context_id
|
69
76
|
@operation = operation
|
70
|
-
@options = options
|
71
77
|
@input = input
|
72
78
|
@output = output
|
79
|
+
@options = options
|
73
80
|
@start_time = start_time
|
74
81
|
@end_time = end_time
|
82
|
+
@name = T.let(name, T.nilable(String))
|
83
|
+
@iids = T.let(iids, T.nilable(T::Array[String]))
|
84
|
+
@error = T.let(error, T.nilable(String))
|
85
|
+
|
75
86
|
@stream = T.let(nil, T.nilable(Rager::Types::Stream))
|
76
87
|
@buffer = T.let([], Rager::Types::Buffer)
|
77
88
|
@consumed = T.let(false, T::Boolean)
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
89
|
+
end
|
90
|
+
|
91
|
+
sig { returns(T::Boolean) }
|
92
|
+
def success?
|
93
|
+
@error.nil?
|
82
94
|
end
|
83
95
|
|
84
96
|
sig { returns(T::Boolean) }
|
@@ -88,11 +100,9 @@ module Rager
|
|
88
100
|
|
89
101
|
sig { returns(Rager::Types::Output) }
|
90
102
|
def out
|
91
|
-
return @output unless stream?
|
103
|
+
return T.must(@output) unless stream?
|
92
104
|
return @buffer.each if @consumed
|
93
105
|
|
94
|
-
log
|
95
|
-
|
96
106
|
@stream = Enumerator.new do |yielder|
|
97
107
|
T.cast(@output, Rager::Types::Stream)
|
98
108
|
.each { |message_delta|
|
@@ -157,49 +167,41 @@ module Rager
|
|
157
167
|
end
|
158
168
|
end
|
159
169
|
|
160
|
-
sig { returns(Rager::Types::NonStreamOutput) }
|
170
|
+
sig { returns(T.nilable(Rager::Types::NonStreamOutput)) }
|
161
171
|
def serialize_output
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
"[STREAM]"
|
166
|
-
else
|
167
|
-
T.cast(@output, Rager::Types::NonStreamOutput)
|
168
|
-
end
|
172
|
+
return nil unless success?
|
173
|
+
return @consumed ? mat : "[STREAM]" if stream?
|
174
|
+
T.cast(@output, Rager::Types::NonStreamOutput)
|
169
175
|
end
|
170
176
|
|
171
177
|
sig { returns(T::Hash[String, T.untyped]) }
|
172
178
|
def to_h
|
173
179
|
{
|
174
180
|
id: @id,
|
181
|
+
context_id: @context_id,
|
175
182
|
operation: @operation.serialize,
|
176
|
-
options: @options.serialize_safe,
|
177
183
|
input: serialize_input,
|
178
|
-
output:
|
179
|
-
|
180
|
-
elsif stream?
|
181
|
-
"[STREAM]"
|
182
|
-
else
|
183
|
-
@output
|
184
|
-
end,
|
184
|
+
output: serialize_output,
|
185
|
+
options: @options.serialize_safe,
|
185
186
|
start_time: @start_time,
|
186
187
|
end_time: @end_time,
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
input_ids: @input_ids
|
188
|
+
name: @name,
|
189
|
+
iids: @iids,
|
190
|
+
error: @error
|
191
191
|
}
|
192
192
|
end
|
193
193
|
|
194
194
|
sig { void }
|
195
195
|
def log
|
196
|
-
return unless Rager.config.
|
196
|
+
return unless Rager.config.logger_type
|
197
197
|
|
198
198
|
json = to_h.to_json
|
199
199
|
|
200
|
-
|
200
|
+
logger = Rager.config.logger
|
201
|
+
|
202
|
+
case Rager.config.logger_type
|
201
203
|
when Rager::Logger::Stdout
|
202
|
-
|
204
|
+
success? ? logger.info(json) : logger.error(json)
|
203
205
|
when Rager::Logger::Remote
|
204
206
|
http_adapter = Rager.config.http_adapter
|
205
207
|
url = Rager.config.url
|
@@ -220,9 +222,7 @@ module Rager
|
|
220
222
|
|
221
223
|
response = http_adapter.make_request(request)
|
222
224
|
|
223
|
-
|
224
|
-
warn "Remote log failed: \\#{response.status} \\#{response.body}"
|
225
|
-
end
|
225
|
+
logger.warn "Remote log failed: \\#{response.status} \\#{response.body}" unless response.success?
|
226
226
|
end
|
227
227
|
end
|
228
228
|
end
|
data/lib/rager/search/options.rb
CHANGED
@@ -12,17 +12,6 @@ module Rager
|
|
12
12
|
const :provider, String, default: "brave"
|
13
13
|
const :n, T.nilable(Integer)
|
14
14
|
const :api_key, T.nilable(String)
|
15
|
-
|
16
|
-
sig { override.returns(T::Hash[String, T.untyped]) }
|
17
|
-
def serialize_safe
|
18
|
-
result = serialize
|
19
|
-
result["api_key"] = "[REDACTED]" if result.key?("api_key")
|
20
|
-
result
|
21
|
-
end
|
22
|
-
|
23
|
-
sig { override.void }
|
24
|
-
def validate
|
25
|
-
end
|
26
15
|
end
|
27
16
|
end
|
28
17
|
end
|
@@ -21,54 +21,37 @@ module Rager
|
|
21
21
|
api_key = options.api_key || ENV["BRAVE_API_KEY"]
|
22
22
|
raise Rager::Errors::MissingCredentialsError.new("Brave", "BRAVE_API_KEY") if api_key.nil?
|
23
23
|
|
24
|
-
params = {"q" => query}
|
25
|
-
|
26
|
-
|
27
|
-
query_string = URI.encode_www_form(params)
|
28
|
-
url = "https://api.search.brave.com/res/v1/web/search?#{query_string}"
|
29
|
-
|
30
|
-
headers = {
|
31
|
-
"Accept" => "application/json",
|
32
|
-
"Accept-Encoding" => "gzip",
|
33
|
-
"x-subscription-token" => api_key
|
34
|
-
}
|
24
|
+
params = {"q" => query}.tap do |p|
|
25
|
+
p["count"] = options.n.to_s if options.n
|
26
|
+
end
|
35
27
|
|
36
28
|
request = Rager::Http::Request.new(
|
37
29
|
verb: Rager::Http::Verb::Get,
|
38
|
-
url:
|
39
|
-
headers:
|
30
|
+
url: "https://api.search.brave.com/res/v1/web/search?#{URI.encode_www_form(params)}",
|
31
|
+
headers: {
|
32
|
+
"Accept" => "application/json",
|
33
|
+
"Accept-Encoding" => "gzip",
|
34
|
+
"x-subscription-token" => api_key
|
35
|
+
}
|
40
36
|
)
|
41
37
|
|
42
|
-
|
43
|
-
response = http_adapter.make_request(request)
|
38
|
+
response = Rager.config.http_adapter.make_request(request)
|
44
39
|
response_body = T.cast(T.must(response.body), String)
|
45
40
|
|
46
|
-
if response.status != 200
|
47
|
-
raise Rager::Errors::HttpError.new(
|
48
|
-
http_adapter,
|
49
|
-
response.status,
|
50
|
-
response_body
|
51
|
-
)
|
52
|
-
end
|
41
|
+
raise Rager::Errors::HttpError.new(Rager.config.http_adapter, response.status, response_body) if response.status != 200
|
53
42
|
|
54
43
|
parsed_response = JSON.parse(response_body)
|
44
|
+
web_results = parsed_response.dig("web", "results") || []
|
55
45
|
|
56
|
-
|
57
|
-
|
58
|
-
if parsed_response["web"] && parsed_response["web"]["results"]
|
59
|
-
web_results = parsed_response["web"]["results"]
|
60
|
-
web_results.each do |result|
|
61
|
-
next unless result["title"] && result["url"] && result["description"]
|
46
|
+
web_results.filter_map do |result|
|
47
|
+
next unless result["title"] && result["url"] && result["description"]
|
62
48
|
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
end
|
49
|
+
Rager::Search::Result.new(
|
50
|
+
title: result["title"],
|
51
|
+
url: result["url"],
|
52
|
+
description: result["description"]
|
53
|
+
)
|
69
54
|
end
|
70
|
-
|
71
|
-
results
|
72
55
|
end
|
73
56
|
end
|
74
57
|
end
|
data/lib/rager/search.rb
CHANGED
@@ -8,15 +8,6 @@ module Rager
|
|
8
8
|
extend T::Sig
|
9
9
|
|
10
10
|
const :provider, String, default: "erb"
|
11
|
-
|
12
|
-
sig { override.returns(T::Hash[String, T.untyped]) }
|
13
|
-
def serialize_safe
|
14
|
-
serialize.transform_keys(&:to_s)
|
15
|
-
end
|
16
|
-
|
17
|
-
sig { override.void }
|
18
|
-
def validate
|
19
|
-
end
|
20
11
|
end
|
21
12
|
end
|
22
13
|
end
|
data/lib/rager/template.rb
CHANGED
data/lib/rager/types.rb
CHANGED
@@ -0,0 +1,39 @@
|
|
1
|
+
# typed: strict
|
2
|
+
# frozen_string_literal: true
|
3
|
+
|
4
|
+
require "sorbet-runtime"
|
5
|
+
require "uri"
|
6
|
+
|
7
|
+
module Rager
|
8
|
+
module Utils
|
9
|
+
module Http
|
10
|
+
extend T::Sig
|
11
|
+
|
12
|
+
sig { params(url: String, path: T.nilable(String), headers: T.nilable(T::Hash[String, String])).returns(String) }
|
13
|
+
def self.download_file(url, path = nil, headers = nil)
|
14
|
+
path ||= begin
|
15
|
+
uri_path = URI.parse(url).path
|
16
|
+
if uri_path.nil? || uri_path.empty?
|
17
|
+
"output"
|
18
|
+
else
|
19
|
+
filename = File.basename(uri_path)
|
20
|
+
filename.empty? ? "output" : filename
|
21
|
+
end
|
22
|
+
end
|
23
|
+
|
24
|
+
response = Rager.config.http_adapter.make_request(
|
25
|
+
Rager::Http::Request.new(url: url, headers: headers || {})
|
26
|
+
)
|
27
|
+
|
28
|
+
raise Rager::Errors::HttpError.new(Rager.config.http_adapter, response.status, T.cast(response.body, T.nilable(String))) unless response.success?
|
29
|
+
|
30
|
+
File.binwrite(path, T.cast(T.must(response.body), String))
|
31
|
+
path
|
32
|
+
rescue Rager::Errors::HttpError
|
33
|
+
raise
|
34
|
+
rescue => e
|
35
|
+
raise Rager::Errors::HttpError.new(Rager.config.http_adapter, 500, "Download failed: #{e.message}")
|
36
|
+
end
|
37
|
+
end
|
38
|
+
end
|
39
|
+
end
|
@@ -0,0 +1,72 @@
|
|
1
|
+
# typed: strict
|
2
|
+
# frozen_string_literal: true
|
3
|
+
|
4
|
+
require "sorbet-runtime"
|
5
|
+
require "json"
|
6
|
+
|
7
|
+
module Rager
|
8
|
+
module Utils
|
9
|
+
module Replicate
|
10
|
+
extend T::Sig
|
11
|
+
|
12
|
+
sig { params(prediction_url: String, key: T.nilable(String), path: T.nilable(String)).returns(T.nilable(String)) }
|
13
|
+
def self.download_prediction(prediction_url, key = nil, path = nil)
|
14
|
+
api_key = ENV["REPLICATE_API_KEY"]
|
15
|
+
raise Rager::Errors::MissingCredentialsError.new("Replicate", "REPLICATE_API_KEY") if api_key.nil?
|
16
|
+
|
17
|
+
response = Rager.config.http_adapter.make_request(
|
18
|
+
Rager::Http::Request.new(
|
19
|
+
url: prediction_url,
|
20
|
+
headers: {"Authorization" => "Bearer #{api_key}", "Content-Type" => "application/json"}
|
21
|
+
)
|
22
|
+
)
|
23
|
+
|
24
|
+
raise Rager::Errors::HttpError.new(Rager.config.http_adapter, response.status, T.cast(response.body, T.nilable(String))) unless response.success?
|
25
|
+
|
26
|
+
data = JSON.parse(T.cast(T.must(response.body), String))
|
27
|
+
return nil if ["starting", "processing"].include?(data["status"])
|
28
|
+
|
29
|
+
if data["status"] == "failed"
|
30
|
+
error_msg = data["error"] || "Prediction failed"
|
31
|
+
raise Rager::Errors::HttpError.new(Rager.config.http_adapter, 422, "Prediction failed: #{error_msg}")
|
32
|
+
end
|
33
|
+
|
34
|
+
return nil unless data["status"] == "succeeded"
|
35
|
+
|
36
|
+
output = data["output"]
|
37
|
+
return nil if output.nil?
|
38
|
+
|
39
|
+
download_url = if key && output.is_a?(Hash) && output[key]
|
40
|
+
output[key]
|
41
|
+
elsif output.is_a?(Hash)
|
42
|
+
output.values.compact.first
|
43
|
+
elsif output.is_a?(Array) && !output.empty?
|
44
|
+
output.first
|
45
|
+
else
|
46
|
+
output.to_s
|
47
|
+
end
|
48
|
+
|
49
|
+
return nil if download_url.nil? || download_url.empty?
|
50
|
+
|
51
|
+
Rager::Utils::Http.download_file(download_url, path)
|
52
|
+
rescue JSON::ParserError => e
|
53
|
+
raise Rager::Errors::ParseError.new("Failed to parse prediction response", e.message)
|
54
|
+
rescue Rager::Errors::HttpError, Rager::Errors::ParseError, Rager::Errors::MissingCredentialsError
|
55
|
+
raise
|
56
|
+
rescue
|
57
|
+
nil
|
58
|
+
end
|
59
|
+
|
60
|
+
sig { params(prediction_url: String, key: T.nilable(String), path: T.nilable(String), max_attempts: Integer, sleep_interval: Integer).returns(T.nilable(String)) }
|
61
|
+
def self.poll_prediction(prediction_url, key: nil, path: nil, max_attempts: 30, sleep_interval: 10)
|
62
|
+
max_attempts.times do
|
63
|
+
result = download_prediction(prediction_url, key, path)
|
64
|
+
return result unless result.nil?
|
65
|
+
sleep(sleep_interval)
|
66
|
+
end
|
67
|
+
|
68
|
+
raise Rager::Errors::HttpError.new(Rager.config.http_adapter, 408, "Prediction polling timed out after #{max_attempts} attempts")
|
69
|
+
end
|
70
|
+
end
|
71
|
+
end
|
72
|
+
end
|
data/lib/rager/version.rb
CHANGED