rager 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/LICENSE.md +21 -0
- data/README.md +23 -0
- data/lib/rager/chat/message.rb +13 -0
- data/lib/rager/chat/message_content.rb +14 -0
- data/lib/rager/chat/message_content_image_type.rb +16 -0
- data/lib/rager/chat/message_content_type.rb +16 -0
- data/lib/rager/chat/message_delta.rb +13 -0
- data/lib/rager/chat/message_role.rb +16 -0
- data/lib/rager/chat/options.rb +52 -0
- data/lib/rager/chat/providers/abstract.rb +25 -0
- data/lib/rager/chat/providers/openai.rb +196 -0
- data/lib/rager/chat/schema.rb +48 -0
- data/lib/rager/chat.rb +35 -0
- data/lib/rager/config.rb +30 -0
- data/lib/rager/context.rb +116 -0
- data/lib/rager/error.rb +7 -0
- data/lib/rager/errors/http_error.rb +19 -0
- data/lib/rager/errors/missing_credentials_error.rb +19 -0
- data/lib/rager/errors/options_error.rb +19 -0
- data/lib/rager/errors/parse_error.rb +19 -0
- data/lib/rager/errors/unknown_provider_error.rb +17 -0
- data/lib/rager/http/adapters/abstract.rb +20 -0
- data/lib/rager/http/adapters/async_http.rb +65 -0
- data/lib/rager/http/adapters/mock.rb +138 -0
- data/lib/rager/http/request.rb +15 -0
- data/lib/rager/http/response.rb +14 -0
- data/lib/rager/http/verb.rb +20 -0
- data/lib/rager/image_gen/options.rb +29 -0
- data/lib/rager/image_gen/providers/abstract.rb +25 -0
- data/lib/rager/image_gen/providers/replicate.rb +55 -0
- data/lib/rager/image_gen.rb +31 -0
- data/lib/rager/logger.rb +13 -0
- data/lib/rager/mesh_gen/options.rb +30 -0
- data/lib/rager/mesh_gen/providers/abstract.rb +25 -0
- data/lib/rager/mesh_gen/providers/replicate.rb +61 -0
- data/lib/rager/mesh_gen.rb +31 -0
- data/lib/rager/operation.rb +14 -0
- data/lib/rager/options.rb +21 -0
- data/lib/rager/result.rb +198 -0
- data/lib/rager/types.rb +45 -0
- data/lib/rager/version.rb +6 -0
- data/lib/rager.rb +35 -0
- metadata +123 -0
@@ -0,0 +1,19 @@
|
|
1
|
+
# typed: strict
|
2
|
+
# frozen_string_literal: true
|
3
|
+
|
4
|
+
require "sorbet-runtime"
|
5
|
+
|
6
|
+
module Rager
|
7
|
+
module Errors
|
8
|
+
class MissingCredentialsError < Rager::Error
|
9
|
+
extend T::Sig
|
10
|
+
|
11
|
+
sig { params(provider_name: String, env_var: T.nilable(String)).void }
|
12
|
+
def initialize(provider_name, env_var = nil)
|
13
|
+
message = "Missing credentials for provider #{provider_name}"
|
14
|
+
message += " -- attempted lookup with #{env_var}" if env_var
|
15
|
+
super(message)
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
19
|
+
end
|
@@ -0,0 +1,19 @@
|
|
1
|
+
# typed: strict
|
2
|
+
# frozen_string_literal: true
|
3
|
+
|
4
|
+
require "sorbet-runtime"
|
5
|
+
|
6
|
+
module Rager
|
7
|
+
module Errors
|
8
|
+
class OptionsError < Rager::Error
|
9
|
+
extend T::Sig
|
10
|
+
|
11
|
+
sig { params(invalid_keys: T::Array[String], description: T.nilable(String)).void }
|
12
|
+
def initialize(invalid_keys:, description: nil)
|
13
|
+
message = "Invalid keys #{invalid_keys.join(", ")}"
|
14
|
+
message += " -- #{description})" if description
|
15
|
+
super(message)
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
19
|
+
end
|
@@ -0,0 +1,19 @@
|
|
1
|
+
# typed: strict
|
2
|
+
# frozen_string_literal: true
|
3
|
+
|
4
|
+
require "sorbet-runtime"
|
5
|
+
|
6
|
+
module Rager
|
7
|
+
module Errors
|
8
|
+
class ParseError < Rager::Error
|
9
|
+
extend T::Sig
|
10
|
+
|
11
|
+
sig { params(description: String, body: T.nilable(String)).void }
|
12
|
+
def initialize(description, body = nil)
|
13
|
+
message = description
|
14
|
+
message += " -- #{body}" if body
|
15
|
+
super(message)
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
19
|
+
end
|
@@ -0,0 +1,17 @@
|
|
1
|
+
# typed: strict
|
2
|
+
# frozen_string_literal: true
|
3
|
+
|
4
|
+
require "sorbet-runtime"
|
5
|
+
|
6
|
+
module Rager
|
7
|
+
module Errors
|
8
|
+
class UnknownProviderError < Rager::Error
|
9
|
+
extend T::Sig
|
10
|
+
|
11
|
+
sig { params(operation: Rager::Operation, key: String).void }
|
12
|
+
def initialize(operation, key)
|
13
|
+
super("Unknown provider #{key} for operation #{operation.serialize}")
|
14
|
+
end
|
15
|
+
end
|
16
|
+
end
|
17
|
+
end
|
@@ -0,0 +1,20 @@
|
|
1
|
+
# typed: strict
|
2
|
+
# frozen_string_literal: true
|
3
|
+
|
4
|
+
require "sorbet-runtime"
|
5
|
+
|
6
|
+
module Rager
|
7
|
+
module Http
|
8
|
+
module Adapters
|
9
|
+
class Abstract
|
10
|
+
extend T::Sig
|
11
|
+
extend T::Helpers
|
12
|
+
abstract!
|
13
|
+
|
14
|
+
sig { abstract.params(request: Rager::Http::Request).returns(Rager::Http::Response) }
|
15
|
+
def make_request(request)
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
19
|
+
end
|
20
|
+
end
|
@@ -0,0 +1,65 @@
|
|
1
|
+
# typed: strict
|
2
|
+
# frozen_string_literal: true
|
3
|
+
|
4
|
+
require "sorbet-runtime"
|
5
|
+
|
6
|
+
module Rager
|
7
|
+
module Http
|
8
|
+
module Adapters
|
9
|
+
class AsyncHttp < Rager::Http::Adapters::Abstract
|
10
|
+
extend T::Sig
|
11
|
+
|
12
|
+
sig { void }
|
13
|
+
def initialize
|
14
|
+
require "async/http"
|
15
|
+
|
16
|
+
@internet = T.let(Async::HTTP::Internet.new, Async::HTTP::Internet)
|
17
|
+
end
|
18
|
+
|
19
|
+
sig {
|
20
|
+
override.params(
|
21
|
+
request: Rager::Http::Request
|
22
|
+
).returns(Rager::Http::Response)
|
23
|
+
}
|
24
|
+
def make_request(request)
|
25
|
+
response = @internet.call(
|
26
|
+
request.verb.serialize,
|
27
|
+
request.url,
|
28
|
+
request.headers.to_a,
|
29
|
+
request.body
|
30
|
+
)
|
31
|
+
|
32
|
+
body = if response.body.nil?
|
33
|
+
nil
|
34
|
+
elsif (response.headers["Transfer-Encoding"]&.downcase == "chunked") ||
|
35
|
+
response.headers["content-type"]&.downcase&.include?("text/event-stream")
|
36
|
+
body_enum(response)
|
37
|
+
else
|
38
|
+
response.body.read
|
39
|
+
end
|
40
|
+
|
41
|
+
Response.new(
|
42
|
+
status: response.status,
|
43
|
+
headers: response.headers.to_h,
|
44
|
+
body: body
|
45
|
+
)
|
46
|
+
end
|
47
|
+
|
48
|
+
private
|
49
|
+
|
50
|
+
sig {
|
51
|
+
params(
|
52
|
+
response: Async::HTTP::Protocol::Response
|
53
|
+
).returns(T::Enumerator[String])
|
54
|
+
}
|
55
|
+
def body_enum(response)
|
56
|
+
Enumerator.new do |yielder|
|
57
|
+
response.body.each { |chunk| yielder << chunk.force_encoding("UTF-8") }
|
58
|
+
ensure
|
59
|
+
response.close
|
60
|
+
end
|
61
|
+
end
|
62
|
+
end
|
63
|
+
end
|
64
|
+
end
|
65
|
+
end
|
@@ -0,0 +1,138 @@
|
|
1
|
+
# typed: strict
|
2
|
+
# frozen_string_literal: true
|
3
|
+
|
4
|
+
require "fileutils"
|
5
|
+
require "json"
|
6
|
+
require "sorbet-runtime"
|
7
|
+
|
8
|
+
module Rager
|
9
|
+
module Http
|
10
|
+
module Adapters
|
11
|
+
class Mock < Rager::Http::Adapters::Abstract
|
12
|
+
extend T::Sig
|
13
|
+
|
14
|
+
Cache = T.type_alias { T::Hash[String, T::Hash[String, T.untyped]] }
|
15
|
+
|
16
|
+
sig do
|
17
|
+
params(
|
18
|
+
test_file_path: String,
|
19
|
+
fallback_adapter: T.nilable(Rager::Http::Adapters::Abstract),
|
20
|
+
chunk_delimiter: T.nilable(String)
|
21
|
+
).void
|
22
|
+
end
|
23
|
+
def initialize(
|
24
|
+
test_file_path,
|
25
|
+
fallback_adapter = nil,
|
26
|
+
chunk_delimiter = nil
|
27
|
+
)
|
28
|
+
@test_file_path = T.let(test_file_path, String)
|
29
|
+
@fallback_adapter = T.let(fallback_adapter || Rager::Http::Adapters::AsyncHttp.new, Rager::Http::Adapters::Abstract)
|
30
|
+
@cache = T.let(load_cache, Cache)
|
31
|
+
end
|
32
|
+
|
33
|
+
sig { override.params(request: Rager::Http::Request).returns(Rager::Http::Response) }
|
34
|
+
def make_request(request)
|
35
|
+
key = request.serialize.to_json
|
36
|
+
cached_entry = @cache[key]
|
37
|
+
if cached_entry
|
38
|
+
build_response_from_cache(cached_entry)
|
39
|
+
else
|
40
|
+
fetch_and_cache_response(request, key)
|
41
|
+
end
|
42
|
+
end
|
43
|
+
|
44
|
+
sig {
|
45
|
+
params(
|
46
|
+
request: Rager::Http::Request,
|
47
|
+
key: String
|
48
|
+
).returns(Rager::Http::Response)
|
49
|
+
}
|
50
|
+
def fetch_and_cache_response(request, key)
|
51
|
+
response = @fallback_adapter.make_request(request)
|
52
|
+
|
53
|
+
serialized_response = T.let({
|
54
|
+
"status" => response.status,
|
55
|
+
"headers" => response.headers
|
56
|
+
}, T::Hash[String, T.untyped])
|
57
|
+
|
58
|
+
if response.body.is_a?(Enumerator)
|
59
|
+
chunks = T.let([], T::Array[String])
|
60
|
+
|
61
|
+
T.cast(response.body, T::Enumerator[String]).each do |chunk|
|
62
|
+
chunks << chunk
|
63
|
+
end
|
64
|
+
|
65
|
+
serialized_response["body"] = chunks
|
66
|
+
serialized_response["is_stream"] = true
|
67
|
+
|
68
|
+
response_body = Enumerator.new do |yielder|
|
69
|
+
chunks.each { |chunk| yielder << chunk }
|
70
|
+
end
|
71
|
+
else
|
72
|
+
serialized_response["body"] = response.body
|
73
|
+
serialized_response["is_stream"] = false
|
74
|
+
response_body = response.body
|
75
|
+
end
|
76
|
+
|
77
|
+
@cache[key] = serialized_response
|
78
|
+
|
79
|
+
save_cache
|
80
|
+
|
81
|
+
Rager::Http::Response.new(
|
82
|
+
status: response.status,
|
83
|
+
headers: response.headers,
|
84
|
+
body: response_body
|
85
|
+
)
|
86
|
+
end
|
87
|
+
|
88
|
+
sig {
|
89
|
+
params(
|
90
|
+
entry: T::Hash[String, T.untyped]
|
91
|
+
).returns(Rager::Http::Response)
|
92
|
+
}
|
93
|
+
def build_response_from_cache(entry)
|
94
|
+
body = entry["body"]
|
95
|
+
is_stream = entry["is_stream"] || false
|
96
|
+
|
97
|
+
body = if is_stream
|
98
|
+
T.cast(body, T::Array[String]).to_enum
|
99
|
+
else
|
100
|
+
body
|
101
|
+
end
|
102
|
+
|
103
|
+
Rager::Http::Response.new(
|
104
|
+
status: T.cast(entry["status"], Integer),
|
105
|
+
headers: T.cast(entry["headers"], T::Hash[String, String]),
|
106
|
+
body: body
|
107
|
+
)
|
108
|
+
end
|
109
|
+
|
110
|
+
sig { void }
|
111
|
+
def create_file_if_not_exists
|
112
|
+
FileUtils.mkdir_p(File.dirname(@test_file_path))
|
113
|
+
File.write(@test_file_path, "{}") unless File.exist?(@test_file_path)
|
114
|
+
end
|
115
|
+
|
116
|
+
sig { returns(Cache) }
|
117
|
+
def load_cache
|
118
|
+
create_file_if_not_exists
|
119
|
+
JSON.parse(File.read(@test_file_path, encoding: "UTF-8"))
|
120
|
+
end
|
121
|
+
|
122
|
+
sig { void }
|
123
|
+
def save_cache
|
124
|
+
current_file_content =
|
125
|
+
if File.exist?(@test_file_path)
|
126
|
+
JSON.parse(File.read(@test_file_path, encoding: "UTF-8"))
|
127
|
+
else
|
128
|
+
{}
|
129
|
+
end
|
130
|
+
|
131
|
+
output = current_file_content.merge(@cache)
|
132
|
+
json_string = JSON.generate(output).force_encoding("UTF-8")
|
133
|
+
File.write(@test_file_path, json_string, encoding: "UTF-8")
|
134
|
+
end
|
135
|
+
end
|
136
|
+
end
|
137
|
+
end
|
138
|
+
end
|
@@ -0,0 +1,15 @@
|
|
1
|
+
# typed: strict
|
2
|
+
# frozen_string_literal: true
|
3
|
+
|
4
|
+
require "sorbet-runtime"
|
5
|
+
|
6
|
+
module Rager
|
7
|
+
module Http
|
8
|
+
class Request < T::Struct
|
9
|
+
const :url, String
|
10
|
+
const :verb, Rager::Http::Verb, default: Rager::Http::Verb::Get
|
11
|
+
const :headers, T::Hash[String, String]
|
12
|
+
const :body, T.nilable(String)
|
13
|
+
end
|
14
|
+
end
|
15
|
+
end
|
@@ -0,0 +1,14 @@
|
|
1
|
+
# typed: strict
|
2
|
+
# frozen_string_literal: true
|
3
|
+
|
4
|
+
require "sorbet-runtime"
|
5
|
+
|
6
|
+
module Rager
|
7
|
+
module Http
|
8
|
+
class Response < T::Struct
|
9
|
+
const :status, Integer
|
10
|
+
const :headers, T::Hash[String, T.any(String, T::Array[String])]
|
11
|
+
const :body, T.nilable(T.any(String, T::Enumerator[String]))
|
12
|
+
end
|
13
|
+
end
|
14
|
+
end
|
@@ -0,0 +1,20 @@
|
|
1
|
+
# typed: strict
|
2
|
+
# frozen_string_literal: true
|
3
|
+
|
4
|
+
require "sorbet-runtime"
|
5
|
+
|
6
|
+
module Rager
|
7
|
+
module Http
|
8
|
+
class Verb < T::Enum
|
9
|
+
enums do
|
10
|
+
Get = new("GET")
|
11
|
+
Post = new("POST")
|
12
|
+
Put = new("PUT")
|
13
|
+
Patch = new("PATCH")
|
14
|
+
Delete = new("DELETE")
|
15
|
+
Head = new("HEAD")
|
16
|
+
Options = new("OPTIONS")
|
17
|
+
end
|
18
|
+
end
|
19
|
+
end
|
20
|
+
end
|
@@ -0,0 +1,29 @@
|
|
1
|
+
# typed: strict
|
2
|
+
# frozen_string_literal: true
|
3
|
+
|
4
|
+
require "sorbet-runtime"
|
5
|
+
|
6
|
+
module Rager
|
7
|
+
module ImageGen
|
8
|
+
class Options < T::Struct
|
9
|
+
extend T::Sig
|
10
|
+
include Rager::Options
|
11
|
+
|
12
|
+
const :provider, String, default: "replicate"
|
13
|
+
const :model, String, default: "black-forest-labs/flux-schnell"
|
14
|
+
const :api_key, T.nilable(String)
|
15
|
+
const :seed, T.nilable(Integer)
|
16
|
+
|
17
|
+
sig { override.returns(T::Hash[String, T.untyped]) }
|
18
|
+
def serialize_safe
|
19
|
+
result = serialize
|
20
|
+
result["api_key"] = "[REDACTED]" if result.key?("api_key")
|
21
|
+
result
|
22
|
+
end
|
23
|
+
|
24
|
+
sig { override.void }
|
25
|
+
def validate
|
26
|
+
end
|
27
|
+
end
|
28
|
+
end
|
29
|
+
end
|
@@ -0,0 +1,25 @@
|
|
1
|
+
# typed: strict
|
2
|
+
# frozen_string_literal: true
|
3
|
+
|
4
|
+
require "sorbet-runtime"
|
5
|
+
|
6
|
+
module Rager
|
7
|
+
module ImageGen
|
8
|
+
module Providers
|
9
|
+
class Abstract
|
10
|
+
extend T::Sig
|
11
|
+
extend T::Helpers
|
12
|
+
abstract!
|
13
|
+
|
14
|
+
sig do
|
15
|
+
abstract.params(
|
16
|
+
prompt: String,
|
17
|
+
options: Rager::ImageGen::Options
|
18
|
+
).returns(Rager::Types::ImageGenOutput)
|
19
|
+
end
|
20
|
+
def image_gen(prompt, options)
|
21
|
+
end
|
22
|
+
end
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
@@ -0,0 +1,55 @@
|
|
1
|
+
# typed: strict
|
2
|
+
# frozen_string_literal: true
|
3
|
+
|
4
|
+
require "sorbet-runtime"
|
5
|
+
require "json"
|
6
|
+
|
7
|
+
module Rager
|
8
|
+
module ImageGen
|
9
|
+
module Providers
|
10
|
+
class Replicate < Rager::ImageGen::Providers::Abstract
|
11
|
+
extend T::Sig
|
12
|
+
|
13
|
+
sig { override.params(prompt: String, options: Rager::ImageGen::Options).returns(Rager::Types::ImageGenOutput) }
|
14
|
+
def image_gen(prompt, options)
|
15
|
+
url = "https://api.replicate.com/v1/models/#{options.model}/predictions"
|
16
|
+
api_key = options.api_key || ENV["REPLICATE_API_KEY"]
|
17
|
+
raise Rager::Errors::MissingCredentialsError.new("Replicate", "REPLICATE_API_KEY") if api_key.nil?
|
18
|
+
|
19
|
+
headers = {
|
20
|
+
"Authorization" => "Bearer #{api_key}",
|
21
|
+
"Content-Type" => "application/json",
|
22
|
+
"Prefer" => "wait"
|
23
|
+
}
|
24
|
+
|
25
|
+
body = {
|
26
|
+
input: {
|
27
|
+
prompt: prompt
|
28
|
+
}
|
29
|
+
}
|
30
|
+
body[:input][:seed] = options.seed unless options.seed.nil?
|
31
|
+
|
32
|
+
request = Rager::Http::Request.new(
|
33
|
+
url: url,
|
34
|
+
verb: Rager::Http::Verb::Post,
|
35
|
+
headers: headers,
|
36
|
+
body: body.to_json
|
37
|
+
)
|
38
|
+
|
39
|
+
http_adapter = Rager.config.http_adapter
|
40
|
+
response = http_adapter.make_request(request)
|
41
|
+
response_body = T.cast(T.must(response.body), String)
|
42
|
+
|
43
|
+
raise Rager::Errors::HttpError.new(http_adapter, response.status, response_body) unless [200, 201].include?(response.status)
|
44
|
+
|
45
|
+
begin
|
46
|
+
parsed = JSON.parse(response_body)
|
47
|
+
parsed.fetch("output").first
|
48
|
+
rescue JSON::ParserError, KeyError => e
|
49
|
+
raise Rager::Errors::ParseError.new(e.message, response_body)
|
50
|
+
end
|
51
|
+
end
|
52
|
+
end
|
53
|
+
end
|
54
|
+
end
|
55
|
+
end
|
@@ -0,0 +1,31 @@
|
|
1
|
+
# typed: strict
|
2
|
+
# frozen_string_literal: true
|
3
|
+
|
4
|
+
require "sorbet-runtime"
|
5
|
+
|
6
|
+
module Rager
|
7
|
+
module ImageGen
|
8
|
+
extend T::Sig
|
9
|
+
|
10
|
+
sig do
|
11
|
+
params(
|
12
|
+
prompt: String,
|
13
|
+
options: Rager::ImageGen::Options
|
14
|
+
).returns(Rager::Types::ImageGenOutput)
|
15
|
+
end
|
16
|
+
def self.image_gen(prompt, options = Rager::ImageGen::Options.new)
|
17
|
+
provider = get_provider(options.provider)
|
18
|
+
provider.image_gen(prompt, options)
|
19
|
+
end
|
20
|
+
|
21
|
+
sig { params(key: String).returns(Rager::ImageGen::Providers::Abstract) }
|
22
|
+
def self.get_provider(key)
|
23
|
+
case key.downcase
|
24
|
+
when "replicate"
|
25
|
+
Rager::ImageGen::Providers::Replicate.new
|
26
|
+
else
|
27
|
+
raise Rager::Errors::UnknownProviderError.new(Rager::Operation::ImageGen, key)
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
31
|
+
end
|
data/lib/rager/logger.rb
ADDED
@@ -0,0 +1,30 @@
|
|
1
|
+
# typed: strict
|
2
|
+
# frozen_string_literal: true
|
3
|
+
|
4
|
+
require "sorbet-runtime"
|
5
|
+
|
6
|
+
module Rager
|
7
|
+
module MeshGen
|
8
|
+
class Options < T::Struct
|
9
|
+
extend T::Sig
|
10
|
+
include Rager::Options
|
11
|
+
|
12
|
+
const :provider, String, default: "replicate"
|
13
|
+
const :model, String, default: "firoz/trellis"
|
14
|
+
const :version, String, default: "4876f2a8da1c544772dffa32e8889da4a1bab3a1f5c1937bfcfccb99ae347251"
|
15
|
+
const :api_key, T.nilable(String)
|
16
|
+
const :seed, T.nilable(Integer)
|
17
|
+
|
18
|
+
sig { override.returns(T::Hash[String, T.untyped]) }
|
19
|
+
def serialize_safe
|
20
|
+
result = serialize
|
21
|
+
result["api_key"] = "[REDACTED]" if result.key?("api_key")
|
22
|
+
result
|
23
|
+
end
|
24
|
+
|
25
|
+
sig { override.void }
|
26
|
+
def validate
|
27
|
+
end
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
@@ -0,0 +1,25 @@
|
|
1
|
+
# typed: strict
|
2
|
+
# frozen_string_literal: true
|
3
|
+
|
4
|
+
require "sorbet-runtime"
|
5
|
+
|
6
|
+
module Rager
|
7
|
+
module MeshGen
|
8
|
+
module Providers
|
9
|
+
class Abstract
|
10
|
+
extend T::Sig
|
11
|
+
extend T::Helpers
|
12
|
+
abstract!
|
13
|
+
|
14
|
+
sig do
|
15
|
+
abstract.params(
|
16
|
+
image_url: String,
|
17
|
+
options: Rager::MeshGen::Options
|
18
|
+
).returns(Rager::Types::MeshGenOutput)
|
19
|
+
end
|
20
|
+
def mesh_gen(image_url, options)
|
21
|
+
end
|
22
|
+
end
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
@@ -0,0 +1,61 @@
|
|
1
|
+
# typed: strict
|
2
|
+
# frozen_string_literal: true
|
3
|
+
|
4
|
+
require "sorbet-runtime"
|
5
|
+
require "json"
|
6
|
+
|
7
|
+
module Rager
|
8
|
+
module MeshGen
|
9
|
+
module Providers
|
10
|
+
class Replicate < Rager::MeshGen::Providers::Abstract
|
11
|
+
extend T::Sig
|
12
|
+
|
13
|
+
sig { override.params(image_url: String, options: Rager::MeshGen::Options).returns(Rager::Types::MeshGenOutput) }
|
14
|
+
def mesh_gen(image_url, options)
|
15
|
+
api_key = options.api_key || ENV["REPLICATE_API_KEY"]
|
16
|
+
raise Rager::Errors::MissingCredentialsError.new("Replicate", "REPLICATE_API_KEY") if api_key.nil?
|
17
|
+
|
18
|
+
headers = {
|
19
|
+
"Authorization" => "Bearer #{api_key}",
|
20
|
+
"Content-Type" => "application/json",
|
21
|
+
"Prefer" => "wait"
|
22
|
+
}
|
23
|
+
|
24
|
+
body = {
|
25
|
+
version: options.version,
|
26
|
+
input: {
|
27
|
+
images: [image_url],
|
28
|
+
texture_size: 2048,
|
29
|
+
mesh_simplify: 0.9,
|
30
|
+
generate_model: true,
|
31
|
+
save_gaussian_ply: true,
|
32
|
+
ss_sampling_steps: 38
|
33
|
+
}
|
34
|
+
}
|
35
|
+
body[:input][:seed] = options.seed unless options.seed.nil?
|
36
|
+
|
37
|
+
request = Rager::Http::Request.new(
|
38
|
+
url: "https://api.replicate.com/v1/predictions",
|
39
|
+
verb: Rager::Http::Verb::Post,
|
40
|
+
headers: headers,
|
41
|
+
body: body.to_json
|
42
|
+
)
|
43
|
+
|
44
|
+
response = Rager.config.http_adapter.make_request(request)
|
45
|
+
response_body = T.cast(T.must(response.body), String)
|
46
|
+
|
47
|
+
raise Rager::Errors::HttpError.new(Rager.config.http_adapter, response.status, response_body) unless [200,
|
48
|
+
201, 202].include?(response.status)
|
49
|
+
|
50
|
+
begin
|
51
|
+
json = JSON.parse(response_body)
|
52
|
+
model_file = json.dig("output", "model_file")
|
53
|
+
model_file || json.fetch("urls").fetch("get")
|
54
|
+
rescue JSON::ParserError, KeyError => e
|
55
|
+
raise Rager::Errors::ParseError.new(e.message, response_body)
|
56
|
+
end
|
57
|
+
end
|
58
|
+
end
|
59
|
+
end
|
60
|
+
end
|
61
|
+
end
|
@@ -0,0 +1,31 @@
|
|
1
|
+
# typed: strict
|
2
|
+
# frozen_string_literal: true
|
3
|
+
|
4
|
+
require "sorbet-runtime"
|
5
|
+
|
6
|
+
module Rager
|
7
|
+
module MeshGen
|
8
|
+
extend T::Sig
|
9
|
+
|
10
|
+
sig do
|
11
|
+
params(
|
12
|
+
image_url: String,
|
13
|
+
options: Rager::MeshGen::Options
|
14
|
+
).returns(Rager::Types::MeshGenOutput)
|
15
|
+
end
|
16
|
+
def self.mesh_gen(image_url, options = Rager::MeshGen::Options.new)
|
17
|
+
provider = get_provider(options.provider)
|
18
|
+
provider.mesh_gen(image_url, options)
|
19
|
+
end
|
20
|
+
|
21
|
+
sig { params(key: String).returns(Rager::MeshGen::Providers::Abstract) }
|
22
|
+
def self.get_provider(key)
|
23
|
+
case key.downcase
|
24
|
+
when "replicate"
|
25
|
+
Rager::MeshGen::Providers::Replicate.new
|
26
|
+
else
|
27
|
+
raise Rager::Errors::UnknownProviderError.new(Rager::Operation::MeshGen, key)
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
31
|
+
end
|