replicate-ruby 0.1.6 → 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 1c51346f40acaa0f9e2f0c1ac97630fa2908d5e87d94de2d27196f2b33ec6135
4
- data.tar.gz: 7559f35badbfd904058c70f3561d45793b6f3290e2905500902b755d2c45aeb2
3
+ metadata.gz: 82385c582239598bb4b62a3e133228b7e6fc417ea57a047c1489e5d94d787faf
4
+ data.tar.gz: 4ccecfdb33f1a4cf3218587c7ed4fa28475da1ca13369e59d5875b07e4f092f8
5
5
  SHA512:
6
- metadata.gz: fb6c205ce5e7cd48decf8ea79067fe4eaa6942ee00c50b85a8efb82d549ae2deec969591b8ddee302b424d583d5e20f3d2c9f860e10bde8d51d633b6c9564036
7
- data.tar.gz: 6d4255676a72b959ddb9ba3280adf181e6e6bbdd9012ac0085e0173bedb63e97a3458f2e8201dd22854fda4f10967ecb075f428759b67b31fd9550ad689e0fba
6
+ metadata.gz: c8edef78b1309f5873997b33b1de5de82f2f8982ce2ec3ed32b0cdeca0e9038c77b4d56e6173441d106160e77fb1e0d2311024d5cae2c68a1fc91ef6c62cb725
7
+ data.tar.gz: 56958524475fbd6636e1eb3f409b4dad100fe8179a982a746a0ee2ea4356ffa66724269a9da0e226e374bbe34c2526aacf3a358e5847ed55429ebf06700f2383
data/Gemfile.lock CHANGED
@@ -1,9 +1,10 @@
1
1
  PATH
2
2
  remote: .
3
3
  specs:
4
- replicate-ruby (0.1.6)
4
+ replicate-ruby (0.1.7)
5
5
  addressable
6
6
  faraday (>= 2.0)
7
+ faraday-multipart
7
8
  faraday-retry
8
9
 
9
10
  GEM
@@ -12,12 +13,15 @@ GEM
12
13
  addressable (2.8.1)
13
14
  public_suffix (>= 2.0.2, < 6.0)
14
15
  ast (2.4.2)
16
+ byebug (11.1.3)
15
17
  coderay (1.1.3)
16
18
  crack (0.4.5)
17
19
  rexml
18
20
  faraday (2.7.2)
19
21
  faraday-net_http (>= 2.0, < 3.1)
20
22
  ruby2_keywords (>= 0.0.4)
23
+ faraday-multipart (1.0.4)
24
+ multipart-post (~> 2)
21
25
  faraday-net_http (3.0.2)
22
26
  faraday-retry (2.0.0)
23
27
  faraday (~> 2.0)
@@ -25,6 +29,7 @@ GEM
25
29
  json (2.6.2)
26
30
  method_source (1.0.0)
27
31
  minitest (5.16.3)
32
+ multipart-post (2.2.3)
28
33
  parallel (1.22.1)
29
34
  parser (3.1.2.1)
30
35
  ast (~> 2.4.1)
@@ -57,9 +62,11 @@ GEM
57
62
  hashdiff (>= 0.4.0, < 2.0.0)
58
63
 
59
64
  PLATFORMS
65
+ arm64-darwin-22
60
66
  x86_64-darwin-21
61
67
 
62
68
  DEPENDENCIES
69
+ byebug
63
70
  minitest (~> 5.0)
64
71
  pry
65
72
  rake (~> 13.0)
data/README.md CHANGED
@@ -54,6 +54,42 @@ id = prediction.id # store prediction id in your backend
54
54
  prediction = Replicate.client.retrieve_prediction(id) # retrieve prediction during webhook with id from backend
