replicate-ruby 0.1.6 → 0.2.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/Gemfile.lock +8 -1
- data/README.md +36 -0
- data/data.zip +0 -0
- data/lib/replicate/client/model.rb +6 -4
- data/lib/replicate/client/prediction.rb +7 -4
- data/lib/replicate/client/training.rb +23 -0
- data/lib/replicate/client/upload.rb +27 -0
- data/lib/replicate/client.rb +13 -2
- data/lib/replicate/configurable.rb +8 -4
- data/lib/replicate/{connection.rb → endpoint.rb} +17 -16
- data/lib/replicate/record/prediction.rb +20 -0
- data/lib/replicate/record/training.rb +11 -0
- data/lib/replicate/record/upload.rb +11 -0
- data/lib/replicate/version.rb +1 -1
- data/lib/replicate.rb +2 -0
- metadata +37 -4
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 82385c582239598bb4b62a3e133228b7e6fc417ea57a047c1489e5d94d787faf
|
4
|
+
data.tar.gz: 4ccecfdb33f1a4cf3218587c7ed4fa28475da1ca13369e59d5875b07e4f092f8
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: c8edef78b1309f5873997b33b1de5de82f2f8982ce2ec3ed32b0cdeca0e9038c77b4d56e6173441d106160e77fb1e0d2311024d5cae2c68a1fc91ef6c62cb725
|
7
|
+
data.tar.gz: 56958524475fbd6636e1eb3f409b4dad100fe8179a982a746a0ee2ea4356ffa66724269a9da0e226e374bbe34c2526aacf3a358e5847ed55429ebf06700f2383
|
data/Gemfile.lock
CHANGED
@@ -1,9 +1,10 @@
|
|
1
1
|
PATH
|
2
2
|
remote: .
|
3
3
|
specs:
|
4
|
-
replicate-ruby (0.1.
|
4
|
+
replicate-ruby (0.1.7)
|
5
5
|
addressable
|
6
6
|
faraday (>= 2.0)
|
7
|
+
faraday-multipart
|
7
8
|
faraday-retry
|
8
9
|
|
9
10
|
GEM
|
@@ -12,12 +13,15 @@ GEM
|
|
12
13
|
addressable (2.8.1)
|
13
14
|
public_suffix (>= 2.0.2, < 6.0)
|
14
15
|
ast (2.4.2)
|
16
|
+
byebug (11.1.3)
|
15
17
|
coderay (1.1.3)
|
16
18
|
crack (0.4.5)
|
17
19
|
rexml
|
18
20
|
faraday (2.7.2)
|
19
21
|
faraday-net_http (>= 2.0, < 3.1)
|
20
22
|
ruby2_keywords (>= 0.0.4)
|
23
|
+
faraday-multipart (1.0.4)
|
24
|
+
multipart-post (~> 2)
|
21
25
|
faraday-net_http (3.0.2)
|
22
26
|
faraday-retry (2.0.0)
|
23
27
|
faraday (~> 2.0)
|
@@ -25,6 +29,7 @@ GEM
|
|
25
29
|
json (2.6.2)
|
26
30
|
method_source (1.0.0)
|
27
31
|
minitest (5.16.3)
|
32
|
+
multipart-post (2.2.3)
|
28
33
|
parallel (1.22.1)
|
29
34
|
parser (3.1.2.1)
|
30
35
|
ast (~> 2.4.1)
|
@@ -57,9 +62,11 @@ GEM
|
|
57
62
|
hashdiff (>= 0.4.0, < 2.0.0)
|
58
63
|
|
59
64
|
PLATFORMS
|
65
|
+
arm64-darwin-22
|
60
66
|
x86_64-darwin-21
|
61
67
|
|
62
68
|
DEPENDENCIES
|
69
|
+
byebug
|
63
70
|
minitest (~> 5.0)
|
64
71
|
pry
|
65
72
|
rake (~> 13.0)
|
data/README.md
CHANGED
@@ -54,6 +54,42 @@ id = prediction.id # store prediction id in your backend
|
|
54
54
|
prediction = Replicate.client.retrieve_prediction(id) # retrieve prediction during webhook with id from backend
|
55
55
|
```
|
56
56
|
|
57
|
+
## Dreambooth
|
58
|
+
|
59
|
+
There is support for the [experimental dreambooth endpoint](https://replicate.com/blog/dreambooth-api).
|
60
|
+
|
61
|
+
First, upload your training dataset:
|
62
|
+
|
63
|
+
```
|
64
|
+
upload = Replicate.client.create_upload
|
65
|
+
upload.attach('tmp/data.zip') # replace with the path to your zip file
|
66
|
+
```
|
67
|
+
|
68
|
+
Then start training a new model using, for instance:
|
69
|
+
|
70
|
+
```
|
71
|
+
training = Replicate.client.create_training(
|
72
|
+
input: {
|
73
|
+
instance_prompt: "zwx style",
|
74
|
+
class_prompt: "style",
|
75
|
+
instance_data: upload.serving_url,
|
76
|
+
max_train_steps: 5000
|
77
|
+
},
|
78
|
+
model: 'yourusername/yourmodel'
|
79
|
+
)
|
80
|
+
```
|
81
|
+
|
82
|
+
As soon as the model has finished training, you can run predictions on it:
|
83
|
+
|
84
|
+
```
|
85
|
+
prediction = Replicate.client.create_prediction(
|
86
|
+
input: {
|
87
|
+
prompt: 'your prompt, zwx style'
|
88
|
+
},
|
89
|
+
version: training.version
|
90
|
+
)
|
91
|
+
```
|
92
|
+
|
57
93
|
## Development
|
58
94
|
|
59
95
|
After checking out the repo, run `bin/setup` to install dependencies. Then, run `rake test` to run the tests. You can also run `bin/console` for an interactive prompt that will allow you to experiment.
|
data/data.zip
ADDED
File without changes
|
@@ -9,20 +9,22 @@ module Replicate
|
|
9
9
|
def retrieve_model(model, version: :latest)
|
10
10
|
case version
|
11
11
|
when :latest
|
12
|
-
|
12
|
+
response = api_endpoint.get("models/#{model}")
|
13
|
+
Replicate::Record::Model.new(self, response)
|
13
14
|
when :all
|
14
|
-
response = get("models/#{model}/versions")
|
15
|
+
response = api_endpoint.get("models/#{model}/versions")
|
15
16
|
response["results"].map! { |result| Replicate::Record::ModelVersion.new(self, result) }
|
16
17
|
response
|
17
18
|
else
|
18
|
-
|
19
|
+
response = api_endpoint.get("models/#{model}/versions/#{version}")
|
20
|
+
Replicate::Record::ModelVersion.new(self, response)
|
19
21
|
end
|
20
22
|
end
|
21
23
|
|
22
24
|
# Get a collection of models
|
23
25
|
# @see https://replicate.com/docs/reference/http#get-collection
|
24
26
|
def retrieve_collection(slug)
|
25
|
-
get("collections/#{slug}")
|
27
|
+
api_endpoint.get("collections/#{slug}")
|
26
28
|
end
|
27
29
|
end
|
28
30
|
end
|
@@ -7,13 +7,14 @@ module Replicate
|
|
7
7
|
# Get a prediction
|
8
8
|
# @see https://replicate.com/docs/reference/http#get-prediction
|
9
9
|
def retrieve_prediction(id)
|
10
|
-
|
10
|
+
response = api_endpoint.get("predictions/#{id}")
|
11
|
+
Replicate::Record::Prediction.new(self, response)
|
11
12
|
end
|
12
13
|
|
13
14
|
# Get a list of predictions
|
14
15
|
# @see https://replicate.com/docs/reference/http#get-predictions
|
15
16
|
def list_predictions(cursor = nil)
|
16
|
-
response = get("predictions", cursor: cursor)
|
17
|
+
response = api_endpoint.get("predictions", cursor: cursor)
|
17
18
|
response["results"].map! { |result| Replicate::Record::Prediction.new(self, result) }
|
18
19
|
response
|
19
20
|
end
|
@@ -22,13 +23,15 @@ module Replicate
|
|
22
23
|
# @see https://replicate.com/docs/reference/http#create-prediction
|
23
24
|
def create_prediction(params)
|
24
25
|
params[:webhook_completed] ||= webhook_url
|
25
|
-
|
26
|
+
response = api_endpoint.post("predictions", params)
|
27
|
+
Replicate::Record::Prediction.new(self, response)
|
26
28
|
end
|
27
29
|
|
28
30
|
# Cancel a prediction
|
29
31
|
# @see https://replicate.com/docs/reference/http#cancel-prediction
|
30
32
|
def cancel_prediction(id)
|
31
|
-
|
33
|
+
response = api_endpont.post("predictions/#{id}/cancel")
|
34
|
+
Replicate::Record::Prediction.new(self, response)
|
32
35
|
end
|
33
36
|
end
|
34
37
|
end
|
@@ -0,0 +1,23 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Replicate
|
4
|
+
class Client
|
5
|
+
# Methods for the Prediction API
|
6
|
+
module Training
|
7
|
+
# Get a training
|
8
|
+
# @see https://replicate.com/blog/dreambooth-api
|
9
|
+
def retrieve_training(id)
|
10
|
+
response = dreambooth_endpoint.get("trainings/#{id}")
|
11
|
+
Replicate::Record::Training.new(self, response)
|
12
|
+
end
|
13
|
+
|
14
|
+
# Create a training
|
15
|
+
# @see https://replicate.com/blog/dreambooth-api
|
16
|
+
def create_training(params)
|
17
|
+
params[:webhook_completed] ||= webhook_url
|
18
|
+
response = dreambooth_endpoint.post("trainings", params)
|
19
|
+
Replicate::Record::Training.new(self, response)
|
20
|
+
end
|
21
|
+
end
|
22
|
+
end
|
23
|
+
end
|
@@ -0,0 +1,27 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Replicate
|
4
|
+
class Client
|
5
|
+
# Methods for the Prediction API
|
6
|
+
module Upload
|
7
|
+
# Create an upload
|
8
|
+
# @see https://replicate.com/blog/dreambooth-api
|
9
|
+
def create_upload
|
10
|
+
response = dreambooth_endpoint.post("upload/data.zip")
|
11
|
+
Replicate::Record::Upload.new(self, response)
|
12
|
+
end
|
13
|
+
|
14
|
+
# Create an upload
|
15
|
+
# @see https://replicate.com/blog/dreambooth-api
|
16
|
+
def update_upload(upload_endpoint_url, zip_path)
|
17
|
+
endpoint = Replicate::Endpoint.new(endpoint_url: upload_endpoint_url, api_token: nil)
|
18
|
+
endpoint.agent.put do |req|
|
19
|
+
req.headers["Content-Type"] = "application/zip"
|
20
|
+
req.headers["Content-Length"] = File.size(zip_path).to_s
|
21
|
+
req.headers["Transfer-Encoding"] = "chunked"
|
22
|
+
req.body = Faraday::UploadIO.new(zip_path, 'application/zip')
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
data/lib/replicate/client.rb
CHANGED
@@ -1,18 +1,21 @@
|
|
1
1
|
# frozen_string_literal: true
|
2
2
|
|
3
3
|
require "replicate/configurable"
|
4
|
-
require "replicate/
|
4
|
+
require "replicate/endpoint"
|
5
5
|
|
6
6
|
require "replicate/client/model"
|
7
7
|
require "replicate/client/prediction"
|
8
|
+
require "replicate/client/upload"
|
9
|
+
require "replicate/client/training"
|
8
10
|
|
9
11
|
module Replicate
|
10
12
|
class Client
|
11
13
|
include Replicate::Configurable
|
12
|
-
include Replicate::Connection
|
13
14
|
|
14
15
|
include Replicate::Client::Model
|
15
16
|
include Replicate::Client::Prediction
|
17
|
+
include Replicate::Client::Upload
|
18
|
+
include Replicate::Client::Training
|
16
19
|
|
17
20
|
def initialize(options = {})
|
18
21
|
# Use options passed in, but fall back to module defaults
|
@@ -21,5 +24,13 @@ module Replicate
|
|
21
24
|
instance_variable_set(:"@#{key}", value)
|
22
25
|
end
|
23
26
|
end
|
27
|
+
|
28
|
+
def api_endpoint
|
29
|
+
@api_endpoint ||= Replicate::Endpoint.new(endpoint_url: api_endpoint_url, api_token: api_token)
|
30
|
+
end
|
31
|
+
|
32
|
+
def dreambooth_endpoint
|
33
|
+
@dreambooth_endpoint ||= Replicate::Endpoint.new(endpoint_url: dreambooth_endpoint_url, api_token: api_token)
|
34
|
+
end
|
24
35
|
end
|
25
36
|
end
|
@@ -3,13 +3,13 @@
|
|
3
3
|
module Replicate
|
4
4
|
module Configurable
|
5
5
|
attr_accessor :api_token, :webhook_url
|
6
|
-
attr_writer :
|
6
|
+
attr_writer :api_endpoint_url, :dreambooth_endpoint_url
|
7
7
|
|
8
8
|
class << self
|
9
9
|
# List of configurable keys for {Datatrans::Client}
|
10
10
|
# @return [Array] of option keys
|
11
11
|
def keys
|
12
|
-
@keys ||= %i[api_token
|
12
|
+
@keys ||= %i[api_token api_endpoint_url dreambooth_endpoint_url webhook_url]
|
13
13
|
end
|
14
14
|
end
|
15
15
|
|
@@ -19,8 +19,12 @@ module Replicate
|
|
19
19
|
end
|
20
20
|
|
21
21
|
# API endpoint methods
|
22
|
-
def
|
23
|
-
@
|
22
|
+
def api_endpoint_url
|
23
|
+
@api_endpoint_url ||= "https://api.replicate.com/v1"
|
24
|
+
end
|
25
|
+
|
26
|
+
def dreambooth_endpoint_url
|
27
|
+
@dreambooth_endpoint_url ||= "https://dreambooth-api-experimental.replicate.com/v1"
|
24
28
|
end
|
25
29
|
|
26
30
|
private
|
@@ -3,16 +3,19 @@
|
|
3
3
|
require "faraday"
|
4
4
|
require "faraday/net_http"
|
5
5
|
require "faraday/retry"
|
6
|
+
require "faraday/multipart"
|
6
7
|
require "addressable/uri"
|
7
8
|
|
8
9
|
module Replicate
|
9
10
|
# Network layer for API clients.
|
10
|
-
|
11
|
-
|
12
|
-
USER_AGENT = "Datatrans Ruby Gem"
|
11
|
+
class Endpoint
|
12
|
+
attr_reader :endpoint_url, :api_token, :content_type
|
13
13
|
|
14
|
-
|
15
|
-
|
14
|
+
def initialize(endpoint_url:, api_token:, content_type: 'application/json')
|
15
|
+
@endpoint_url = endpoint_url
|
16
|
+
@api_token = api_token
|
17
|
+
@content_type = content_type
|
18
|
+
end
|
16
19
|
|
17
20
|
# Make a HTTP GET request
|
18
21
|
#
|
@@ -72,11 +75,10 @@ module Replicate
|
|
72
75
|
#
|
73
76
|
# @return [Sawyer::Agent]
|
74
77
|
def agent
|
75
|
-
@agent ||= Faraday.new(url:
|
78
|
+
@agent ||= Faraday.new(url: endpoint_url) do |conn|
|
76
79
|
conn.request :retry
|
77
|
-
conn.request :authorization, 'Token', api_token
|
78
|
-
conn.headers["Content-Type"] =
|
79
|
-
conn.headers["Accept"] = DEFAULT_MEDIA_TYPE
|
80
|
+
conn.request :authorization, 'Token', api_token if api_token
|
81
|
+
conn.headers["Content-Type"] = content_type
|
80
82
|
|
81
83
|
conn.adapter :net_http
|
82
84
|
end
|
@@ -89,12 +91,6 @@ module Replicate
|
|
89
91
|
@last_response if defined? @last_response
|
90
92
|
end
|
91
93
|
|
92
|
-
protected
|
93
|
-
|
94
|
-
def endpoint
|
95
|
-
api_endpoint
|
96
|
-
end
|
97
|
-
|
98
94
|
private
|
99
95
|
|
100
96
|
def request(method, path, data)
|
@@ -103,7 +99,12 @@ module Replicate
|
|
103
99
|
when 400
|
104
100
|
raise Error, "#{@last_response.status} #{@last_response.reason_phrase}: #{JSON.parse(@last_response.body)}"
|
105
101
|
else
|
106
|
-
|
102
|
+
case content_type
|
103
|
+
when 'application/json'
|
104
|
+
JSON.parse(@last_response.body)
|
105
|
+
else
|
106
|
+
@last_response.body
|
107
|
+
end
|
107
108
|
end
|
108
109
|
end
|
109
110
|
end
|
@@ -19,6 +19,26 @@ module Replicate
|
|
19
19
|
false
|
20
20
|
end
|
21
21
|
end
|
22
|
+
|
23
|
+
def starting?
|
24
|
+
status == "starting"
|
25
|
+
end
|
26
|
+
|
27
|
+
def processing?
|
28
|
+
status == "processing"
|
29
|
+
end
|
30
|
+
|
31
|
+
def succeeded?
|
32
|
+
status == "succeeded"
|
33
|
+
end
|
34
|
+
|
35
|
+
def failed?
|
36
|
+
status == "failed"
|
37
|
+
end
|
38
|
+
|
39
|
+
def canceled?
|
40
|
+
status == "canceled"
|
41
|
+
end
|
22
42
|
end
|
23
43
|
end
|
24
44
|
end
|
data/lib/replicate/version.rb
CHANGED
data/lib/replicate.rb
CHANGED
@@ -7,6 +7,8 @@ require "replicate/record/base"
|
|
7
7
|
require "replicate/record/model"
|
8
8
|
require "replicate/record/model_version"
|
9
9
|
require "replicate/record/prediction"
|
10
|
+
require "replicate/record/upload"
|
11
|
+
require "replicate/record/training"
|
10
12
|
|
11
13
|
module Replicate
|
12
14
|
class Error < StandardError; end
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: replicate-ruby
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.2.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Dreaming Tulpa
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date:
|
11
|
+
date: 2023-03-08 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: faraday
|
@@ -38,6 +38,20 @@ dependencies:
|
|
38
38
|
- - ">="
|
39
39
|
- !ruby/object:Gem::Version
|
40
40
|
version: '0'
|
41
|
+
- !ruby/object:Gem::Dependency
|
42
|
+
name: faraday-multipart
|
43
|
+
requirement: !ruby/object:Gem::Requirement
|
44
|
+
requirements:
|
45
|
+
- - ">="
|
46
|
+
- !ruby/object:Gem::Version
|
47
|
+
version: '0'
|
48
|
+
type: :runtime
|
49
|
+
prerelease: false
|
50
|
+
version_requirements: !ruby/object:Gem::Requirement
|
51
|
+
requirements:
|
52
|
+
- - ">="
|
53
|
+
- !ruby/object:Gem::Version
|
54
|
+
version: '0'
|
41
55
|
- !ruby/object:Gem::Dependency
|
42
56
|
name: addressable
|
43
57
|
requirement: !ruby/object:Gem::Requirement
|
@@ -80,6 +94,20 @@ dependencies:
|
|
80
94
|
- - ">="
|
81
95
|
- !ruby/object:Gem::Version
|
82
96
|
version: '0'
|
97
|
+
- !ruby/object:Gem::Dependency
|
98
|
+
name: byebug
|
99
|
+
requirement: !ruby/object:Gem::Requirement
|
100
|
+
requirements:
|
101
|
+
- - ">="
|
102
|
+
- !ruby/object:Gem::Version
|
103
|
+
version: '0'
|
104
|
+
type: :development
|
105
|
+
prerelease: false
|
106
|
+
version_requirements: !ruby/object:Gem::Requirement
|
107
|
+
requirements:
|
108
|
+
- - ">="
|
109
|
+
- !ruby/object:Gem::Version
|
110
|
+
version: '0'
|
83
111
|
description:
|
84
112
|
email:
|
85
113
|
- hey@dreamingtulpa.com
|
@@ -93,16 +121,21 @@ files:
|
|
93
121
|
- Gemfile.lock
|
94
122
|
- README.md
|
95
123
|
- Rakefile
|
124
|
+
- data.zip
|
96
125
|
- lib/replicate.rb
|
97
126
|
- lib/replicate/client.rb
|
98
127
|
- lib/replicate/client/model.rb
|
99
128
|
- lib/replicate/client/prediction.rb
|
129
|
+
- lib/replicate/client/training.rb
|
130
|
+
- lib/replicate/client/upload.rb
|
100
131
|
- lib/replicate/configurable.rb
|
101
|
-
- lib/replicate/
|
132
|
+
- lib/replicate/endpoint.rb
|
102
133
|
- lib/replicate/record/base.rb
|
103
134
|
- lib/replicate/record/model.rb
|
104
135
|
- lib/replicate/record/model_version.rb
|
105
136
|
- lib/replicate/record/prediction.rb
|
137
|
+
- lib/replicate/record/training.rb
|
138
|
+
- lib/replicate/record/upload.rb
|
106
139
|
- lib/replicate/version.rb
|
107
140
|
- sig/replicate/ruby.rbs
|
108
141
|
homepage: https://github.com/dreamingtulpa/replicate-ruby
|
@@ -126,7 +159,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
126
159
|
- !ruby/object:Gem::Version
|
127
160
|
version: '0'
|
128
161
|
requirements: []
|
129
|
-
rubygems_version: 3.
|
162
|
+
rubygems_version: 3.4.6
|
130
163
|
signing_key:
|
131
164
|
specification_version: 4
|
132
165
|
summary: Ruby client for Replicate
|