replicate-ruby 0.1.7 → 0.2.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 6e58708bf432115af17aef8f67789838f4044526a851785a04dd3f206d17f727
4
- data.tar.gz: f1422c952891c61df45c4a4bd69e27f13b724185edc883df8dec00addec7679a
3
+ metadata.gz: 82385c582239598bb4b62a3e133228b7e6fc417ea57a047c1489e5d94d787faf
4
+ data.tar.gz: 4ccecfdb33f1a4cf3218587c7ed4fa28475da1ca13369e59d5875b07e4f092f8
5
5
  SHA512:
6
- metadata.gz: 234415f451e806a5d2c800cf11da2aa626b131f4c8f5d3df501c317607a5637c267bc757d75115bb8fd5c6378500bd6c38c749c722f7a38ebad4f6bd8b7b5f7d
7
- data.tar.gz: ea46776b902acd5cec2283fd8bd390127244771c0250e05623303f1d5284e1aa982240eecd5227b2b2f243d01ccc8d8c099442a33c28f114287b97e9765e26ba
6
+ metadata.gz: c8edef78b1309f5873997b33b1de5de82f2f8982ce2ec3ed32b0cdeca0e9038c77b4d56e6173441d106160e77fb1e0d2311024d5cae2c68a1fc91ef6c62cb725
7
+ data.tar.gz: 56958524475fbd6636e1eb3f409b4dad100fe8179a982a746a0ee2ea4356ffa66724269a9da0e226e374bbe34c2526aacf3a358e5847ed55429ebf06700f2383
data/Gemfile.lock CHANGED
@@ -4,6 +4,7 @@ PATH
4
4
  replicate-ruby (0.1.7)
5
5
  addressable
6
6
  faraday (>= 2.0)
7
+ faraday-multipart
7
8
  faraday-retry
8
9
 
9
10
  GEM
@@ -12,12 +13,15 @@ GEM
12
13
  addressable (2.8.1)
13
14
  public_suffix (>= 2.0.2, < 6.0)
14
15
  ast (2.4.2)
16
+ byebug (11.1.3)
15
17
  coderay (1.1.3)
16
18
  crack (0.4.5)
17
19
  rexml
18
20
  faraday (2.7.2)
19
21
  faraday-net_http (>= 2.0, < 3.1)
20
22
  ruby2_keywords (>= 0.0.4)
23
+ faraday-multipart (1.0.4)
24
+ multipart-post (~> 2)
21
25
  faraday-net_http (3.0.2)
22
26
  faraday-retry (2.0.0)
23
27
  faraday (~> 2.0)
@@ -25,6 +29,7 @@ GEM
25
29
  json (2.6.2)
26
30
  method_source (1.0.0)
27
31
  minitest (5.16.3)
32
+ multipart-post (2.2.3)
28
33
  parallel (1.22.1)
29
34
  parser (3.1.2.1)
30
35
  ast (~> 2.4.1)
@@ -57,9 +62,11 @@ GEM
57
62
  hashdiff (>= 0.4.0, < 2.0.0)
58
63
 
59
64
  PLATFORMS
65
+ arm64-darwin-22
60
66
  x86_64-darwin-21
61
67
 
62
68
  DEPENDENCIES
69
+ byebug
63
70
  minitest (~> 5.0)
64
71
  pry
65
72
  rake (~> 13.0)
data/README.md CHANGED
@@ -54,6 +54,42 @@ id = prediction.id # store prediction id in your backend
54
54
  prediction = Replicate.client.retrieve_prediction(id) # retrieve prediction during webhook with id from backend
