rager 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. checksums.yaml +7 -0
  2. data/LICENSE.md +21 -0
  3. data/README.md +23 -0
  4. data/lib/rager/chat/message.rb +13 -0
  5. data/lib/rager/chat/message_content.rb +14 -0
  6. data/lib/rager/chat/message_content_image_type.rb +16 -0
  7. data/lib/rager/chat/message_content_type.rb +16 -0
  8. data/lib/rager/chat/message_delta.rb +13 -0
  9. data/lib/rager/chat/message_role.rb +16 -0
  10. data/lib/rager/chat/options.rb +52 -0
  11. data/lib/rager/chat/providers/abstract.rb +25 -0
  12. data/lib/rager/chat/providers/openai.rb +196 -0
  13. data/lib/rager/chat/schema.rb +48 -0
  14. data/lib/rager/chat.rb +35 -0
  15. data/lib/rager/config.rb +30 -0
  16. data/lib/rager/context.rb +116 -0
  17. data/lib/rager/error.rb +7 -0
  18. data/lib/rager/errors/http_error.rb +19 -0
  19. data/lib/rager/errors/missing_credentials_error.rb +19 -0
  20. data/lib/rager/errors/options_error.rb +19 -0
  21. data/lib/rager/errors/parse_error.rb +19 -0
  22. data/lib/rager/errors/unknown_provider_error.rb +17 -0
  23. data/lib/rager/http/adapters/abstract.rb +20 -0
  24. data/lib/rager/http/adapters/async_http.rb +65 -0
  25. data/lib/rager/http/adapters/mock.rb +138 -0
  26. data/lib/rager/http/request.rb +15 -0
  27. data/lib/rager/http/response.rb +14 -0
  28. data/lib/rager/http/verb.rb +20 -0
  29. data/lib/rager/image_gen/options.rb +29 -0
  30. data/lib/rager/image_gen/providers/abstract.rb +25 -0
  31. data/lib/rager/image_gen/providers/replicate.rb +55 -0
  32. data/lib/rager/image_gen.rb +31 -0
  33. data/lib/rager/logger.rb +13 -0
  34. data/lib/rager/mesh_gen/options.rb +30 -0
  35. data/lib/rager/mesh_gen/providers/abstract.rb +25 -0
  36. data/lib/rager/mesh_gen/providers/replicate.rb +61 -0
  37. data/lib/rager/mesh_gen.rb +31 -0
  38. data/lib/rager/operation.rb +14 -0
  39. data/lib/rager/options.rb +21 -0
  40. data/lib/rager/result.rb +198 -0
  41. data/lib/rager/types.rb +45 -0
  42. data/lib/rager/version.rb +6 -0
  43. data/lib/rager.rb +35 -0
  44. metadata +123 -0
