rager 0.2.1 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,97 @@
1
+ # typed: strict
2
+ # frozen_string_literal: true
3
+
4
+ require "sorbet-runtime"
5
+ require "json"
6
+
7
+ module Rager
8
+ module Utils
9
+ module Replicate
10
+ extend T::Sig
11
+
12
+ sig { params(prediction_url: String, key: T.nilable(String), http_adapter: T.nilable(Rager::Http::Adapters::Abstract)).returns(T.nilable(String)) }
13
+ def self.download_prediction(prediction_url, key = nil, http_adapter: nil)
14
+ download_url = get_download_url(prediction_url, key, http_adapter: http_adapter)
15
+ return nil if download_url.nil?
16
+
17
+ Rager::Utils::Http.download_binary(download_url)
18
+ rescue
19
+ nil
20
+ end
21
+
22
+ sig { params(prediction_url: String, key: T.nilable(String), path: String, http_adapter: T.nilable(Rager::Http::Adapters::Abstract)).returns(T.nilable(String)) }
23
+ def self.download_prediction_to_file(prediction_url, key = nil, path:, http_adapter: nil)
24
+ download_url = get_download_url(prediction_url, key, http_adapter: http_adapter)
25
+ return nil if download_url.nil?
26
+
27
+ Rager::Utils::Http.download_file(download_url, path)
28
+ rescue
29
+ nil
30
+ end
31
+
32
+ sig { params(prediction_url: String, key: T.nilable(String), http_adapter: T.nilable(Rager::Http::Adapters::Abstract)).returns(T.nilable(String)) }
33
+ def self.get_download_url(prediction_url, key = nil, http_adapter: nil)
34
+ api_key = ENV["REPLICATE_API_KEY"]
35
+ raise Rager::Errors::MissingCredentialsError.new("Replicate", "REPLICATE_API_KEY") if api_key.nil?
36
+
37
+ adapter = http_adapter || Rager.config.http_adapter
38
+ response = adapter.make_request(
39
+ Rager::Http::Request.new(
40
+ url: prediction_url,
41
+ headers: {"Authorization" => "Bearer #{api_key}", "Content-Type" => "application/json"}
42
+ )
43
+ )
44
+
45
+ raise Rager::Errors::HttpError.new(adapter, response.status, T.cast(response.body, T.nilable(String))) unless response.success?
46
+
47
+ data = JSON.parse(T.cast(T.must(response.body), String))
48
+ return nil if ["starting", "processing"].include?(data["status"])
49
+
50
+ if data["status"] == "failed"
51
+ error_msg = data["error"] || "Prediction failed"
52
+ raise Rager::Errors::HttpError.new(adapter, 422, "Prediction failed: #{error_msg}")
53
+ end
54
+
55
+ return nil unless data["status"] == "succeeded"
56
+
57
+ output = data["output"]
58
+ return nil if output.nil?
59
+
60
+ download_url = if key && output.is_a?(Hash) && output[key]
61
+ output[key]
62
+ elsif output.is_a?(Hash)
63
+ output.values.compact.first
64
+ elsif output.is_a?(Array) && !output.empty?
65
+ output.first
66
+ else
67
+ output.to_s
68
+ end
69
+
70
+ return nil if download_url.nil? || download_url.empty?
71
+ download_url
72
+ rescue JSON::ParserError => e
73
+ raise Rager::Errors::ParseError.new("Failed to parse prediction response", e.message)
74
+ rescue Rager::Errors::HttpError, Rager::Errors::ParseError, Rager::Errors::MissingCredentialsError
75
+ raise
76
+ rescue
77
+ nil
78
+ end
79
+
80
+ sig { params(prediction_url: String, key: T.nilable(String), path: T.nilable(String), max_attempts: Integer, sleep_interval: Integer, http_adapter: T.nilable(Rager::Http::Adapters::Abstract)).returns(T.nilable(String)) }
81
+ def self.poll_prediction(prediction_url, key: nil, path: nil, max_attempts: 30, sleep_interval: 10, http_adapter: nil)
82
+ max_attempts.times do
83
+ result = if path
84
+ download_prediction_to_file(prediction_url, key, path: path, http_adapter: http_adapter)
85
+ else
86
+ download_prediction(prediction_url, key, http_adapter: http_adapter)
87
+ end
88
+ return result unless result.nil?
89
+ sleep(sleep_interval)
90
+ end
91
+
92
+ adapter = http_adapter || Rager.config.http_adapter
93
+ raise Rager::Errors::HttpError.new(adapter, 408, "Prediction polling timed out after #{max_attempts} attempts")
94
+ end
95
+ end
96
+ end
97
+ end
data/lib/rager/version.rb CHANGED
@@ -2,5 +2,5 @@
2
2
  # frozen_string_literal: true
3
3
 
4
4
  module Rager
5
- VERSION = "0.2.1"
5
+ VERSION = "0.4.0"
6
6
  end
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: rager
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.2.1
4
+ version: 0.4.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - mvkvc
@@ -106,6 +106,7 @@ files:
106
106
  - lib/rager/http/verb.rb
107
107
  - lib/rager/image_gen.rb
108
108
  - lib/rager/image_gen/options.rb
109
+ - lib/rager/image_gen/output_format.rb
109
110
  - lib/rager/image_gen/providers/abstract.rb
110
111
  - lib/rager/image_gen/providers/replicate.rb
111
112
  - lib/rager/logger.rb
@@ -133,8 +134,10 @@ files:
133
134
  - lib/rager/template/providers/abstract.rb
134
135
  - lib/rager/template/providers/erb.rb
135
136
  - lib/rager/types.rb
137
+ - lib/rager/utils/http.rb
138
+ - lib/rager/utils/replicate.rb
136
139
  - lib/rager/version.rb
137
- homepage: https://github.com/mvkvc/rager_ruby
140
+ homepage: https://github.com/mvkvc/rager_rb
138
141
  licenses:
139
142
  - MIT
140
143
  metadata: {}
@@ -145,7 +148,7 @@ required_ruby_version: !ruby/object:Gem::Requirement
145
148
  requirements:
146
149
  - - ">="
147
150
  - !ruby/object:Gem::Version
148
- version: '3.1'
151
+ version: '3.2'
149
152
  required_rubygems_version: !ruby/object:Gem::Requirement
150
153
  requirements:
151
154
  - - ">="