55
55
  ```
56
56
 
57
+ ## Dreambooth
58
+
59
+ There is support for the [experimental dreambooth endpoint](https://replicate.com/blog/dreambooth-api).
60
+
61
+ First, upload your training dataset:
62
+
63
+ ```
64
+ upload = Replicate.client.create_upload
65
+ upload.attach('tmp/data.zip') # replace with the path to your zip file
66
+ ```
67
+
68
+ Then start training a new model using, for instance:
69
+
70
+ ```
71
+ training = Replicate.client.create_training(
72
+ input: {
73
+ instance_prompt: "zwx style",
74
+ class_prompt: "style",
75
+ instance_data: upload.serving_url,
76
+ max_train_steps: 5000
77
+ },
78
+ model: 'yourusername/yourmodel'
79
+ )
80
+ ```
81
+
82
+ As soon as the model has finished training, you can run predictions on it:
83
+
84
+ ```
85
+ prediction = Replicate.client.create_prediction(
86
+ input: {
87
+ prompt: 'your prompt, zwx style'
88
+ },
89
+ version: training.version
90
+ )
91
+ ```
92
+
57
93
  ## Development
58
94
 
59
95
  After checking out the repo, run `bin/setup` to install dependencies. Then, run `rake test` to run the tests. You can also run `bin/console` for an interactive prompt that will allow you to experiment.
data/data.zip ADDED
File without changes
@@ -9,20 +9,22 @@ module Replicate
9
9
  def retrieve_model(model, version: :latest)
10
10
  case version
11
11
  when :latest
12
- Replicate::Record::Model.new(self, get("models/#{model}"))
12
+ response = api_endpoint.get("models/#{model}")
13
+ Replicate::Record::Model.new(self, response)
13
14
  when :all
14
- response = get("models/#{model}/versions")
15
+ response = api_endpoint.get("models/#{model}/versions")
15
16
  response["results"].map! { |result| Replicate::Record::ModelVersion.new(self, result) }
16
17
  response
17
18
  else
18
- Replicate::Record::ModelVersion.new(self, get("models/#{model}/versions/#{version}"))
19
+ response = api_endpoint.get("models/#{model}/versions/#{version}")
20
+ Replicate::Record::ModelVersion.new(self, response)
19
21
  end
20
22
  end
21
23
 
22
24
  # Get a collection of models
23
25
  # @see https://replicate.com/docs/reference/http#get-collection
24
26
  def retrieve_collection(slug)
25
- get("collections/#{slug}")
27
+ api_endpoint.get("collections/#{slug}")
26
28
  end
27
29
  end
28
30
  end
@@ -7,13 +7,14 @@ module Replicate
7
7
  # Get a prediction
8
8
  # @see https://replicate.com/docs/reference/http#get-prediction
9
9
  def retrieve_prediction(id)
10
- Replicate::Record::Prediction.new(self, get("predictions/#{id}"))
10
+ response = api_endpoint.get("predictions/#{id}")
11
+ Replicate::Record::Prediction.new(self, response)
11
12
  end
12
13
 
13
14
  # Get a list of predictions
14
15
  # @see https://replicate.com/docs/reference/http#get-predictions
15
16
  def list_predictions(cursor = nil)
16
- response = get("predictions", cursor: cursor)
17
+ response = api_endpoint.get("predictions", cursor: cursor)
17
18
  response["results"].map! { |result| Replicate::Record::Prediction.new(self, result) }
18
19
  response
19
20
  end
@@ -22,13 +23,15 @@ module Replicate
22
23
  # @see https://replicate.com/docs/reference/http#create-prediction
23
24
  def create_prediction(params)
24
25
  params[:webhook_completed] ||= webhook_url
25
- Replicate::Record::Prediction.new(self, post("predictions", params))
26
+ response = api_endpoint.post("predictions", params)
27
+ Replicate::Record::Prediction.new(self, response)
26
28
  end
27
29
 
28
30
  # Cancel a prediction
29
31
  # @see https://replicate.com/docs/reference/http#cancel-prediction
30
32
  def cancel_prediction(id)
31
- Replicate::Record::Prediction.new(self, post("predictions/#{id}/cancel"))
33
+ response = api_endpont.post("predictions/#{id}/cancel")
34
+ Replicate::Record::Prediction.new(self, response)
32
35
  end
33
36
  end
34
37
  end
@@ -0,0 +1,23 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Replicate
4
+ class Client
5
+ # Methods for the Prediction API
6
+ module Training
7
+ # Get a training
8
+ # @see https://replicate.com/blog/dreambooth-api
9
+ def retrieve_training(id)
10
+ response = dreambooth_endpoint.get("trainings/#{id}")
11
+ Replicate::Record::Training.new(self, response)
12
+ end
13
+
14
+ # Create a training
15
+ # @see https://replicate.com/blog/dreambooth-api
16
+ def create_training(params)
17
+ params[:webhook_completed] ||= webhook_url
18
+ response = dreambooth_endpoint.post("trainings", params)
19
+ Replicate::Record::Training.new(self, response)
20
+ end
21
+ end
22
+ end
23
+ end
@@ -0,0 +1,27 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Replicate
4
+ class Client
5
+ # Methods for the Prediction API
6
+ module Upload
7
+ # Create an upload
8
+ # @see https://replicate.com/blog/dreambooth-api
9
+ def create_upload
10
+ response = dreambooth_endpoint.post("upload/data.zip")
11
+ Replicate::Record::Upload.new(self, response)
12
+ end
13
+
14
+ # Create an upload
15
+ # @see https://replicate.com/blog/dreambooth-api
16
+ def update_upload(upload_endpoint_url, zip_path)
17
+ endpoint = Replicate::Endpoint.new(endpoint_url: upload_endpoint_url, api_token: nil)
18
+ endpoint.agent.put do |req|
19
+ req.headers["Content-Type"] = "application/zip"
20
+ req.headers["Content-Length"] = File.size(zip_path).to_s
21
+ req.headers["Transfer-Encoding"] = "chunked"
22
+ req.body = Faraday::UploadIO.new(zip_path, 'application/zip')
23
+ end
24
+ end
25
+ end
26
+ end
27
+ end
@@ -1,18 +1,21 @@
1
1
  # frozen_string_literal: true
2
2
 
3
3
  require "replicate/configurable"
4
- require "replicate/connection"
4
+ require "replicate/endpoint"
5
5
 
6
6
  require "replicate/client/model"
7
7
  require "replicate/client/prediction"
8
+ require "replicate/client/upload"
9
+ require "replicate/client/training"
8
10
 
9
11
  module Replicate
10
12
  class Client
11
13
  include Replicate::Configurable
12
- include Replicate::Connection
13
14
 
14
15
  include Replicate::Client::Model
15
16
  include Replicate::Client::Prediction
17
+ include Replicate::Client::Upload
18
+ include Replicate::Client::Training
16
19
 
17
20
  def initialize(options = {})
18
21
  # Use options passed in, but fall back to module defaults
@@ -21,5 +24,13 @@ module Replicate
21
24
  instance_variable_set(:"@#{key}", value)
22
25
  end
23
26
  end
27
+
28
+ def api_endpoint
29
+ @api_endpoint ||= Replicate::Endpoint.new(endpoint_url: api_endpoint_url, api_token: api_token)
30
+ end
31
+
32
+ def dreambooth_endpoint
33
+ @dreambooth_endpoint ||= Replicate::Endpoint.new(endpoint_url: dreambooth_endpoint_url, api_token: api_token)
34
+ end
24
35
  end
25
36
  end
@@ -3,13 +3,13 @@
3
3
  module Replicate
4
4
  module Configurable
5
5
  attr_accessor :api_token, :webhook_url
6
- attr_writer :api_endpoint
6
+ attr_writer :api_endpoint_url, :dreambooth_endpoint_url
7
7
 
8
8
  class << self
9
9
  # List of configurable keys for {Datatrans::Client}
10
10
  # @return [Array] of option keys
11
11
  def keys
12
- @keys ||= %i[api_token api_endpoint webhook_url]
12
+ @keys ||= %i[api_token api_endpoint_url dreambooth_endpoint_url webhook_url]
13
13
  end
14
14
  end
15
15
 
@@ -19,8 +19,12 @@ module Replicate
19
19
  end
20
20
 
21
21
  # API endpoint methods
22
- def api_endpoint
23
- @api_endpoint ||= "https://api.replicate.com/v1"
22
+ def api_endpoint_url
23
+ @api_endpoint_url ||= "https://api.replicate.com/v1"
24
+ end
25
+
26
+ def dreambooth_endpoint_url
27
+ @dreambooth_endpoint_url ||= "https://dreambooth-api-experimental.replicate.com/v1"
24
28
  end
25
29
 
26
30
  private
@@ -3,16 +3,19 @@
3
3
  require "faraday"
4
4
  require "faraday/net_http"
5
5
  require "faraday/retry"
6
+ require "faraday/multipart"
6
7
  require "addressable/uri"
7
8
 
8
9
  module Replicate
9
10
  # Network layer for API clients.
10
- module Connection
11
- DEFAULT_MEDIA_TYPE = "application/json"
12
- USER_AGENT = "Datatrans Ruby Gem"
11
+ class Endpoint
12
+ attr_reader :endpoint_url, :api_token, :content_type
13
13
 
14
- # Header keys that can be passed in options hash to {#get},{#head}
15
- CONVENIENCE_HEADERS = Set.new(%i[accept content_type])
14
+ def initialize(endpoint_url:, api_token:, content_type: 'application/json')
15
+ @endpoint_url = endpoint_url
16
+ @api_token = api_token
17
+ @content_type = content_type
18
+ end
16
19
 
17
20
  # Make a HTTP GET request
18
21
  #
@@ -72,11 +75,10 @@ module Replicate
72
75
  #
73
76
  # @return [Sawyer::Agent]
74
77
  def agent
75
- @agent ||= Faraday.new(url: endpoint) do |conn|
78
+ @agent ||= Faraday.new(url: endpoint_url) do |conn|
76
79
  conn.request :retry
77
- conn.request :authorization, 'Token', api_token
78
- conn.headers["Content-Type"] = DEFAULT_MEDIA_TYPE
79
- conn.headers["Accept"] = DEFAULT_MEDIA_TYPE
80
+ conn.request :authorization, 'Token', api_token if api_token
81
+ conn.headers["Content-Type"] = content_type
80
82
 
81
83
  conn.adapter :net_http
82
84
  end
@@ -89,12 +91,6 @@ module Replicate
89
91
  @last_response if defined? @last_response
90
92
  end
91
93
 
92
- protected
93
-
94
- def endpoint
95
- api_endpoint
96
- end
97
-
98
94
  private
99
95
 
100
96
  def request(method, path, data)
@@ -103,7 +99,12 @@ module Replicate
103
99
  when 400
104
100
  raise Error, "#{@last_response.status} #{@last_response.reason_phrase}: #{JSON.parse(@last_response.body)}"
105
101
  else
106
- JSON.parse(@last_response.body)
102
+ case content_type
103
+ when 'application/json'
104
+ JSON.parse(@last_response.body)
105
+ else
106
+ @last_response.body
107
+ end
107
108
  end
108
109
  end
109
110
  end
@@ -0,0 +1,11 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Replicate
4
+ module Record
5
+ class Training < Base
6
+ def refetch
7
+ @data = client.retrieve_training(id).data
8
+ end
9
+ end
10
+ end
11
+ end
@@ -0,0 +1,11 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Replicate
4
+ module Record
5
+ class Upload < Base
6
+ def attach(path)
7
+ client.update_upload(upload_url, path)
8
+ end
9
+ end
10
+ end
11
+ end
@@ -1,5 +1,5 @@
1
1
  # frozen_string_literal: true
2
2
 
3
3
  module Replicate
4
- VERSION = "0.1.7"
4
+ VERSION = "0.2.0"
5
5
  end
data/lib/replicate.rb CHANGED
@@ -7,6 +7,8 @@ require "replicate/record/base"
7
7
  require "replicate/record/model"
8
8
  require "replicate/record/model_version"
9
9
  require "replicate/record/prediction"
10
+ require "replicate/record/upload"
11
+ require "replicate/record/training"
10
12
 
11
13
  module Replicate
12
14
  class Error < StandardError; end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: replicate-ruby
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.7
4
+ version: 0.2.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Dreaming Tulpa
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2022-12-19 00:00:00.000000000 Z
11
+ date: 2023-03-08 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: faraday
@@ -38,6 +38,20 @@ dependencies:
38
38
  - - ">="
39
39
  - !ruby/object:Gem::Version
40
40
  version: '0'
41
+ - !ruby/object:Gem::Dependency
42
+ name: faraday-multipart
43
+ requirement: !ruby/object:Gem::Requirement
44
+ requirements:
45
+ - - ">="
46
+ - !ruby/object:Gem::Version
47
+ version: '0'
48
+ type: :runtime
49
+ prerelease: false
50
+ version_requirements: !ruby/object:Gem::Requirement
51
+ requirements:
52
+ - - ">="
53
+ - !ruby/object:Gem::Version
54
+ version: '0'
41
55
  - !ruby/object:Gem::Dependency
42
56
  name: addressable
43
57
  requirement: !ruby/object:Gem::Requirement
@@ -80,6 +94,20 @@ dependencies:
80
94
  - - ">="
81
95
  - !ruby/object:Gem::Version
82
96
  version: '0'
97
+ - !ruby/object:Gem::Dependency
98
+ name: byebug
99
+ requirement: !ruby/object:Gem::Requirement
100
+ requirements:
101
+ - - ">="
102
+ - !ruby/object:Gem::Version
103
+ version: '0'
104
+ type: :development
105
+ prerelease: false
106
+ version_requirements: !ruby/object:Gem::Requirement
107
+ requirements:
108
+ - - ">="
109
+ - !ruby/object:Gem::Version
110
+ version: '0'
83
111
  description:
84
112
  email:
85
113
  - hey@dreamingtulpa.com
@@ -93,16 +121,21 @@ files:
93
121
  - Gemfile.lock
94
122
  - README.md
95
123
  - Rakefile
124
+ - data.zip
96
125
  - lib/replicate.rb
97
126
  - lib/replicate/client.rb
98
127
  - lib/replicate/client/model.rb
99
128
  - lib/replicate/client/prediction.rb
129
+ - lib/replicate/client/training.rb
130
+ - lib/replicate/client/upload.rb
100
131
  - lib/replicate/configurable.rb
101
- - lib/replicate/connection.rb
132
+ - lib/replicate/endpoint.rb
102
133
  - lib/replicate/record/base.rb
103
134
  - lib/replicate/record/model.rb
104
135
  - lib/replicate/record/model_version.rb
105
136
  - lib/replicate/record/prediction.rb
137
+ - lib/replicate/record/training.rb
138
+ - lib/replicate/record/upload.rb
106
139
  - lib/replicate/version.rb
107
140
  - sig/replicate/ruby.rbs
108
141
  homepage: https://github.com/dreamingtulpa/replicate-ruby
@@ -126,7 +159,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
126
159
  - !ruby/object:Gem::Version
127
160
  version: '0'
128
161
  requirements: []
129
- rubygems_version: 3.3.26
162
+ rubygems_version: 3.4.6
130
163
  signing_key:
131
164
  specification_version: 4
132
165
  summary: Ruby client for Replicate