@@ -0,0 +1,19 @@
1
+ # typed: strict
2
+ # frozen_string_literal: true
3
+
4
+ require "sorbet-runtime"
5
+
6
+ module Rager
7
+ module Errors
8
+ class MissingCredentialsError < Rager::Error
9
+ extend T::Sig
10
+
11
+ sig { params(provider_name: String, env_var: T.nilable(String)).void }
12
+ def initialize(provider_name, env_var = nil)
13
+ message = "Missing credentials for provider #{provider_name}"
14
+ message += " -- attempted lookup with #{env_var}" if env_var
15
+ super(message)
16
+ end
17
+ end
18
+ end
19
+ end
@@ -0,0 +1,19 @@
1
+ # typed: strict
2
+ # frozen_string_literal: true
3
+
4
+ require "sorbet-runtime"
5
+
6
+ module Rager
7
+ module Errors
8
+ class OptionsError < Rager::Error
9
+ extend T::Sig
10
+
11
+ sig { params(invalid_keys: T::Array[String], description: T.nilable(String)).void }
12
+ def initialize(invalid_keys:, description: nil)
13
+ message = "Invalid keys #{invalid_keys.join(", ")}"
14
+ message += " -- #{description})" if description
15
+ super(message)
16
+ end
17
+ end
18
+ end
19
+ end
@@ -0,0 +1,19 @@
1
+ # typed: strict
2
+ # frozen_string_literal: true
3
+
4
+ require "sorbet-runtime"
5
+
6
+ module Rager
7
+ module Errors
8
+ class ParseError < Rager::Error
9
+ extend T::Sig
10
+
11
+ sig { params(description: String, body: T.nilable(String)).void }
12
+ def initialize(description, body = nil)
13
+ message = description
14
+ message += " -- #{body}" if body
15
+ super(message)
16
+ end
17
+ end
18
+ end
19
+ end
@@ -0,0 +1,17 @@
1
+ # typed: strict
2
+ # frozen_string_literal: true
3
+
4
+ require "sorbet-runtime"
5
+
6
+ module Rager
7
+ module Errors
8
+ class UnknownProviderError < Rager::Error
9
+ extend T::Sig
10
+
11
+ sig { params(operation: Rager::Operation, key: String).void }
12
+ def initialize(operation, key)
13
+ super("Unknown provider #{key} for operation #{operation.serialize}")
14
+ end
15
+ end
16
+ end
17
+ end
@@ -0,0 +1,20 @@
1
+ # typed: strict
2
+ # frozen_string_literal: true
3
+
4
+ require "sorbet-runtime"
5
+
6
+ module Rager
7
+ module Http
8
+ module Adapters
9
+ class Abstract
10
+ extend T::Sig
11
+ extend T::Helpers
12
+ abstract!
13
+
14
+ sig { abstract.params(request: Rager::Http::Request).returns(Rager::Http::Response) }
15
+ def make_request(request)
16
+ end
17
+ end
18
+ end
19
+ end
20
+ end
@@ -0,0 +1,65 @@
1
+ # typed: strict
2
+ # frozen_string_literal: true
3
+
4
+ require "sorbet-runtime"
5
+
6
+ module Rager
7
+ module Http
8
+ module Adapters
9
+ class AsyncHttp < Rager::Http::Adapters::Abstract
10
+ extend T::Sig
11
+
12
+ sig { void }
13
+ def initialize
14
+ require "async/http"
15
+
16
+ @internet = T.let(Async::HTTP::Internet.new, Async::HTTP::Internet)
17
+ end
18
+
19
+ sig {
20
+ override.params(
21
+ request: Rager::Http::Request
22
+ ).returns(Rager::Http::Response)
23
+ }
24
+ def make_request(request)
25
+ response = @internet.call(
26
+ request.verb.serialize,
27
+ request.url,
28
+ request.headers.to_a,
29
+ request.body
30
+ )
31
+
32
+ body = if response.body.nil?
33
+ nil
34
+ elsif (response.headers["Transfer-Encoding"]&.downcase == "chunked") ||
35
+ response.headers["content-type"]&.downcase&.include?("text/event-stream")
36
+ body_enum(response)
37
+ else
38
+ response.body.read
39
+ end
40
+
41
+ Response.new(
42
+ status: response.status,
43
+ headers: response.headers.to_h,
44
+ body: body
45
+ )
46
+ end
47
+
48
+ private
49
+
50
+ sig {
51
+ params(
52
+ response: Async::HTTP::Protocol::Response
53
+ ).returns(T::Enumerator[String])
54
+ }
55
+ def body_enum(response)
56
+ Enumerator.new do |yielder|
57
+ response.body.each { |chunk| yielder << chunk.force_encoding("UTF-8") }
58
+ ensure
59
+ response.close
60
+ end
61
+ end
62
+ end
63
+ end
64
+ end
65
+ end
@@ -0,0 +1,138 @@
1
+ # typed: strict
2
+ # frozen_string_literal: true
3
+
4
+ require "fileutils"
5
+ require "json"
6
+ require "sorbet-runtime"
7
+
8
+ module Rager
9
+ module Http
10
+ module Adapters
11
+ class Mock < Rager::Http::Adapters::Abstract
12
+ extend T::Sig
13
+
14
+ Cache = T.type_alias { T::Hash[String, T::Hash[String, T.untyped]] }
15
+
16
+ sig do
17
+ params(
18
+ test_file_path: String,
19
+ fallback_adapter: T.nilable(Rager::Http::Adapters::Abstract),
20
+ chunk_delimiter: T.nilable(String)
21
+ ).void
22
+ end
23
+ def initialize(
24
+ test_file_path,
25
+ fallback_adapter = nil,
26
+ chunk_delimiter = nil
27
+ )
28
+ @test_file_path = T.let(test_file_path, String)
29
+ @fallback_adapter = T.let(fallback_adapter || Rager::Http::Adapters::AsyncHttp.new, Rager::Http::Adapters::Abstract)
30
+ @cache = T.let(load_cache, Cache)
31
+ end
32
+
33
+ sig { override.params(request: Rager::Http::Request).returns(Rager::Http::Response) }
34
+ def make_request(request)
35
+ key = request.serialize.to_json
36
+ cached_entry = @cache[key]
37
+ if cached_entry
38
+ build_response_from_cache(cached_entry)
39
+ else
40
+ fetch_and_cache_response(request, key)
41
+ end
42
+ end
43
+
44
+ sig {
45
+ params(
46
+ request: Rager::Http::Request,
47
+ key: String
48
+ ).returns(Rager::Http::Response)
49
+ }
50
+ def fetch_and_cache_response(request, key)
51
+ response = @fallback_adapter.make_request(request)
52
+
53
+ serialized_response = T.let({
54
+ "status" => response.status,
55
+ "headers" => response.headers
56
+ }, T::Hash[String, T.untyped])
57
+
58
+ if response.body.is_a?(Enumerator)
59
+ chunks = T.let([], T::Array[String])
60
+
61
+ T.cast(response.body, T::Enumerator[String]).each do |chunk|
62
+ chunks << chunk
63
+ end
64
+
65
+ serialized_response["body"] = chunks
66
+ serialized_response["is_stream"] = true
67
+
68
+ response_body = Enumerator.new do |yielder|
69
+ chunks.each { |chunk| yielder << chunk }
70
+ end
71
+ else
72
+ serialized_response["body"] = response.body
73
+ serialized_response["is_stream"] = false
74
+ response_body = response.body
75
+ end
76
+
77
+ @cache[key] = serialized_response
78
+
79
+ save_cache
80
+
81
+ Rager::Http::Response.new(
82
+ status: response.status,
83
+ headers: response.headers,
84
+ body: response_body
85
+ )
86
+ end
87
+
88
+ sig {
89
+ params(
90
+ entry: T::Hash[String, T.untyped]
91
+ ).returns(Rager::Http::Response)
92
+ }
93
+ def build_response_from_cache(entry)
94
+ body = entry["body"]
95
+ is_stream = entry["is_stream"] || false
96
+
97
+ body = if is_stream
98
+ T.cast(body, T::Array[String]).to_enum
99
+ else
100
+ body
101
+ end
102
+
103
+ Rager::Http::Response.new(
104
+ status: T.cast(entry["status"], Integer),
105
+ headers: T.cast(entry["headers"], T::Hash[String, String]),
106
+ body: body
107
+ )
108
+ end
109
+
110
+ sig { void }
111
+ def create_file_if_not_exists
112
+ FileUtils.mkdir_p(File.dirname(@test_file_path))
113
+ File.write(@test_file_path, "{}") unless File.exist?(@test_file_path)
114
+ end
115
+
116
+ sig { returns(Cache) }
117
+ def load_cache
118
+ create_file_if_not_exists
119
+ JSON.parse(File.read(@test_file_path, encoding: "UTF-8"))
120
+ end
121
+
122
+ sig { void }
123
+ def save_cache
124
+ current_file_content =
125
+ if File.exist?(@test_file_path)
126
+ JSON.parse(File.read(@test_file_path, encoding: "UTF-8"))
127
+ else
128
+ {}
129
+ end
130
+
131
+ output = current_file_content.merge(@cache)
132
+ json_string = JSON.generate(output).force_encoding("UTF-8")
133
+ File.write(@test_file_path, json_string, encoding: "UTF-8")
134
+ end
135
+ end
136
+ end
137
+ end
138
+ end
@@ -0,0 +1,15 @@
1
+ # typed: strict
2
+ # frozen_string_literal: true
3
+
4
+ require "sorbet-runtime"
5
+
6
+ module Rager
7
+ module Http
8
+ class Request < T::Struct
9
+ const :url, String
10
+ const :verb, Rager::Http::Verb, default: Rager::Http::Verb::Get
11
+ const :headers, T::Hash[String, String]
12
+ const :body, T.nilable(String)
13
+ end
14
+ end
15
+ end
@@ -0,0 +1,14 @@
1
+ # typed: strict
2
+ # frozen_string_literal: true
3
+
4
+ require "sorbet-runtime"
5
+
6
+ module Rager
7
+ module Http
8
+ class Response < T::Struct
9
+ const :status, Integer
10
+ const :headers, T::Hash[String, T.any(String, T::Array[String])]
11
+ const :body, T.nilable(T.any(String, T::Enumerator[String]))
12
+ end
13
+ end
14
+ end
@@ -0,0 +1,20 @@
1
+ # typed: strict
2
+ # frozen_string_literal: true
3
+
4
+ require "sorbet-runtime"
5
+
6
+ module Rager
7
+ module Http
8
+ class Verb < T::Enum
9
+ enums do
10
+ Get = new("GET")
11
+ Post = new("POST")
12
+ Put = new("PUT")
13
+ Patch = new("PATCH")
14
+ Delete = new("DELETE")
15
+ Head = new("HEAD")
16
+ Options = new("OPTIONS")
17
+ end
18
+ end
19
+ end
20
+ end
@@ -0,0 +1,29 @@
1
+ # typed: strict
2
+ # frozen_string_literal: true
3
+
4
+ require "sorbet-runtime"
5
+
6
+ module Rager
7
+ module ImageGen
8
+ class Options < T::Struct
9
+ extend T::Sig
10
+ include Rager::Options
11
+
12
+ const :provider, String, default: "replicate"
13
+ const :model, String, default: "black-forest-labs/flux-schnell"
14
+ const :api_key, T.nilable(String)
15
+ const :seed, T.nilable(Integer)
16
+
17
+ sig { override.returns(T::Hash[String, T.untyped]) }
18
+ def serialize_safe
19
+ result = serialize
20
+ result["api_key"] = "[REDACTED]" if result.key?("api_key")
21
+ result
22
+ end
23
+
24
+ sig { override.void }
25
+ def validate
26
+ end
27
+ end
28
+ end
29
+ end
@@ -0,0 +1,25 @@
1
+ # typed: strict
2
+ # frozen_string_literal: true
3
+
4
+ require "sorbet-runtime"
5
+
6
+ module Rager
7
+ module ImageGen
8
+ module Providers
9
+ class Abstract
10
+ extend T::Sig
11
+ extend T::Helpers
12
+ abstract!
13
+
14
+ sig do
15
+ abstract.params(
16
+ prompt: String,
17
+ options: Rager::ImageGen::Options
18
+ ).returns(Rager::Types::ImageGenOutput)
19
+ end
20
+ def image_gen(prompt, options)
21
+ end
22
+ end
23
+ end
24
+ end
25
+ end
@@ -0,0 +1,55 @@
1
+ # typed: strict
2
+ # frozen_string_literal: true
3
+
4
+ require "sorbet-runtime"
5
+ require "json"
6
+
7
+ module Rager
8
+ module ImageGen
9
+ module Providers
10
+ class Replicate < Rager::ImageGen::Providers::Abstract
11
+ extend T::Sig
12
+
13
+ sig { override.params(prompt: String, options: Rager::ImageGen::Options).returns(Rager::Types::ImageGenOutput) }
14
+ def image_gen(prompt, options)
15
+ url = "https://api.replicate.com/v1/models/#{options.model}/predictions"
16
+ api_key = options.api_key || ENV["REPLICATE_API_KEY"]
17
+ raise Rager::Errors::MissingCredentialsError.new("Replicate", "REPLICATE_API_KEY") if api_key.nil?
18
+
19
+ headers = {
20
+ "Authorization" => "Bearer #{api_key}",
21
+ "Content-Type" => "application/json",
22
+ "Prefer" => "wait"
23
+ }
24
+
25
+ body = {
26
+ input: {
27
+ prompt: prompt
28
+ }
29
+ }
30
+ body[:input][:seed] = options.seed unless options.seed.nil?
31
+
32
+ request = Rager::Http::Request.new(
33
+ url: url,
34
+ verb: Rager::Http::Verb::Post,
35
+ headers: headers,
36
+ body: body.to_json
37
+ )
38
+
39
+ http_adapter = Rager.config.http_adapter
40
+ response = http_adapter.make_request(request)
41
+ response_body = T.cast(T.must(response.body), String)
42
+
43
+ raise Rager::Errors::HttpError.new(http_adapter, response.status, response_body) unless [200, 201].include?(response.status)
44
+
45
+ begin
46
+ parsed = JSON.parse(response_body)
47
+ parsed.fetch("output").first
48
+ rescue JSON::ParserError, KeyError => e
49
+ raise Rager::Errors::ParseError.new(e.message, response_body)
50
+ end
51
+ end
52
+ end
53
+ end
54
+ end
55
+ end
@@ -0,0 +1,31 @@
1
+ # typed: strict
2
+ # frozen_string_literal: true
3
+
4
+ require "sorbet-runtime"
5
+
6
+ module Rager
7
+ module ImageGen
8
+ extend T::Sig
9
+
10
+ sig do
11
+ params(
12
+ prompt: String,
13
+ options: Rager::ImageGen::Options
14
+ ).returns(Rager::Types::ImageGenOutput)
15
+ end
16
+ def self.image_gen(prompt, options = Rager::ImageGen::Options.new)
17
+ provider = get_provider(options.provider)
18
+ provider.image_gen(prompt, options)
19
+ end
20
+
21
+ sig { params(key: String).returns(Rager::ImageGen::Providers::Abstract) }
22
+ def self.get_provider(key)
23
+ case key.downcase
24
+ when "replicate"
25
+ Rager::ImageGen::Providers::Replicate.new
26
+ else
27
+ raise Rager::Errors::UnknownProviderError.new(Rager::Operation::ImageGen, key)
28
+ end
29
+ end
30
+ end
31
+ end
@@ -0,0 +1,13 @@
1
+ # typed: strict
2
+ # frozen_string_literal: true
3
+
4
+ require "sorbet-runtime"
5
+
6
+ module Rager
7
+ class Logger < T::Enum
8
+ enums do
9
+ Stdout = new("stdout")
10
+ Remote = new("remote")
11
+ end
12
+ end
13
+ end
@@ -0,0 +1,30 @@
1
+ # typed: strict
2
+ # frozen_string_literal: true
3
+
4
+ require "sorbet-runtime"
5
+
6
+ module Rager
7
+ module MeshGen
8
+ class Options < T::Struct
9
+ extend T::Sig
10
+ include Rager::Options
11
+
12
+ const :provider, String, default: "replicate"
13
+ const :model, String, default: "firoz/trellis"
14
+ const :version, String, default: "4876f2a8da1c544772dffa32e8889da4a1bab3a1f5c1937bfcfccb99ae347251"
15
+ const :api_key, T.nilable(String)
16
+ const :seed, T.nilable(Integer)
17
+
18
+ sig { override.returns(T::Hash[String, T.untyped]) }
19
+ def serialize_safe
20
+ result = serialize
21
+ result["api_key"] = "[REDACTED]" if result.key?("api_key")
22
+ result
23
+ end
24
+
25
+ sig { override.void }
26
+ def validate
27
+ end
28
+ end
29
+ end
30
+ end
@@ -0,0 +1,25 @@
1
+ # typed: strict
2
+ # frozen_string_literal: true
3
+
4
+ require "sorbet-runtime"
5
+
6
+ module Rager
7
+ module MeshGen
8
+ module Providers
9
+ class Abstract
10
+ extend T::Sig
11
+ extend T::Helpers
12
+ abstract!
13
+
14
+ sig do
15
+ abstract.params(
16
+ image_url: String,
17
+ options: Rager::MeshGen::Options
18
+ ).returns(Rager::Types::MeshGenOutput)
19
+ end
20
+ def mesh_gen(image_url, options)
21
+ end
22
+ end
23
+ end
24
+ end
25
+ end
@@ -0,0 +1,61 @@
1
+ # typed: strict
2
+ # frozen_string_literal: true
3
+
4
+ require "sorbet-runtime"
5
+ require "json"
6
+
7
+ module Rager
8
+ module MeshGen
9
+ module Providers
10
+ class Replicate < Rager::MeshGen::Providers::Abstract
11
+ extend T::Sig
12
+
13
+ sig { override.params(image_url: String, options: Rager::MeshGen::Options).returns(Rager::Types::MeshGenOutput) }
14
+ def mesh_gen(image_url, options)
15
+ api_key = options.api_key || ENV["REPLICATE_API_KEY"]
16
+ raise Rager::Errors::MissingCredentialsError.new("Replicate", "REPLICATE_API_KEY") if api_key.nil?
17
+
18
+ headers = {
19
+ "Authorization" => "Bearer #{api_key}",
20
+ "Content-Type" => "application/json",
21
+ "Prefer" => "wait"
22
+ }
23
+
24
+ body = {
25
+ version: options.version,
26
+ input: {
27
+ images: [image_url],
28
+ texture_size: 2048,
29
+ mesh_simplify: 0.9,
30
+ generate_model: true,
31
+ save_gaussian_ply: true,
32
+ ss_sampling_steps: 38
33
+ }
34
+ }
35
+ body[:input][:seed] = options.seed unless options.seed.nil?
36
+
37
+ request = Rager::Http::Request.new(
38
+ url: "https://api.replicate.com/v1/predictions",
39
+ verb: Rager::Http::Verb::Post,
40
+ headers: headers,
41
+ body: body.to_json
42
+ )
43
+
44
+ response = Rager.config.http_adapter.make_request(request)
45
+ response_body = T.cast(T.must(response.body), String)
46
+
47
+ raise Rager::Errors::HttpError.new(Rager.config.http_adapter, response.status, response_body) unless [200,
48
+ 201, 202].include?(response.status)
49
+
50
+ begin
51
+ json = JSON.parse(response_body)
52
+ model_file = json.dig("output", "model_file")
53
+ model_file || json.fetch("urls").fetch("get")
54
+ rescue JSON::ParserError, KeyError => e
55
+ raise Rager::Errors::ParseError.new(e.message, response_body)
56
+ end
57
+ end
58
+ end
59
+ end
60
+ end
61
+ end
@@ -0,0 +1,31 @@
1
+ # typed: strict
2
+ # frozen_string_literal: true
3
+
4
+ require "sorbet-runtime"
5
+
6
+ module Rager
7
+ module MeshGen
8
+ extend T::Sig
9
+
10
+ sig do
11
+ params(
12
+ image_url: String,
13
+ options: Rager::MeshGen::Options
14
+ ).returns(Rager::Types::MeshGenOutput)
15
+ end
16
+ def self.mesh_gen(image_url, options = Rager::MeshGen::Options.new)
17
+ provider = get_provider(options.provider)
18
+ provider.mesh_gen(image_url, options)
19
+ end
20
+
21
+ sig { params(key: String).returns(Rager::MeshGen::Providers::Abstract) }
22
+ def self.get_provider(key)
23
+ case key.downcase
24
+ when "replicate"
25
+ Rager::MeshGen::Providers::Replicate.new
26
+ else
27
+ raise Rager::Errors::UnknownProviderError.new(Rager::Operation::MeshGen, key)
28
+ end
29
+ end
30
+ end
31
+ end
@@ -0,0 +1,14 @@
1
+ # typed: strict
2
+ # frozen_string_literal: true
3
+
4
+ require "sorbet-runtime"
5
+
6
+ module Rager
7
+ class Operation < T::Enum
8
+ enums do
9
+ Chat = new("chat")
10
+ ImageGen = new("image_gen")
11
+ MeshGen = new("mesh_gen")
12
+ end
13
+ end
14
+ end