55
55
  ```
56
56
 
57
+ ## Dreambooth
58
+
59
+ There is support for the [experimental dreambooth endpoint](https://replicate.com/blog/dreambooth-api).
60
+
61
+ First, upload your training dataset:
62
+
63
+ ```
64
+ upload = Replicate.client.create_upload
65
+ upload.attach('tmp/data.zip') # replace with the path to your zip file
66
+ ```
67
+
68
+ Then start training a new model using, for instance:
69
+
70
+ ```
71
+ training = Replicate.client.create_training(
72
+ input: {
73
+ instance_prompt: "zwx style",
74
+ class_prompt: "style",
75
+ instance_data: upload.serving_url,
76
+ max_train_steps: 5000
77
+ },
78
+ model: 'yourusername/yourmodel'
79
+ )
80
+ ```
81
+
82
+ As soon as the model has finished training, you can run predictions on it:
83
+
84
+ ```
85
+ prediction = Replicate.client.create_prediction(
86
+ input: {
87
+ prompt: 'your prompt, zwx style'
88
+ },
89
+ version: training.version
90
+ )
91
+ ```
92
+
57
93
  ## Development
58
94
 
59
95
  After checking out the repo, run `bin/setup` to install dependencies. Then, run `rake test` to run the tests. You can also run `bin/console` for an interactive prompt that will allow you to experiment.
data/data.zip ADDED
File without changes
@@ -9,20 +9,22 @@ module Replicate
9
9
  def retrieve_model(model, version: :latest)
10
10
  case version
11
11
  when :latest
12
- Replicate::Record::Model.new(self, get("models/#{model}"))
12
+ response = api_endpoint.get("models/#{model}")
13
+ Replicate::Record::Model.new(self, response)
13
14
  when :all
14
- response = get("models/#{model}/versions")
15
+ response = api_endpoint.get("models/#{model}/versions")
15
16
  response["results"].map! { |result| Replicate::Record::ModelVersion.new(self, result) }
16
17
  response
17
18
  else
18
- Replicate::Record::ModelVersion.new(self, get("models/#{model}/versions/#{version}"))
19
+ response = api_endpoint.get("models/#{model}/versions/#{version}")
20
+ Replicate::Record::ModelVersion.new(self, response)
19
21
  end
20
22
  end
21
23
 
22
24
  # Get a collection of models
23
25
  # @see https://replicate.com/docs/reference/http#get-collection
24
26
  def retrieve_collection(slug)
25
- get("collections/#{slug}")
27
+ api_endpoint.get("collections/#{slug}")
26
28
  end
27
29
  end
28
30
  end
@@ -7,13 +7,14 @@ module Replicate
7
7
  # Get a prediction
8
8
  # @see https://replicate.com/docs/reference/http#get-prediction
9
9
  def retrieve_prediction(id)
10
- Replicate::Record::Prediction.new(self, get("predictions/#{id}"))
10
+ response = api_endpoint.get("predictions/#{id}")
11
+ Replicate::Record::Prediction.new(self, response)
11
12
  end
12
13
 
13
14
  # Get a list of predictions
14
15
  # @see https://replicate.com/docs/reference/http#get-predictions
15
16
  def list_predictions(cursor = nil)
16
- response = get("predictions", cursor: cursor)
17
+ response = api_endpoint.get("predictions", cursor: cursor)
17
18
  response["results"].map! { |result| Replicate::Record::Prediction.new(self, result) }
18
19
  response
19
20
  end
@@ -22,13 +23,15 @@ module Replicate
22
23
  # @see https://replicate.com/docs/reference/http#create-prediction
23
24
  def create_prediction(params)
24
25
  params[:webhook_completed] ||= webhook_url
25
- Replicate::Record::Prediction.new(self, post("predictions", params))
26
+ response = api_endpoint.post("predictions", params)
27
+ Replicate::Record::Prediction.new(self, response)
26
28
  end
27
29
 
28
30
  # Cancel a prediction
29
31
  # @see https://replicate.com/docs/reference/http#cancel-prediction
30
32
  def cancel_prediction(id)
31
- Replicate::Record::Prediction.new(self, post("predictions/#{id}/cancel"))
33
+ response = api_endpont.post("predictions/#{id}/cancel")
34
+ Replicate::Record::Prediction.new(self, response)
32
35
  end
33
36
  end
34
37
  end
@@ -0,0 +1,23 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Replicate
4
+ class Client
5
+ # Methods for the Prediction API
6
+ module Training
7
+ # Get a training
8
+ # @see https://replicate.com/blog/dreambooth-api
9
+ def retrieve_training(id)
10
+ response = dreambooth_endpoint.get("trainings/#{id}")
11
+ Replicate::Record::Training.new(self, response)
12
+ end
13
+
14
+ # Create a training
15
+ # @see https://replicate.com/blog/dreambooth-api
16
+ def create_training(params)
17
+ params[:webhook_completed] ||= webhook_url
18
+ response = dreambooth_endpoint.post("trainings", params)
19
+ Replicate::Record::Training.new(self, response)
20
+ end
21
+ end
22
+ end
23
+ end
@@ -0,0 +1,27 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Replicate
4
+ class Client
5
+ # Methods for the Prediction API
6
+ module Upload
7
+ # Create an upload
8
+ # @see https://replicate.com/blog/dreambooth-api
9
+ def create_upload
10
+ response = dreambooth_endpoint.post("upload/data.zip")
11
+ Replicate::Record::Upload.new(self, response)
12
+ end
13
+
14
+ # Create an upload
15
+ # @see https://replicate.com/blog/dreambooth-api
16
+ def update_upload(upload_endpoint_url, zip_path)
17
+ endpoint = Replicate::Endpoint.new(endpoint_url: upload_endpoint_url, api_token: nil)
18
+ endpoint.agent.put do |req|
19
+ req.headers["Content-Type"] = "application/zip"
20
+ req.headers["Content-Length"] = File.size(zip_path).to_s
21
+ req.headers["Transfer-Encoding"] = "chunked"
22
+ req.body = Faraday::UploadIO.new(zip_path, 'application/zip')
23
+ end
24
+ end
25
+ end
26
+ end
27
+ end
@@ -1,18 +1,21 @@
1
1
  # frozen_string_literal: true
2
2
 
3
3
  require "replicate/configurable"
4
- require "replicate/connection"
4
+ require "replicate/endpoint"
5
5
 
6
6
  require "replicate/client/model"
7
7
  require "replicate/client/prediction"
8
+ require "replicate/client/upload"
9
+ require "replicate/client/training"
8
10
 
9
11
  module Replicate
10
12
  class Client
11
13
  include Replicate::Configurable
12
- include Replicate::Connection
13
14
 
14
15
  include Replicate::Client::Model
15
16
  include Replicate::Client::Prediction
17
+ include Replicate::Client::Upload
18
+ include Replicate::Client::Training
16
19
 
17
20
  def initialize(options = {})
18
21
  # Use options passed in, but fall back to module defaults
@@ -21,5 +24,13 @@ module Replicate
21
24
  instance_variable_set(:"@#{key}", value)
22
25
  end
23
26
  end
27
+
28
+ def api_endpoint
29
+ @api_endpoint ||= Replicate::Endpoint.new(endpoint_url: api_endpoint_url, api_token: api_token)
30
+ end
31
+
32
+ def dreambooth_endpoint
33
+ @dreambooth_endpoint ||= Replicate::Endpoint.new(endpoint_url: dreambooth_endpoint_url, api_token: api_token)
34
+ end
24
35
  end
25
36
  end
@@ -3,13 +3,13 @@
3
3
  module Replicate
4
4
  module Configurable
5
5
  attr_accessor :api_token, :webhook_url
6
- attr_writer :api_endpoint
6
+ attr_writer :api_endpoint_url, :dreambooth_endpoint_url
7
7
 
8
8
  class << self
9
9
  # List of configurable keys for {Datatrans::Client}
10
10
  # @return [Array] of option keys
11
11
  def keys
12
- @keys ||= %i[api_token api_endpoint webhook_url]
12
+ @keys ||= %i[api_token api_endpoint_url dreambooth_endpoint_url webhook_url]
13
13
  end
14
14
  end
15
15
 
@@ -19,8 +19,12 @@ module Replicate
19
19
  end
20
20
 
21
21
  # API endpoint methods
22
- def api_endpoint
23
- @api_endpoint ||= "https://api.replicate.com/v1"
22
+ def api_endpoint_url
23
+ @api_endpoint_url ||= "https://api.replicate.com/v1"
24
+ end
25
+
26
+ def dreambooth_endpoint_url
27
+ @dreambooth_endpoint_url ||= "https://dreambooth-api-experimental.replicate.com/v1"
24
28
  end
25
29
 
26
30
  private
@@ -3,16 +3,19 @@
3
3
  require "faraday"
4
4
  require "faraday/net_http"
5
5
  require "faraday/retry"
6
+ require "faraday/multipart"
6
7
  require "addressable/uri"
7
8
 
8
9
  module Replicate
9
10
  # Network layer for API clients.
10
- module Connection
11
- DEFAULT_MEDIA_TYPE = "application/json"
12
- USER_AGENT = "Datatrans Ruby Gem"
11
+ class Endpoint
12
+ attr_reader :endpoint_url, :api_token, :content_type
13
13
 
14
- # Header keys that can be passed in options hash to {#get},{#head}
15
- CONVENIENCE_HEADERS = Set.new(%i[accept content_type])
14
+ def initialize(endpoint_url:, api_token:, content_type: 'application/json')
15
+ @endpoint_url = endpoint_url
16
+ @api_token = api_token
17
+ @content_type = content_type
18
+ end
16
19
 
17
20
  # Make a HTTP GET request
18
21
  #
@@ -72,11 +75,10 @@ module Replicate
72
75
  #
73
76
  # @return [Sawyer::Agent]
74
77
  def agent
75
- @agent ||= Faraday.new(url: endpoint) do |conn|
78
+ @agent ||= Faraday.new(url: endpoint_url) do |conn|
76
79
  conn.request :retry
77
- conn.request :authorization, 'Token', api_token
78
- conn.headers["Content-Type"] = DEFAULT_MEDIA_TYPE
79
- conn.headers["Accept"] = DEFAULT_MEDIA_TYPE
80
+ conn.request :authorization, 'Token', api_token if api_token
81
+ conn.headers["Content-Type"] = content_type
80
82
 
81
83
  conn.adapter :net_http
82
84
  end
@@ -89,12 +91,6 @@ module Replicate
89
91
  @last_response if defined? @last_response
90
92
  end
91
93
 
92
- protected
93
-
94
- def endpoint
95
- api_endpoint
96
- end
97
-
98
94
  private
99
95
 
100
96
  def request(method, path, data)
@@ -103,7 +99,12 @@ module Replicate
103
99
  when 400
104
100
  raise Error, "#{@last_response.status} #{@last_response.reason_phrase}: #{JSON.parse(@last_response.body)}"
105
101
  else
106
- JSON.parse(@last_response.body)
102
+ case content_type
103
+ when 'application/json'
104
+ JSON.parse(@last_response.body)
105
+ else
106
+ @last_response.body
107
+ end
107
108
  end
108
109
  end
109
110
  end
@@ -19,6 +19,26 @@ module Replicate
19
19
  false
20
20
  end
21
21
  end
22
+
23
+ def starting?
24
+ status == "starting"
25
+ end
26
+
27
+ def processing?
28
+ status == "processing"
29
+ end
30
+
31
+ def succeeded?
32
+ status == "succeeded"
33
+ end
34
+
35
+ def failed?
36
+ status == "failed"
37
+ end
38
+
39
+ def canceled?
40
+ status == "canceled"
41
+ end
22
42
  end
23
43
  end
24
44
  end
@@ -0,0 +1,11 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Replicate
4
+ module Record
5
+ class Training < Base
6
+ def refetch
7
+ @data = client.retrieve_training(id).data
8
+ end
9
+ end
10
+ end
11
+ end
@@ -0,0 +1,11 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Replicate
4
+ module Record
5
+ class Upload < Base
6
+ def attach(path)
7
+ client.update_upload(upload_url, path)
8
+ end
9
+ end
10
+ end
11
+ end
@@ -1,5 +1,5 @@
1
1
  # frozen_string_literal: true
2
2
 
3
3
  module Replicate
4
- VERSION = "0.1.6"
4
+ VERSION = "0.2.0"
5
5
  end
data/lib/replicate.rb CHANGED
@@ -7,6 +7,8 @@ require "replicate/record/base"
7
7
  require "replicate/record/model"
8
8
  require "replicate/record/model_version"
9
9
  require "replicate/record/prediction"
10
+ require "replicate/record/upload"
11
+ require "replicate/record/training"
10
12
 
11
13
  module Replicate
12
14
  class Error < StandardError; end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: replicate-ruby
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.6
4
+ version: 0.2.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Dreaming Tulpa
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2022-12-19 00:00:00.000000000 Z
11
+ date: 2023-03-08 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: faraday
@@ -38,6 +38,20 @@ dependencies:
38
38
  - - ">="
39
39
  - !ruby/object:Gem::Version
40
40
  version: '0'
41
+ - !ruby/object:Gem::Dependency
42
+ name: faraday-multipart
43
+ requirement: !ruby/object:Gem::Requirement
44
+ requirements:
45
+ - - ">="
46
+ - !ruby/object:Gem::Version
47
+ version: '0'
48
+ type: :runtime
49
+ prerelease: false
50
+ version_requirements: !ruby/object:Gem::Requirement
51
+ requirements:
52
+ - - ">="
53
+ - !ruby/object:Gem::Version
54
+ version: '0'
41
55
  - !ruby/object:Gem::Dependency
42
56
  name: addressable
43
57
  requirement: !ruby/object:Gem::Requirement
@@ -80,6 +94,20 @@ dependencies:
80
94
  - - ">="
81
95
  - !ruby/object:Gem::Version
82
96
  version: '0'
97
+ - !ruby/object:Gem::Dependency
98
+ name: byebug
99
+ requirement: !ruby/object:Gem::Requirement
100
+ requirements:
101
+ - - ">="
102
+ - !ruby/object:Gem::Version
103
+ version: '0'
104
+ type: :development
105
+ prerelease: false
106
+ version_requirements: !ruby/object:Gem::Requirement
107
+ requirements:
108
+ - - ">="
109
+ - !ruby/object:Gem::Version
110
+ version: '0'
83
111
  description:
84
112
  email:
85
113
  - hey@dreamingtulpa.com
@@ -93,16 +121,21 @@ files:
93
121
  - Gemfile.lock
94
122
  - README.md
95
123
  - Rakefile
124
+ - data.zip
96
125
  - lib/replicate.rb
97
126
  - lib/replicate/client.rb
98
127
  - lib/replicate/client/model.rb
99
128
  - lib/replicate/client/prediction.rb
129
+ - lib/replicate/client/training.rb
130
+ - lib/replicate/client/upload.rb
100
131
  - lib/replicate/configurable.rb
101
- - lib/replicate/connection.rb
132
+ - lib/replicate/endpoint.rb
102
133
  - lib/replicate/record/base.rb
103
134
  - lib/replicate/record/model.rb
104
135
  - lib/replicate/record/model_version.rb
105
136
  - lib/replicate/record/prediction.rb
137
+ - lib/replicate/record/training.rb
138
+ - lib/replicate/record/upload.rb
106
139
  - lib/replicate/version.rb
107
140
  - sig/replicate/ruby.rbs
108
141
  homepage: https://github.com/dreamingtulpa/replicate-ruby
@@ -126,7 +159,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
126
159
  - !ruby/object:Gem::Version
127
160
  version: '0'
128
161
  requirements: []
129
- rubygems_version: 3.3.26
162
+ rubygems_version: 3.4.6
130
163
  signing_key:
131
164
  specification_version: 4
132
165
  summary: Ruby client for Replicate