replicate-ruby 0.1.6 → 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/Gemfile.lock +8 -1
- data/README.md +36 -0
- data/data.zip +0 -0
- data/lib/replicate/client/model.rb +6 -4
- data/lib/replicate/client/prediction.rb +7 -4
- data/lib/replicate/client/training.rb +23 -0
- data/lib/replicate/client/upload.rb +27 -0
- data/lib/replicate/client.rb +13 -2
- data/lib/replicate/configurable.rb +8 -4
- data/lib/replicate/{connection.rb → endpoint.rb} +17 -16
- data/lib/replicate/record/prediction.rb +20 -0
- data/lib/replicate/record/training.rb +11 -0
- data/lib/replicate/record/upload.rb +11 -0
- data/lib/replicate/version.rb +1 -1
- data/lib/replicate.rb +2 -0
- metadata +37 -4
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 82385c582239598bb4b62a3e133228b7e6fc417ea57a047c1489e5d94d787faf
|
4
|
+
data.tar.gz: 4ccecfdb33f1a4cf3218587c7ed4fa28475da1ca13369e59d5875b07e4f092f8
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: c8edef78b1309f5873997b33b1de5de82f2f8982ce2ec3ed32b0cdeca0e9038c77b4d56e6173441d106160e77fb1e0d2311024d5cae2c68a1fc91ef6c62cb725
|
7
|
+
data.tar.gz: 56958524475fbd6636e1eb3f409b4dad100fe8179a982a746a0ee2ea4356ffa66724269a9da0e226e374bbe34c2526aacf3a358e5847ed55429ebf06700f2383
|
data/Gemfile.lock
CHANGED
@@ -1,9 +1,10 @@
|
|
1
1
|
PATH
|
2
2
|
remote: .
|
3
3
|
specs:
|
4
|
-
replicate-ruby (0.1.
|
4
|
+
replicate-ruby (0.1.7)
|
5
5
|
addressable
|
6
6
|
faraday (>= 2.0)
|
7
|
+
faraday-multipart
|
7
8
|
faraday-retry
|
8
9
|
|
9
10
|
GEM
|
@@ -12,12 +13,15 @@ GEM
|
|
12
13
|
addressable (2.8.1)
|
13
14
|
public_suffix (>= 2.0.2, < 6.0)
|
14
15
|
ast (2.4.2)
|
16
|
+
byebug (11.1.3)
|
15
17
|
coderay (1.1.3)
|
16
18
|
crack (0.4.5)
|
17
19
|
rexml
|
18
20
|
faraday (2.7.2)
|
19
21
|
faraday-net_http (>= 2.0, < 3.1)
|
20
22
|
ruby2_keywords (>= 0.0.4)
|
23
|
+
faraday-multipart (1.0.4)
|
24
|
+
multipart-post (~> 2)
|
21
25
|
faraday-net_http (3.0.2)
|
22
26
|
faraday-retry (2.0.0)
|
23
27
|
faraday (~> 2.0)
|
@@ -25,6 +29,7 @@ GEM
|
|
25
29
|
json (2.6.2)
|
26
30
|
method_source (1.0.0)
|
27
31
|
minitest (5.16.3)
|
32
|
+
multipart-post (2.2.3)
|
28
33
|
parallel (1.22.1)
|
29
34
|
parser (3.1.2.1)
|
30
35
|
ast (~> 2.4.1)
|
@@ -57,9 +62,11 @@ GEM
|
|
57
62
|
hashdiff (>= 0.4.0, < 2.0.0)
|
58
63
|
|
59
64
|
PLATFORMS
|
65
|
+
arm64-darwin-22
|
60
66
|
x86_64-darwin-21
|
61
67
|
|
62
68
|
DEPENDENCIES
|
69
|
+
byebug
|
63
70
|
minitest (~> 5.0)
|
64
71
|
pry
|
65
72
|
rake (~> 13.0)
|
data/README.md
CHANGED
@@ -54,6 +54,42 @@ id = prediction.id # store prediction id in your backend
|
|
54
54
|
prediction = Replicate.client.retrieve_prediction(id) # retrieve prediction during webhook with id from backend
|
55
55
|
```
|
56
56
|
|
57
|
+
## Dreambooth
|
58
|
+
|
59
|
+
There is support for the [experimental dreambooth endpoint](https://replicate.com/blog/dreambooth-api).
|
60
|
+
|
61
|
+
First, upload your training dataset:
|
62
|
+
|
63
|
+
```
|
64
|
+
upload = Replicate.client.create_upload
|
65
|
+
upload.attach('tmp/data.zip') # replace with the path to your zip file
|
66
|
+
```
|
67
|
+
|
68
|
+
Then start training a new model using, for instance:
|
69
|
+
|
70
|
+
```
|
71
|
+
training = Replicate.client.create_training(
|
72
|
+
input: {
|
73
|
+
instance_prompt: "zwx style",
|
74
|
+
class_prompt: "style",
|
75
|
+
instance_data: upload.serving_url,
|
76
|
+
max_train_steps: 5000
|
77
|
+
},
|
78
|
+
model: 'yourusername/yourmodel'
|
79
|
+
)
|
80
|
+
```
|
81
|
+
|
82
|
+
As soon as the model has finished training, you can run predictions on it:
|
83
|
+
|
84
|
+
```
|
85
|
+
prediction = Replicate.client.create_prediction(
|
86
|
+
input: {
|
87
|
+
prompt: 'your prompt, zwx style'
|
88
|
+
},
|
89
|
+
version: training.version
|
90
|
+
)
|
91
|
+
```
|
92
|
+
|
57
93
|
## Development
|
58
94
|
|
59
95
|
After checking out the repo, run `bin/setup` to install dependencies. Then, run `rake test` to run the tests. You can also run `bin/console` for an interactive prompt that will allow you to experiment.
|
data/data.zip
ADDED
File without changes
|
@@ -9,20 +9,22 @@ module Replicate
|
|
9
9
|
def retrieve_model(model, version: :latest)
|
10
10
|
case version
|
11
11
|
when :latest
|
12
|
-
|
12
|
+
response = api_endpoint.get("models/#{model}")
|
13
|
+
Replicate::Record::Model.new(self, response)
|
13
14
|
when :all
|
14
|
-
response = get("models/#{model}/versions")
|
15
|
+
response = api_endpoint.get("models/#{model}/versions")
|
15
16
|
response["results"].map! { |result| Replicate::Record::ModelVersion.new(self, result) }
|
16
17
|
response
|
17
18
|
else
|
18
|
-
|
19
|
+
response = api_endpoint.get("models/#{model}/versions/#{version}")
|
20
|
+
Replicate::Record::ModelVersion.new(self, response)
|
19
21
|
end
|
20
22
|
end
|
21
23
|
|
22
24
|
# Get a collection of models
|
23
25
|
# @see https://replicate.com/docs/reference/http#get-collection
|
24
26
|
def retrieve_collection(slug)
|
25
|
-
get("collections/#{slug}")
|
27
|
+
api_endpoint.get("collections/#{slug}")
|
26
28
|
end
|
27
29
|
end
|
28
30
|
end
|
@@ -7,13 +7,14 @@ module Replicate
|
|
7
7
|
# Get a prediction
|
8
8
|
# @see https://replicate.com/docs/reference/http#get-prediction
|
9
9
|
def retrieve_prediction(id)
|
10
|
-
|
10
|
+
response = api_endpoint.get("predictions/#{id}")
|
11
|
+
Replicate::Record::Prediction.new(self, response)
|
11
12
|
end
|
12
13
|
|
13
14
|
# Get a list of predictions
|
14
15
|
# @see https://replicate.com/docs/reference/http#get-predictions
|
15
16
|
def list_predictions(cursor = nil)
|
16
|
-
response = get("predictions", cursor: cursor)
|
17
|
+
response = api_endpoint.get("predictions", cursor: cursor)
|
17
18
|
response["results"].map! { |result| Replicate::Record::Prediction.new(self, result) }
|
18
19
|
response
|
19
20
|
end
|
@@ -22,13 +23,15 @@ module Replicate
|
|
22
23
|
# @see https://replicate.com/docs/reference/http#create-prediction
|
23
24
|
def create_prediction(params)
|
24
25
|
params[:webhook_completed] ||= webhook_url
|
25
|
-
|
26
|
+
response = api_endpoint.post("predictions", params)
|
27
|
+
Replicate::Record::Prediction.new(self, response)
|
26
28
|
end
|
27
29
|
|
28
30
|
# Cancel a prediction
|
29
31
|
# @see https://replicate.com/docs/reference/http#cancel-prediction
|
30
32
|
def cancel_prediction(id)
|
31
|
-
|
33
|
+
response = api_endpont.post("predictions/#{id}/cancel")
|
34
|
+
Replicate::Record::Prediction.new(self, response)
|
32
35
|
end
|
33
36
|
end
|
34
37
|
end
|
@@ -0,0 +1,23 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Replicate
|
4
|
+
class Client
|
5
|
+
# Methods for the Prediction API
|
6
|
+
module Training
|
7
|
+
# Get a training
|
8
|
+
# @see https://replicate.com/blog/dreambooth-api
|
9
|
+
def retrieve_training(id)
|
10
|
+
response = dreambooth_endpoint.get("trainings/#{id}")
|
11
|
+
Replicate::Record::Training.new(self, response)
|
12
|
+
end
|
13
|
+
|
14
|
+
# Create a training
|
15
|
+
# @see https://replicate.com/blog/dreambooth-api
|
16
|
+
def create_training(params)
|
17
|
+
params[:webhook_completed] ||= webhook_url
|
18
|
+
response = dreambooth_endpoint.post("trainings", params)
|
19
|
+
Replicate::Record::Training.new(self, response)
|
20
|
+
end
|
21
|
+
end
|
22
|
+
end
|
23
|
+
end
|
@@ -0,0 +1,27 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Replicate
|
4
|
+
class Client
|
5
|
+
# Methods for the Prediction API
|
6
|
+
module Upload
|
7
|
+
# Create an upload
|
8
|
+
# @see https://replicate.com/blog/dreambooth-api
|
9
|
+
def create_upload
|
10
|
+
response = dreambooth_endpoint.post("upload/data.zip")
|
11
|
+
Replicate::Record::Upload.new(self, response)
|
12
|
+
end
|
13
|
+
|
14
|
+
# Create an upload
|
15
|
+
# @see https://replicate.com/blog/dreambooth-api
|
16
|
+
def update_upload(upload_endpoint_url, zip_path)
|
17
|
+
endpoint = Replicate::Endpoint.new(endpoint_url: upload_endpoint_url, api_token: nil)
|
18
|
+
endpoint.agent.put do |req|
|
19
|
+
req.headers["Content-Type"] = "application/zip"
|
20
|
+
req.headers["Content-Length"] = File.size(zip_path).to_s
|
21
|
+
req.headers["Transfer-Encoding"] = "chunked"
|
22
|
+
req.body = Faraday::UploadIO.new(zip_path, 'application/zip')
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
data/lib/replicate/client.rb
CHANGED
@@ -1,18 +1,21 @@
|
|
1
1
|
# frozen_string_literal: true
|
2
2
|
|
3
3
|
require "replicate/configurable"
|
4
|
-
require "replicate/
|
4
|
+
require "replicate/endpoint"
|
5
5
|
|
6
6
|
require "replicate/client/model"
|
7
7
|
require "replicate/client/prediction"
|
8
|
+
require "replicate/client/upload"
|
9
|
+
require "replicate/client/training"
|
8
10
|
|
9
11
|
module Replicate
|
10
12
|
class Client
|
11
13
|
include Replicate::Configurable
|
12
|
-
include Replicate::Connection
|
13
14
|
|
14
15
|
include Replicate::Client::Model
|
15
16
|
include Replicate::Client::Prediction
|
17
|
+
include Replicate::Client::Upload
|
18
|
+
include Replicate::Client::Training
|
16
19
|
|
17
20
|
def initialize(options = {})
|
18
21
|
# Use options passed in, but fall back to module defaults
|
@@ -21,5 +24,13 @@ module Replicate
|
|
21
24
|
instance_variable_set(:"@#{key}", value)
|
22
25
|
end
|
23
26
|
end
|
27
|
+
|
28
|
+
def api_endpoint
|
29
|
+
@api_endpoint ||= Replicate::Endpoint.new(endpoint_url: api_endpoint_url, api_token: api_token)
|
30
|
+
end
|
31
|
+
|
32
|
+
def dreambooth_endpoint
|
33
|
+
@dreambooth_endpoint ||= Replicate::Endpoint.new(endpoint_url: dreambooth_endpoint_url, api_token: api_token)
|
34
|
+
end
|
24
35
|
end
|
25
36
|
end
|
@@ -3,13 +3,13 @@
|
|
3
3
|
module Replicate
|
4
4
|
module Configurable
|
5
5
|
attr_accessor :api_token, :webhook_url
|
6
|
-
attr_writer :
|
6
|
+
attr_writer :api_endpoint_url, :dreambooth_endpoint_url
|
7
7
|
|
8
8
|
class << self
|
9
9
|
# List of configurable keys for {Datatrans::Client}
|
10
10
|
# @return [Array] of option keys
|
11
11
|
def keys
|
12
|
-
@keys ||= %i[api_token
|
12
|
+
@keys ||= %i[api_token api_endpoint_url dreambooth_endpoint_url webhook_url]
|
13
13
|
end
|
14
14
|
end
|
15
15
|
|
@@ -19,8 +19,12 @@ module Replicate
|
|
19
19
|
end
|
20
20
|
|
21
21
|
# API endpoint methods
|
22
|
-
def
|
23
|
-
@
|
22
|
+
def api_endpoint_url
|
23
|
+
@api_endpoint_url ||= "https://api.replicate.com/v1"
|
24
|
+
end
|
25
|
+
|
26
|
+
def dreambooth_endpoint_url
|
27
|
+
@dreambooth_endpoint_url ||= "https://dreambooth-api-experimental.replicate.com/v1"
|
24
28
|
end
|
25
29
|
|
26
30
|
private
|
@@ -3,16 +3,19 @@
|
|
3
3
|
require "faraday"
|
4
4
|
require "faraday/net_http"
|
5
5
|
require "faraday/retry"
|
6
|
+
require "faraday/multipart"
|
6
7
|
require "addressable/uri"
|
7
8
|
|
8
9
|
module Replicate
|
9
10
|
# Network layer for API clients.
|
10
|
-
|
11
|
-
|
12
|
-
USER_AGENT = "Datatrans Ruby Gem"
|
11
|
+
class Endpoint
|
12
|
+
attr_reader :endpoint_url, :api_token, :content_type
|
13
13
|
|
14
|
-
|
15
|
-
|
14
|
+
def initialize(endpoint_url:, api_token:, content_type: 'application/json')
|
15
|
+
@endpoint_url = endpoint_url
|
16
|
+
@api_token = api_token
|
17
|
+
@content_type = content_type
|
18
|
+
end
|
16
19
|
|
17
20
|
# Make a HTTP GET request
|
18
21
|
#
|
@@ -72,11 +75,10 @@ module Replicate
|
|
72
75
|
#
|
73
76
|
# @return [Sawyer::Agent]
|
74
77
|
def agent
|
75
|
-
@agent ||= Faraday.new(url:
|
78
|
+
@agent ||= Faraday.new(url: endpoint_url) do |conn|
|
76
79
|
conn.request :retry
|
77
|
-
conn.request :authorization, 'Token', api_token
|
78
|
-
conn.headers["Content-Type"] =
|
79
|
-
conn.headers["Accept"] = DEFAULT_MEDIA_TYPE
|
80
|
+
conn.request :authorization, 'Token', api_token if api_token
|
81
|
+
conn.headers["Content-Type"] = content_type
|
80
82
|
|
81
83
|
conn.adapter :net_http
|
82
84
|
end
|
@@ -89,12 +91,6 @@ module Replicate
|
|
89
91
|
@last_response if defined? @last_response
|
90
92
|
end
|
91
93
|
|
92
|
-
protected
|
93
|
-
|
94
|
-
def endpoint
|
95
|
-
api_endpoint
|
96
|
-
end
|
97
|
-
|
98
94
|
private
|
99
95
|
|
100
96
|
def request(method, path, data)
|
@@ -103,7 +99,12 @@ module Replicate
|
|
103
99
|
when 400
|
104
100
|
raise Error, "#{@last_response.status} #{@last_response.reason_phrase}: #{JSON.parse(@last_response.body)}"
|
105
101
|
else
|
106
|
-
|
102
|
+
case content_type
|
103
|
+
when 'application/json'
|
104
|
+
JSON.parse(@last_response.body)
|
105
|
+
else
|
106
|
+
@last_response.body
|
107
|
+
end
|
107
108
|
end
|
108
109
|
end
|
109
110
|
end
|
@@ -19,6 +19,26 @@ module Replicate
|
|
19
19
|
false
|
20
20
|
end
|
21
21
|
end
|
22
|
+
|
23
|
+
def starting?
|
24
|
+
status == "starting"
|
25
|
+
end
|
26
|
+
|
27
|
+
def processing?
|
28
|
+
status == "processing"
|
29
|
+
end
|
30
|
+
|
31
|
+
def succeeded?
|
32
|
+
status == "succeeded"
|
33
|
+
end
|
34
|
+
|
35
|
+
def failed?
|
36
|
+
status == "failed"
|
37
|
+
end
|
38
|
+
|
39
|
+
def canceled?
|
40
|
+
status == "canceled"
|
41
|
+
end
|
22
42
|
end
|
23
43
|
end
|
24
44
|
end
|
data/lib/replicate/version.rb
CHANGED
data/lib/replicate.rb
CHANGED
@@ -7,6 +7,8 @@ require "replicate/record/base"
|
|
7
7
|
require "replicate/record/model"
|
8
8
|
require "replicate/record/model_version"
|
9
9
|
require "replicate/record/prediction"
|
10
|
+
require "replicate/record/upload"
|
11
|
+
require "replicate/record/training"
|
10
12
|
|
11
13
|
module Replicate
|
12
14
|
class Error < StandardError; end
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: replicate-ruby
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.2.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Dreaming Tulpa
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date:
|
11
|
+
date: 2023-03-08 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: faraday
|
@@ -38,6 +38,20 @@ dependencies:
|
|
38
38
|
- - ">="
|
39
39
|
- !ruby/object:Gem::Version
|
40
40
|
version: '0'
|
41
|
+
- !ruby/object:Gem::Dependency
|
42
|
+
name: faraday-multipart
|
43
|
+
requirement: !ruby/object:Gem::Requirement
|
44
|
+
requirements:
|
45
|
+
- - ">="
|
46
|
+
- !ruby/object:Gem::Version
|
47
|
+
version: '0'
|
48
|
+
type: :runtime
|
49
|
+
prerelease: false
|
50
|
+
version_requirements: !ruby/object:Gem::Requirement
|
51
|
+
requirements:
|
52
|
+
- - ">="
|
53
|
+
- !ruby/object:Gem::Version
|
54
|
+
version: '0'
|
41
55
|
- !ruby/object:Gem::Dependency
|
42
56
|
name: addressable
|
43
57
|
requirement: !ruby/object:Gem::Requirement
|
@@ -80,6 +94,20 @@ dependencies:
|
|
80
94
|
- - ">="
|
81
95
|
- !ruby/object:Gem::Version
|
82
96
|
version: '0'
|
97
|
+
- !ruby/object:Gem::Dependency
|
98
|
+
name: byebug
|
99
|
+
requirement: !ruby/object:Gem::Requirement
|
100
|
+
requirements:
|
101
|
+
- - ">="
|
102
|
+
- !ruby/object:Gem::Version
|
103
|
+
version: '0'
|
104
|
+
type: :development
|
105
|
+
prerelease: false
|
106
|
+
version_requirements: !ruby/object:Gem::Requirement
|
107
|
+
requirements:
|
108
|
+
- - ">="
|
109
|
+
- !ruby/object:Gem::Version
|
110
|
+
version: '0'
|
83
111
|
description:
|
84
112
|
email:
|
85
113
|
- hey@dreamingtulpa.com
|
@@ -93,16 +121,21 @@ files:
|
|
93
121
|
- Gemfile.lock
|
94
122
|
- README.md
|
95
123
|
- Rakefile
|
124
|
+
- data.zip
|
96
125
|
- lib/replicate.rb
|
97
126
|
- lib/replicate/client.rb
|
98
127
|
- lib/replicate/client/model.rb
|
99
128
|
- lib/replicate/client/prediction.rb
|
129
|
+
- lib/replicate/client/training.rb
|
130
|
+
- lib/replicate/client/upload.rb
|
100
131
|
- lib/replicate/configurable.rb
|
101
|
-
- lib/replicate/
|
132
|
+
- lib/replicate/endpoint.rb
|
102
133
|
- lib/replicate/record/base.rb
|
103
134
|
- lib/replicate/record/model.rb
|
104
135
|
- lib/replicate/record/model_version.rb
|
105
136
|
- lib/replicate/record/prediction.rb
|
137
|
+
- lib/replicate/record/training.rb
|
138
|
+
- lib/replicate/record/upload.rb
|
106
139
|
- lib/replicate/version.rb
|
107
140
|
- sig/replicate/ruby.rbs
|
108
141
|
homepage: https://github.com/dreamingtulpa/replicate-ruby
|
@@ -126,7 +159,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
126
159
|
- !ruby/object:Gem::Version
|
127
160
|
version: '0'
|
128
161
|
requirements: []
|
129
|
-
rubygems_version: 3.
|
162
|
+
rubygems_version: 3.4.6
|
130
163
|
signing_key:
|
131
164
|
specification_version: 4
|
132
165
|
summary: Ruby client for Replicate
|