replicate-ruby 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml ADDED
@@ -0,0 +1,7 @@
1
+ ---
2
+ SHA256:
3
+ metadata.gz: 7c894d68b6c7c1dcc1bfdb5fbe0067a1485c3e918029f3d1fd5139ae8acbcc8b
4
+ data.tar.gz: 00de848d1fc16197959283ae65c136e1723eca8c15a2a0556bb08d84e9a06a2f
5
+ SHA512:
6
+ metadata.gz: b3c7f2edbba1a0e3c572e7f1f9e614336e8dd015f7683f094fda645f942c75b8abf9dac43719625761ad5b5ebbb00e65ef5f71741c1b1c0d12f2e7d3c72df75e
7
+ data.tar.gz: ddced9835e3e37acd8c132be0036519b63b4a5adca1332ee33aef2dd75a3830112df937492aa90e5337007db24252bacdae1cc404e063e11ac3e0608c8498566
data/.rubocop.yml ADDED
@@ -0,0 +1,13 @@
1
+ AllCops:
2
+ TargetRubyVersion: 3.1
3
+
4
+ Style/StringLiterals:
5
+ Enabled: true
6
+ EnforcedStyle: double_quotes
7
+
8
+ Style/StringLiteralsInInterpolation:
9
+ Enabled: true
10
+ EnforcedStyle: double_quotes
11
+
12
+ Layout/LineLength:
13
+ Max: 120
data/CHANGELOG.md ADDED
@@ -0,0 +1,5 @@
1
+ ## [Unreleased]
2
+
3
+ ## [0.1.0] - 2022-10-12
4
+
5
+ - Initial release
data/Gemfile ADDED
@@ -0,0 +1,12 @@
1
+ # frozen_string_literal: true
2
+
3
+ source "https://rubygems.org"
4
+
5
+ # Specify your gem's dependencies in replicate-ruby.gemspec
6
+ gemspec
7
+
8
+ gem "rake", "~> 13.0"
9
+
10
+ gem "minitest", "~> 5.0"
11
+
12
+ gem "rubocop", "~> 1.21"
data/Gemfile.lock ADDED
@@ -0,0 +1,71 @@
1
+ PATH
2
+ remote: .
3
+ specs:
4
+ replicate-ruby (0.1.0)
5
+ addressable
6
+ faraday (>= 2.0)
7
+ faraday-retry
8
+
9
+ GEM
10
+ remote: https://rubygems.org/
11
+ specs:
12
+ addressable (2.8.1)
13
+ public_suffix (>= 2.0.2, < 6.0)
14
+ ast (2.4.2)
15
+ coderay (1.1.3)
16
+ crack (0.4.5)
17
+ rexml
18
+ faraday (2.6.0)
19
+ faraday-net_http (>= 2.0, < 3.1)
20
+ ruby2_keywords (>= 0.0.4)
21
+ faraday-net_http (3.0.1)
22
+ faraday-retry (2.0.0)
23
+ faraday (~> 2.0)
24
+ hashdiff (1.0.1)
25
+ json (2.6.2)
26
+ method_source (1.0.0)
27
+ minitest (5.16.3)
28
+ parallel (1.22.1)
29
+ parser (3.1.2.1)
30
+ ast (~> 2.4.1)
31
+ pry (0.14.1)
32
+ coderay (~> 1.1)
33
+ method_source (~> 1.0)
34
+ public_suffix (5.0.0)
35
+ rainbow (3.1.1)
36
+ rake (13.0.6)
37
+ regexp_parser (2.6.0)
38
+ rexml (3.2.5)
39
+ rubocop (1.36.0)
40
+ json (~> 2.3)
41
+ parallel (~> 1.10)
42
+ parser (>= 3.1.2.1)
43
+ rainbow (>= 2.2.2, < 4.0)
44
+ regexp_parser (>= 1.8, < 3.0)
45
+ rexml (>= 3.2.5, < 4.0)
46
+ rubocop-ast (>= 1.20.1, < 2.0)
47
+ ruby-progressbar (~> 1.7)
48
+ unicode-display_width (>= 1.4.0, < 3.0)
49
+ rubocop-ast (1.21.0)
50
+ parser (>= 3.1.1.0)
51
+ ruby-progressbar (1.11.0)
52
+ ruby2_keywords (0.0.5)
53
+ unicode-display_width (2.3.0)
54
+ webmock (3.18.1)
55
+ addressable (>= 2.8.0)
56
+ crack (>= 0.3.2)
57
+ hashdiff (>= 0.4.0, < 2.0.0)
58
+
59
+ PLATFORMS
60
+ x86_64-darwin-21
61
+
62
+ DEPENDENCIES
63
+ minitest (~> 5.0)
64
+ pry
65
+ rake (~> 13.0)
66
+ replicate-ruby!
67
+ rubocop (~> 1.21)
68
+ webmock
69
+
70
+ BUNDLED WITH
71
+ 2.3.7
data/README.md ADDED
@@ -0,0 +1,63 @@
1
+ # Replicate Ruby client
2
+
3
+ This is a Ruby client for Replicate. It lets you run models from your Ruby code and do various other things on Replicate.
4
+
5
+ ## Installation
6
+
7
+ Add this line to your application's Gemfile:
8
+
9
+ ```ruby
10
+ gem 'replicate-ruby'
11
+ ```
12
+
13
+ ## Usage
14
+
15
+ Grab your token from replicate.com/account and authenticate by configuring `api_token`:
16
+
17
+ ```ruby
18
+ Replicate.configure do |config|
19
+ config.api_token = "your_api_token"
20
+ end
21
+ ```
22
+
23
+ You can retrieve a model:
24
+
25
+ ```ruby
26
+ # Latest version
27
+ model = Replicate.client.retrieve_model("stability-ai/stable-diffusion")
28
+ version = model.latest_version
29
+
30
+ # List of versions
31
+ version = Replicate.client.retrieve_model("stability-ai/stable-diffusion", version: all)
32
+
33
+ # Specific version
34
+ version = Replicate.client.retrieve_model("stability-ai/stable-diffusion", version: "<id>")
35
+ ```
36
+
37
+ And then run predictions on it:
38
+
39
+ ```ruby
40
+ prediction = version.predict(prompt: "a handsome teddy bear")
41
+
42
+ # Optionally you can submit a webhook url for replicate to send a POST request once a prediction has completed
43
+ prediction = version.predict(prompt: "a handsome teddy bear", "https://webhook.url/path")
44
+
45
+ # Or manually refetch predictions
46
+ prediction = prediction.refetch
47
+
48
+ # or cancel a running prediction
49
+ prediction = prediction.cancel
50
+
51
+ # and if a prediction returns with status succeeded, you can retrieve the output
52
+ output = prediction.output
53
+ ```
54
+
55
+ ## Development
56
+
57
+ After checking out the repo, run `bin/setup` to install dependencies. Then, run `rake test` to run the tests. You can also run `bin/console` for an interactive prompt that will allow you to experiment.
58
+
59
+ To install this gem onto your local machine, run `bundle exec rake install`. To release a new version, update the version number in `version.rb`, and then run `bundle exec rake release`, which will create a git tag for the version, push git commits and the created tag, and push the `.gem` file to [rubygems.org](https://rubygems.org).
60
+
61
+ ## Contributing
62
+
63
+ Bug reports and pull requests are welcome on GitHub at https://github.com/YOUR_GITHUB_USERNAME/replicate-ruby. This project is intended to be a safe, welcoming space for collaboration, and contributors are expected to adhere to the [code of conduct](https://github.com/YOUR_GITHUB_USERNAME/replicate-ruby/blob/master/CODE_OF_CONDUCT.md).
data/Rakefile ADDED
@@ -0,0 +1,16 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "bundler/gem_tasks"
4
+ require "rake/testtask"
5
+
6
+ Rake::TestTask.new(:test) do |t|
7
+ t.libs << "test"
8
+ t.libs << "lib"
9
+ t.test_files = FileList["test/**/*_test.rb"]
10
+ end
11
+
12
+ require "rubocop/rake_task"
13
+
14
+ RuboCop::RakeTask.new
15
+
16
+ task default: %i[test rubocop]
@@ -0,0 +1,29 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Replicate
4
+ class Client
5
+ # Methods for the Prediction API
6
+ module Model
7
+ # Get a model
8
+ # @see https://replicate.com/docs/reference/http#get-model
9
+ def retrieve_model(model, version: :latest)
10
+ case version
11
+ when :latest
12
+ Replicate::Record::Model.new(self, get("models/#{model}"))
13
+ when :all
14
+ response = get("models/#{model}/versions")
15
+ response["results"].map! { |result| Replicate::Record::ModelVersion.new(self, result) }
16
+ response
17
+ else
18
+ Replicate::Record::ModelVersion.new(self, get("models/#{model}/versions/#{version}"))
19
+ end
20
+ end
21
+
22
+ # Get a collection of models
23
+ # @see https://replicate.com/docs/reference/http#get-collection
24
+ def retrieve_collection(slug)
25
+ get("collections/#{slug}")
26
+ end
27
+ end
28
+ end
29
+ end
@@ -0,0 +1,34 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Replicate
4
+ class Client
5
+ # Methods for the Prediction API
6
+ module Prediction
7
+ # Get a prediction
8
+ # @see https://replicate.com/docs/reference/http#get-prediction
9
+ def retrieve_prediction(id)
10
+ Replicate::Record::Prediction.new(self, get("predictions/#{id}"))
11
+ end
12
+
13
+ # Get a list of predictions
14
+ # @see https://replicate.com/docs/reference/http#get-predictions
15
+ def list_predictions(cursor = nil)
16
+ response = get("predictions", cursor: cursor)
17
+ response["results"].map! { |result| Replicate::Record::Prediction.new(self, result) }
18
+ response
19
+ end
20
+
21
+ # Create a prediction
22
+ # @see https://replicate.com/docs/reference/http#create-prediction
23
+ def create_prediction(params)
24
+ Replicate::Record::Prediction.new(self, post("predictions", params))
25
+ end
26
+
27
+ # Cancel a prediction
28
+ # @see https://replicate.com/docs/reference/http#cancel-prediction
29
+ def cancel_prediction(id)
30
+ Replicate::Record::Prediction.new(self, post("predictions/#{id}/cancel"))
31
+ end
32
+ end
33
+ end
34
+ end
@@ -0,0 +1,25 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "replicate/configurable"
4
+ require "replicate/connection"
5
+
6
+ require "replicate/client/model"
7
+ require "replicate/client/prediction"
8
+
9
+ module Replicate
10
+ class Client
11
+ include Replicate::Configurable
12
+ include Replicate::Connection
13
+
14
+ include Replicate::Client::Model
15
+ include Replicate::Client::Prediction
16
+
17
+ def initialize(options = {})
18
+ # Use options passed in, but fall back to module defaults
19
+ Replicate::Configurable.keys.each do |key|
20
+ value = options.key?(key) ? options[key] : Replicate.instance_variable_get(:"@#{key}")
21
+ instance_variable_set(:"@#{key}", value)
22
+ end
23
+ end
24
+ end
25
+ end
@@ -0,0 +1,32 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Replicate
4
+ module Configurable
5
+ attr_accessor :api_token
6
+ attr_writer :api_endpoint
7
+
8
+ class << self
9
+ # List of configurable keys for {Datatrans::Client}
10
+ # @return [Array] of option keys
11
+ def keys
12
+ @keys ||= %i[api_token api_endpoint]
13
+ end
14
+ end
15
+
16
+ # Set configuration options using a block
17
+ def configure
18
+ yield self
19
+ end
20
+
21
+ # API endpoint methods
22
+ def api_endpoint
23
+ @api_endpoint ||= "https://api.replicate.com/v1"
24
+ end
25
+
26
+ private
27
+
28
+ def options
29
+ Hash[Replicate::Configurable.keys.map { |key| [key, send(key)] }]
30
+ end
31
+ end
32
+ end
@@ -0,0 +1,110 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "faraday"
4
+ require "faraday/net_http"
5
+ require "faraday/retry"
6
+ require "addressable/uri"
7
+
8
+ module Replicate
9
+ # Network layer for API clients.
10
+ module Connection
11
+ DEFAULT_MEDIA_TYPE = "application/json"
12
+ USER_AGENT = "Datatrans Ruby Gem"
13
+
14
+ # Header keys that can be passed in options hash to {#get},{#head}
15
+ CONVENIENCE_HEADERS = Set.new(%i[accept content_type])
16
+
17
+ # Make a HTTP GET request
18
+ #
19
+ # @param url [String] The path, relative to {#api_endpoint}
20
+ # @param options [Hash] Query and header params for request
21
+ # @return [Sawyer::Resource]
22
+ def get(url, options = {})
23
+ request :get, url, options
24
+ end
25
+
26
+ # Make a HTTP POST request
27
+ #
28
+ # @param url [String] The path, relative to {#api_endpoint}
29
+ # @param options [Hash] Body and header params for request
30
+ # @return [Sawyer::Resource]
31
+ def post(url, options = {})
32
+ request :post, url, options.to_json
33
+ end
34
+
35
+ # Make a HTTP PUT request
36
+ #
37
+ # @param url [String] The path, relative to {#api_endpoint}
38
+ # @param options [Hash] Body and header params for request
39
+ # @return [Sawyer::Resource]
40
+ def put(url, options = {})
41
+ request :put, url, options.to_json
42
+ end
43
+
44
+ # Make a HTTP PATCH request
45
+ #
46
+ # @param url [String] The path, relative to {#api_endpoint}
47
+ # @param options [Hash] Body and header params for request
48
+ # @return [Sawyer::Resource]
49
+ def patch(url, options = {})
50
+ request :patch, url, options.to_json
51
+ end
52
+
53
+ # Make a HTTP DELETE request
54
+ #
55
+ # @param url [String] The path, relative to {#api_endpoint}
56
+ # @param options [Hash] Query and header params for request
57
+ # @return [Sawyer::Resource]
58
+ def delete(url, options = {})
59
+ request :delete, url, options
60
+ end
61
+
62
+ # Make a HTTP HEAD request
63
+ #
64
+ # @param url [String] The path, relative to {#api_endpoint}
65
+ # @param options [Hash] Query and header params for request
66
+ # @return [Sawyer::Resource]
67
+ def head(url, options = {})
68
+ request :head, url, options
69
+ end
70
+
71
+ # Hypermedia agent for the datatrans API
72
+ #
73
+ # @return [Sawyer::Agent]
74
+ def agent
75
+ @agent ||= Faraday.new(url: endpoint) do |conn|
76
+ conn.request :retry
77
+ conn.request :authorization, 'Token', api_token
78
+ conn.headers["Content-Type"] = DEFAULT_MEDIA_TYPE
79
+ conn.headers["Accept"] = DEFAULT_MEDIA_TYPE
80
+
81
+ conn.adapter :net_http
82
+ end
83
+ end
84
+
85
+ # Response for last HTTP request
86
+ #
87
+ # @return [Sawyer::Response]
88
+ def last_response
89
+ @last_response if defined? @last_response
90
+ end
91
+
92
+ protected
93
+
94
+ def endpoint
95
+ api_endpoint
96
+ end
97
+
98
+ private
99
+
100
+ def request(method, path, data)
101
+ @last_response = agent.send(method, Addressable::URI.parse(path.to_s).normalize.to_s, data)
102
+ case @last_response.status
103
+ when 400
104
+ raise Error, "#{@last_response.status} #{@last_response.reason_phrase}: #{JSON.parse(@last_response.body)}"
105
+ else
106
+ JSON.parse(@last_response.body)
107
+ end
108
+ end
109
+ end
110
+ end
@@ -0,0 +1,37 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Replicate
4
+ module Record
5
+ class Base
6
+ def initialize(client, params)
7
+ @client = client
8
+ self.assign_attributes = params
9
+ end
10
+
11
+ def assign_attributes=(params)
12
+ params = params.instance_variables_hash if params.is_a?(self.class)
13
+ params.each do |key, value|
14
+ instance_variable_set("@#{key}", value)
15
+ end
16
+ end
17
+
18
+ def method_missing(method_name, *args, &block)
19
+ if instance_variables.include? :"@#{method_name}"
20
+ instance_variable_get "@#{method_name}"
21
+ else
22
+ super
23
+ end
24
+ end
25
+
26
+ def instance_variables_hash
27
+ Hash[instance_variables.map { |name| [name.to_s[1..-1], instance_variable_get(name)] } ]
28
+ end
29
+
30
+ def inspect
31
+ string = "#<#{self.class.name}:#{object_id} "
32
+ fields = instance_variables_hash.except("client").map { |attr, value| "#{attr}: #{value.inspect}" }
33
+ string << fields.join(", ") << ">"
34
+ end
35
+ end
36
+ end
37
+ end
@@ -0,0 +1,12 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Replicate
4
+ module Record
5
+ class Model < Base
6
+ def initialize(client, params)
7
+ params["latest_version"] = Replicate::Record::ModelVersion.new(client, params["latest_version"])
8
+ super
9
+ end
10
+ end
11
+ end
12
+ end
@@ -0,0 +1,15 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Replicate
4
+ module Record
5
+ class ModelVersion < Base
6
+ def predict(input, webhook_completed = nil)
7
+ params = {}
8
+ params[:version] = id
9
+ params[:input] = input
10
+ params[:webhook_completed] = webhook_completed
11
+ client.create_prediction(params)
12
+ end
13
+ end
14
+ end
15
+ end
@@ -0,0 +1,15 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Replicate
4
+ module Record
5
+ class Prediction < Base
6
+ def refetch
7
+ self.assign_attributes = client.retrieve_prediction(id)
8
+ end
9
+
10
+ def cancel
11
+ self.assign_attributes = client.cancel_prediction(id)
12
+ end
13
+ end
14
+ end
15
+ end
@@ -0,0 +1,5 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Replicate
4
+ VERSION = "0.1.0"
5
+ end
data/lib/replicate.rb ADDED
@@ -0,0 +1,22 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "replicate/version"
4
+ require "replicate/client"
5
+
6
+ require "replicate/record/base"
7
+ require "replicate/record/model"
8
+ require "replicate/record/model_version"
9
+ require "replicate/record/prediction"
10
+
11
+ module Replicate
12
+ class Error < StandardError; end
13
+
14
+ class << self
15
+ include Replicate::Configurable
16
+
17
+ def client
18
+ return @client if defined?(@client)
19
+ @client = Replicate::Client.new(options)
20
+ end
21
+ end
22
+ end
@@ -0,0 +1,4 @@
1
+ module Replicate
2
+ VERSION: String
3
+ # See the writing guide of rbs: https://github.com/ruby/rbs#guides
4
+ end
metadata ADDED
@@ -0,0 +1,133 @@
1
+ --- !ruby/object:Gem::Specification
2
+ name: replicate-ruby
3
+ version: !ruby/object:Gem::Version
4
+ version: 0.1.0
5
+ platform: ruby
6
+ authors:
7
+ - Daniel Puglisi
8
+ autorequire:
9
+ bindir: exe
10
+ cert_chain: []
11
+ date: 2022-10-13 00:00:00.000000000 Z
12
+ dependencies:
13
+ - !ruby/object:Gem::Dependency
14
+ name: faraday
15
+ requirement: !ruby/object:Gem::Requirement
16
+ requirements:
17
+ - - ">="
18
+ - !ruby/object:Gem::Version
19
+ version: '2.0'
20
+ type: :runtime
21
+ prerelease: false
22
+ version_requirements: !ruby/object:Gem::Requirement
23
+ requirements:
24
+ - - ">="
25
+ - !ruby/object:Gem::Version
26
+ version: '2.0'
27
+ - !ruby/object:Gem::Dependency
28
+ name: faraday-retry
29
+ requirement: !ruby/object:Gem::Requirement
30
+ requirements:
31
+ - - ">="
32
+ - !ruby/object:Gem::Version
33
+ version: '0'
34
+ type: :runtime
35
+ prerelease: false
36
+ version_requirements: !ruby/object:Gem::Requirement
37
+ requirements:
38
+ - - ">="
39
+ - !ruby/object:Gem::Version
40
+ version: '0'
41
+ - !ruby/object:Gem::Dependency
42
+ name: addressable
43
+ requirement: !ruby/object:Gem::Requirement
44
+ requirements:
45
+ - - ">="
46
+ - !ruby/object:Gem::Version
47
+ version: '0'
48
+ type: :runtime
49
+ prerelease: false
50
+ version_requirements: !ruby/object:Gem::Requirement
51
+ requirements:
52
+ - - ">="
53
+ - !ruby/object:Gem::Version
54
+ version: '0'
55
+ - !ruby/object:Gem::Dependency
56
+ name: pry
57
+ requirement: !ruby/object:Gem::Requirement
58
+ requirements:
59
+ - - ">="
60
+ - !ruby/object:Gem::Version
61
+ version: '0'
62
+ type: :development
63
+ prerelease: false
64
+ version_requirements: !ruby/object:Gem::Requirement
65
+ requirements:
66
+ - - ">="
67
+ - !ruby/object:Gem::Version
68
+ version: '0'
69
+ - !ruby/object:Gem::Dependency
70
+ name: webmock
71
+ requirement: !ruby/object:Gem::Requirement
72
+ requirements:
73
+ - - ">="
74
+ - !ruby/object:Gem::Version
75
+ version: '0'
76
+ type: :development
77
+ prerelease: false
78
+ version_requirements: !ruby/object:Gem::Requirement
79
+ requirements:
80
+ - - ">="
81
+ - !ruby/object:Gem::Version
82
+ version: '0'
83
+ description:
84
+ email:
85
+ - daniel@codegestalt.com
86
+ executables: []
87
+ extensions: []
88
+ extra_rdoc_files: []
89
+ files:
90
+ - ".rubocop.yml"
91
+ - CHANGELOG.md
92
+ - Gemfile
93
+ - Gemfile.lock
94
+ - README.md
95
+ - Rakefile
96
+ - lib/replicate.rb
97
+ - lib/replicate/client.rb
98
+ - lib/replicate/client/model.rb
99
+ - lib/replicate/client/prediction.rb
100
+ - lib/replicate/configurable.rb
101
+ - lib/replicate/connection.rb
102
+ - lib/replicate/record/base.rb
103
+ - lib/replicate/record/model.rb
104
+ - lib/replicate/record/model_version.rb
105
+ - lib/replicate/record/prediction.rb
106
+ - lib/replicate/version.rb
107
+ - sig/replicate/ruby.rbs
108
+ homepage: https://github.com/danielpuglisi/replicate-ruby
109
+ licenses: []
110
+ metadata:
111
+ homepage_uri: https://github.com/danielpuglisi/replicate-ruby
112
+ source_code_uri: https://github.com/danielpuglisi/replicate-ruby
113
+ changelog_uri: https://github.com/danielpuglisi/replicate-ruby/blob/master/CHANGELOG.md
114
+ post_install_message:
115
+ rdoc_options: []
116
+ require_paths:
117
+ - lib
118
+ required_ruby_version: !ruby/object:Gem::Requirement
119
+ requirements:
120
+ - - ">="
121
+ - !ruby/object:Gem::Version
122
+ version: 2.6.0
123
+ required_rubygems_version: !ruby/object:Gem::Requirement
124
+ requirements:
125
+ - - ">="
126
+ - !ruby/object:Gem::Version
127
+ version: '0'
128
+ requirements: []
129
+ rubygems_version: 3.3.7
130
+ signing_key:
131
+ specification_version: 4
132
+ summary: Ruby client for Replicate
133
+ test_files: []