replicate-client 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/.rubocop.yml +42 -0
- data/.ruby-version +1 -0
- data/Gemfile +9 -0
- data/Gemfile.lock +61 -0
- data/README.md +202 -0
- data/Rakefile +16 -0
- data/lib/replicate-client/client.rb +116 -0
- data/lib/replicate-client/deployment.rb +215 -0
- data/lib/replicate-client/hardware.rb +53 -0
- data/lib/replicate-client/model/version.rb +132 -0
- data/lib/replicate-client/model.rb +335 -0
- data/lib/replicate-client/prediction.rb +287 -0
- data/lib/replicate-client/training.rb +250 -0
- data/lib/replicate-client/version.rb +5 -0
- data/lib/replicate-client.rb +78 -0
- metadata +76 -0
checksums.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
1
|
+
---
|
2
|
+
SHA256:
|
3
|
+
metadata.gz: 7d40274f18b3fca8f69b7a0e7614765150352ec1d828dedc6d4046605339833c
|
4
|
+
data.tar.gz: 991a137f2a20658ef3c85d95fbc230ac4528d63b29d1c01d6efdd07f1dcd1675
|
5
|
+
SHA512:
|
6
|
+
metadata.gz: e89be0757da9ea93891a36126fefc72b5d26a5d2942477c7dd0c5a28978d31faf202b84743d6196c8d597e0703527e68538cfe483afc4d0c1526fb714e0c925d
|
7
|
+
data.tar.gz: d30afd20e196df7e169053d284ad9ab1d435876274763926e518e78661f2b3a156eaf9ffbc4b9ccd84955c7b22bc3abc979ca647f4c91ddf9404311cf0e3aad5
|
data/.rubocop.yml
ADDED
@@ -0,0 +1,42 @@
|
|
1
|
+
AllCops:
|
2
|
+
TargetRubyVersion: 3.3.9
|
3
|
+
NewCops: enable
|
4
|
+
SuggestExtensions: false
|
5
|
+
|
6
|
+
Style/Documentation:
|
7
|
+
Enabled: false
|
8
|
+
|
9
|
+
Bundler/OrderedGems:
|
10
|
+
Enabled: false
|
11
|
+
|
12
|
+
Style/StringLiterals:
|
13
|
+
EnforcedStyle: double_quotes
|
14
|
+
|
15
|
+
Style/FrozenStringLiteralComment:
|
16
|
+
SafeAutoCorrect: true
|
17
|
+
EnforcedStyle: always_true
|
18
|
+
|
19
|
+
Metrics/MethodLength:
|
20
|
+
Enabled: false
|
21
|
+
|
22
|
+
Style/AccessorGrouping:
|
23
|
+
Enabled: false
|
24
|
+
|
25
|
+
Metrics/AbcSize:
|
26
|
+
Enabled: false
|
27
|
+
|
28
|
+
Metrics/CyclomaticComplexity:
|
29
|
+
Enabled: false
|
30
|
+
|
31
|
+
Metrics/ParameterLists:
|
32
|
+
Enabled: false
|
33
|
+
|
34
|
+
Metrics/ClassLength:
|
35
|
+
Enabled: false
|
36
|
+
|
37
|
+
Style/HashSyntax:
|
38
|
+
Enabled: false
|
39
|
+
|
40
|
+
Naming/FileName:
|
41
|
+
Exclude:
|
42
|
+
- "lib/replicate-client.rb"
|
data/.ruby-version
ADDED
@@ -0,0 +1 @@
|
|
1
|
+
3.3.0
|
data/Gemfile
ADDED
data/Gemfile.lock
ADDED
@@ -0,0 +1,61 @@
|
|
1
|
+
PATH
|
2
|
+
remote: .
|
3
|
+
specs:
|
4
|
+
replicate-client (0.1.0)
|
5
|
+
faraday (>= 1)
|
6
|
+
|
7
|
+
GEM
|
8
|
+
remote: https://rubygems.org/
|
9
|
+
specs:
|
10
|
+
ast (2.4.2)
|
11
|
+
faraday (2.10.1)
|
12
|
+
faraday-net_http (>= 2.0, < 3.2)
|
13
|
+
logger
|
14
|
+
faraday-net_http (3.1.1)
|
15
|
+
net-http
|
16
|
+
json (2.7.2)
|
17
|
+
language_server-protocol (3.17.0.3)
|
18
|
+
logger (1.6.0)
|
19
|
+
minitest (5.23.1)
|
20
|
+
net-http (0.4.1)
|
21
|
+
uri
|
22
|
+
parallel (1.24.0)
|
23
|
+
parser (3.3.1.0)
|
24
|
+
ast (~> 2.4.1)
|
25
|
+
racc
|
26
|
+
racc (1.8.0)
|
27
|
+
rainbow (3.1.1)
|
28
|
+
rake (13.2.1)
|
29
|
+
regexp_parser (2.9.2)
|
30
|
+
rexml (3.2.8)
|
31
|
+
strscan (>= 3.0.9)
|
32
|
+
rubocop (1.64.0)
|
33
|
+
json (~> 2.3)
|
34
|
+
language_server-protocol (>= 3.17.0)
|
35
|
+
parallel (~> 1.10)
|
36
|
+
parser (>= 3.3.0.2)
|
37
|
+
rainbow (>= 2.2.2, < 4.0)
|
38
|
+
regexp_parser (>= 1.8, < 3.0)
|
39
|
+
rexml (>= 3.2.5, < 4.0)
|
40
|
+
rubocop-ast (>= 1.31.1, < 2.0)
|
41
|
+
ruby-progressbar (~> 1.7)
|
42
|
+
unicode-display_width (>= 2.4.0, < 3.0)
|
43
|
+
rubocop-ast (1.31.3)
|
44
|
+
parser (>= 3.3.1.0)
|
45
|
+
ruby-progressbar (1.13.0)
|
46
|
+
strscan (3.1.0)
|
47
|
+
unicode-display_width (2.5.0)
|
48
|
+
uri (0.13.0)
|
49
|
+
|
50
|
+
PLATFORMS
|
51
|
+
ruby
|
52
|
+
x86_64-darwin-23
|
53
|
+
|
54
|
+
DEPENDENCIES
|
55
|
+
minitest (~> 5.0)
|
56
|
+
rake (~> 13.0)
|
57
|
+
replicate-client!
|
58
|
+
rubocop (~> 1.21)
|
59
|
+
|
60
|
+
BUNDLED WITH
|
61
|
+
2.5.7
|
data/README.md
ADDED
@@ -0,0 +1,202 @@
|
|
1
|
+
# ReplicateClient
|
2
|
+
|
3
|
+
**🚧 This gem is still under development 🚧**
|
4
|
+
|
5
|
+
## Installation
|
6
|
+
|
7
|
+
Install the gem and add to the application"s Gemfile by executing:
|
8
|
+
|
9
|
+
$ bundle add replicate
|
10
|
+
|
11
|
+
If bundler is not being used to manage dependencies, install the gem by executing:
|
12
|
+
|
13
|
+
$ gem install replicate
|
14
|
+
|
15
|
+
## Usage
|
16
|
+
|
17
|
+
### Configuration
|
18
|
+
|
19
|
+
You can configure the gem by calling the `#configure` method on the `ReplicateClient` module. The method accepts a block with the configuration options.
|
20
|
+
|
21
|
+
```ruby
|
22
|
+
ReplicateClient.configure do |config|
|
23
|
+
config.access_token = ENV["REPLICATE_ACCESS_TOKEN"] # Required
|
24
|
+
config.uri_base = "https://replicate.app/api/v1" # Optional
|
25
|
+
config.request_timeout = 5 # Optional (default: 120)
|
26
|
+
config.webhook_url = "https://example.com/replicate/webhook" # Optional
|
27
|
+
end
|
28
|
+
```
|
29
|
+
|
30
|
+
### Get a model
|
31
|
+
|
32
|
+
```ruby
|
33
|
+
model = ReplicateClient::Model.find("stability-ai/sdxl")
|
34
|
+
model = ReplicateClient::Model.find_by(owner: "stability-ai", name: "sdxl")
|
35
|
+
model = ReplicateClient::Model.find_by(owner: "stability-ai", name: "sdxl", version_id: "7762fd07cf82c948538e41f63f77d685e02b063e37e496e96eefd46c929f9bdc")
|
36
|
+
```
|
37
|
+
|
38
|
+
### Get a model version
|
39
|
+
|
40
|
+
```ruby
|
41
|
+
version = ReplicateClient::Model::Version.find_by(owner: "stability-ai", name: "sdxl", version_id: "7762fd07cf82c948538e41f63f77d685e02b063e37e496e96eefd46c929f9bdc")
|
42
|
+
|
43
|
+
model = ReplicateClient::Model.find_by(owner: "stability-ai", name: "sdxl", version_id: "7762fd07cf82c948538e41f63f77d685e02b063e37e496e96eefd46c929f9bdc")
|
44
|
+
version = model.version
|
45
|
+
```
|
46
|
+
|
47
|
+
### Get the latest version of a model
|
48
|
+
|
49
|
+
```ruby
|
50
|
+
model = ReplicateClient::Model.find("stability-ai/sdxl")
|
51
|
+
version = model.latest_version
|
52
|
+
```
|
53
|
+
|
54
|
+
### Get list of model versions
|
55
|
+
|
56
|
+
```ruby
|
57
|
+
model = ReplicateClient::Model.find("stability-ai/sdxl")
|
58
|
+
versions = model.versions
|
59
|
+
```
|
60
|
+
|
61
|
+
### Paginate through all models
|
62
|
+
|
63
|
+
```ruby
|
64
|
+
ReplicateClient::Model.auto_paging_each do |model|
|
65
|
+
puts model.name
|
66
|
+
end
|
67
|
+
```
|
68
|
+
|
69
|
+
### Create a prediction
|
70
|
+
|
71
|
+
```ruby
|
72
|
+
version = ReplicateClient::Model::Version.find_by(owner: "stability-ai", name: "sdxl", version_id: "7762fd07cf82c948538e41f63f77d685e02b063e37e496e96eefd46c929f9bdc")
|
73
|
+
prediction = version.create_prediction!(input: { my: "input" })
|
74
|
+
|
75
|
+
prediction = version.create_prediction!(input: { my: "input" }, webhook_url: "https://example.com/replicate/webhook")
|
76
|
+
|
77
|
+
prediction = version.create_prediction!(input: { my: "input" }, webhook_url: "https://example.com/replicate/webhook", webhook_events_filter: ["start", "completed"])
|
78
|
+
|
79
|
+
prediction = ReplicateClient::Prediction.create!(version: version, input: { my: "input" })
|
80
|
+
|
81
|
+
prediction = ReplicateClient::Prediction.create!(version: "7762fd07cf82c948538e41f63f77d685e02b063e37e496e96eefd46c929f9bdc", input: { my: "input" })
|
82
|
+
|
83
|
+
deployment = ReplicateClient::Deployment.find("851-labs/my-deployment")
|
84
|
+
prediction = deployment.create_prediction!(input: { my: "input" })
|
85
|
+
|
86
|
+
model = ReplicateClient::Model.find("stability-ai/sdxl")
|
87
|
+
prediction = model.create_prediction!(input: { my: "input" })
|
88
|
+
|
89
|
+
prediction = ReplicateClient::Prediction.create_for_official_model!(model: "stability-ai/sdxl", input: { my: "input" })
|
90
|
+
```
|
91
|
+
|
92
|
+
### Get a prediction
|
93
|
+
|
94
|
+
```ruby
|
95
|
+
prediction = ReplicateClient::Prediction.find("7762fd07cf82c948538e41f63f77d685e02b063e37e496e96eefd46c929f9bdc")
|
96
|
+
|
97
|
+
prediction = ReplicateClient::Prediction.find_by(id: "7762fd07cf82c948538e41f63f77d685e02b063e37e496e96eefd46c929f9bdc")
|
98
|
+
|
99
|
+
prediction = ReplicateClient::Prediction.find_by!(id: "7762fd07cf82c948538e41f63f77d685e02b063e37e496e96eefd46c929f9bdc")
|
100
|
+
```
|
101
|
+
|
102
|
+
### Reload a resource
|
103
|
+
|
104
|
+
```ruby
|
105
|
+
model = ReplicateClient::Model.find("stability-ai/sdxl")
|
106
|
+
model.reload!
|
107
|
+
|
108
|
+
version = ReplicateClient::Model::Version.find_by(owner: "stability-ai", name: "sdxl", version_id: "7762fd07cf82c948538e41f63f77d685e02b063e37e496e96eefd46c929f9bdc")
|
109
|
+
version.reload!
|
110
|
+
|
111
|
+
prediction = ReplicateClient::Prediction.find("7762fd07cf82c948538e41f63f77d685e02b063e37e496e96eefd46c929f9bdc")
|
112
|
+
prediction.reload!
|
113
|
+
```
|
114
|
+
|
115
|
+
### Delete a resource
|
116
|
+
|
117
|
+
```ruby
|
118
|
+
model = ReplicateClient::Model.find("stability-ai/sdxl")
|
119
|
+
model.delete!
|
120
|
+
```
|
121
|
+
|
122
|
+
### Get available hardware
|
123
|
+
|
124
|
+
```ruby
|
125
|
+
hardware = ReplicateClient::Hardware.all
|
126
|
+
```
|
127
|
+
|
128
|
+
### Create a deployment
|
129
|
+
|
130
|
+
```ruby
|
131
|
+
deployment = ReplicateClient::Deployment.create!(name: "851-deployment", model: "stability-ai/sdxl", hardware: "gpu-t4", min_instances: 1, max_instances: 1)
|
132
|
+
|
133
|
+
model = ReplicateClient::Model.find("stability-ai/sdxl")
|
134
|
+
deployment = ReplicateClient::Deployment.create!(name: "851-deployment", model: model, hardware: "gpu-t4", min_instances: 1, max_instances: 1)
|
135
|
+
|
136
|
+
hardware = ReplicateClient::Hardware.all.first
|
137
|
+
deployment = ReplicateClient::Deployment.create!(name: "851-deployment", model: model, hardware: hardware, min_instances: 1, max_instances: 1)
|
138
|
+
```
|
139
|
+
|
140
|
+
### Get a deployment
|
141
|
+
|
142
|
+
```ruby
|
143
|
+
deployment = ReplicateClient::Deployment.find("851-labs/my-deployment")
|
144
|
+
|
145
|
+
deployment = ReplicateClient::Deployment.find_by(owner: "851-labs", name: "my-deployment")
|
146
|
+
|
147
|
+
deployment = ReplicateClient::Deployment.find_by!(owner: "851-labs", name: "my-deployment")
|
148
|
+
```
|
149
|
+
|
150
|
+
### Paginate through all deployments
|
151
|
+
|
152
|
+
```ruby
|
153
|
+
ReplicateClient::Deployment.auto_paging_each do |deployment|
|
154
|
+
puts deployment.name
|
155
|
+
end
|
156
|
+
```
|
157
|
+
|
158
|
+
### Create a training
|
159
|
+
|
160
|
+
```ruby
|
161
|
+
training = ReplicateClient::Training.create!(owner: "851-labs", name: "my-training", version: "7762fd07cf82c948538e41f63f77d685e02b063e37e496e96eefd46c929f9bdc", destination: "851-labs/my-new-model", input: {})
|
162
|
+
|
163
|
+
sdxl = ReplicateClient::Model::Version.find_by(owner: "stability-ai", name: "sdxl", version_id: "7762fd07cf82c948538e41f63f77d685e02b063e37e496e96eefd46c929f9bdc")
|
164
|
+
destination_model = ReplicateClient::Model.find("851-labs/my-new-model")
|
165
|
+
training = ReplicateClient::Training.create!(owner: "851-labs", name: "my-training", version: version, destination: destination_model, input: {})
|
166
|
+
|
167
|
+
sdxl = ReplicateClient::Model.find("stability-ai/sdxl")
|
168
|
+
destination_model = ReplicateClient::Model.find("851-labs/my-new-model")
|
169
|
+
training = ReplicateClient::Training.create_for_model!(model: sdxl, destination: destination_model, input: {})
|
170
|
+
```
|
171
|
+
|
172
|
+
### Get a training
|
173
|
+
|
174
|
+
```ruby
|
175
|
+
training = ReplicateClient::Training.find("b3kgfb2y9nrm00chdnkaam2dvz")
|
176
|
+
```
|
177
|
+
|
178
|
+
### Paginate through all trainings
|
179
|
+
|
180
|
+
```ruby
|
181
|
+
ReplicateClient::Training.auto_paging_each do |training|
|
182
|
+
puts training.name
|
183
|
+
end
|
184
|
+
```
|
185
|
+
|
186
|
+
### Cancel a training
|
187
|
+
|
188
|
+
```ruby
|
189
|
+
training = ReplicateClient::Training.find("b3kgfb2y9nrm00chdnkaam2dvz")
|
190
|
+
training.cancel!
|
191
|
+
```
|
192
|
+
|
193
|
+
## Warning
|
194
|
+
|
195
|
+
Official models will not have vesions. The version id will be nil.
|
196
|
+
|
197
|
+
## Development
|
198
|
+
|
199
|
+
After checking out the repo, run `bin/setup` to install dependencies. Then, run `rake test` to run the tests. You can also run `bin/console` for an interactive prompt that will allow you to experiment.
|
200
|
+
|
201
|
+
To install this gem onto your local machine, run `bundle exec rake install`. To release a new version, update the version number in `version.rb`, and then run `bundle exec rake release`, which will create a git tag for the version, push git commits and the created tag, and push the `.gem` file to [rubygems.org](https://rubygems.org).
|
202
|
+
|
data/Rakefile
ADDED
@@ -0,0 +1,16 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require "bundler/gem_tasks"
|
4
|
+
require "rake/testtask"
|
5
|
+
|
6
|
+
Rake::TestTask.new(:test) do |t|
|
7
|
+
t.libs << "test"
|
8
|
+
t.libs << "lib"
|
9
|
+
t.test_files = FileList["test/**/test_*.rb"]
|
10
|
+
end
|
11
|
+
|
12
|
+
require "rubocop/rake_task"
|
13
|
+
|
14
|
+
RuboCop::RakeTask.new
|
15
|
+
|
16
|
+
task default: %i[test rubocop]
|
@@ -0,0 +1,116 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module ReplicateClient
|
4
|
+
class Client
|
5
|
+
# Initialize the client.
|
6
|
+
#
|
7
|
+
# @param configuration [ReplicateClient::Configuration] The configuration for the client.
|
8
|
+
#
|
9
|
+
# @return [ReplicateClient::Client]
|
10
|
+
def initialize(configuration = ReplicateClient.configuration)
|
11
|
+
@configuration = configuration
|
12
|
+
end
|
13
|
+
|
14
|
+
# Make a POST request to the API.
|
15
|
+
#
|
16
|
+
# @param path [String] The path to the API endpoint.
|
17
|
+
# @param payload [Hash] The payload to send to the API.
|
18
|
+
#
|
19
|
+
# @return [Hash] The response from the API.
|
20
|
+
def post(path, payload)
|
21
|
+
response = connection.post(build_url(path)) do |request|
|
22
|
+
request.headers["Authorization"] = "Bearer #{@configuration.access_token}"
|
23
|
+
request.headers["Content-Type"] = "application/json"
|
24
|
+
request.headers["Accept"] = "application/json"
|
25
|
+
request.body = payload.compact.to_json
|
26
|
+
end
|
27
|
+
|
28
|
+
handle_error(response) unless response.success?
|
29
|
+
|
30
|
+
JSON.parse(response.body)
|
31
|
+
end
|
32
|
+
|
33
|
+
# Make a GET request to the API.
|
34
|
+
#
|
35
|
+
# @param path [String] The path to the API endpoint.
|
36
|
+
#
|
37
|
+
# @return [Hash] The response from the API.
|
38
|
+
def get(path)
|
39
|
+
puts "GET #{path}"
|
40
|
+
|
41
|
+
response = connection.get(build_url(path)) do |request|
|
42
|
+
request.headers["Authorization"] = "Bearer #{@configuration.access_token}"
|
43
|
+
request.headers["Content-Type"] = "application/json"
|
44
|
+
end
|
45
|
+
|
46
|
+
handle_error(response) unless response.success?
|
47
|
+
|
48
|
+
JSON.parse(response.body)
|
49
|
+
end
|
50
|
+
|
51
|
+
# Make a DELETE request to the API.
|
52
|
+
#
|
53
|
+
# @param path [String] The path to the API endpoint.
|
54
|
+
#
|
55
|
+
# @return [void]
|
56
|
+
def delete(path)
|
57
|
+
response = connection.delete(build_url(path)) do |request|
|
58
|
+
request.headers["Authorization"] = "Bearer #{@configuration.access_token}"
|
59
|
+
request.headers["Content-Type"] = "application/json"
|
60
|
+
end
|
61
|
+
|
62
|
+
handle_error(response) unless response.success?
|
63
|
+
end
|
64
|
+
|
65
|
+
def patch(path, payload)
|
66
|
+
response = connection.patch(build_url(path)) do |request|
|
67
|
+
request.headers["Authorization"] = "Bearer #{@configuration.access_token}"
|
68
|
+
request.headers["Content-Type"] = "application/json"
|
69
|
+
request.headers["Accept"] = "application/json"
|
70
|
+
request.body = payload.compact.to_json
|
71
|
+
end
|
72
|
+
|
73
|
+
handle_error(response) unless response.success?
|
74
|
+
|
75
|
+
JSON.parse(response.body)
|
76
|
+
end
|
77
|
+
|
78
|
+
# Handle errors from the API.
|
79
|
+
#
|
80
|
+
# @param response [Faraday::Response] The response from the API.
|
81
|
+
#
|
82
|
+
# @return [void]
|
83
|
+
def handle_error(response)
|
84
|
+
case response.status
|
85
|
+
when 401
|
86
|
+
raise UnauthorizedError, response.body
|
87
|
+
when 403
|
88
|
+
raise ForbiddenError, response.body
|
89
|
+
when 404
|
90
|
+
raise NotFoundError, response.body
|
91
|
+
else
|
92
|
+
raise ServerError, response.body
|
93
|
+
end
|
94
|
+
end
|
95
|
+
|
96
|
+
private
|
97
|
+
|
98
|
+
# Build the URL for the API.
|
99
|
+
#
|
100
|
+
# @param path [String] The path to the API endpoint.
|
101
|
+
def build_url(path)
|
102
|
+
"#{@configuration.uri_base}#{path}"
|
103
|
+
end
|
104
|
+
|
105
|
+
# Create a connection to the API.
|
106
|
+
#
|
107
|
+
# @return [Faraday::Connection]
|
108
|
+
def connection
|
109
|
+
Faraday.new do |faraday|
|
110
|
+
faraday.request :url_encoded
|
111
|
+
faraday.options.timeout = @configuration.request_timeout
|
112
|
+
faraday.options.open_timeout = @configuration.request_timeout
|
113
|
+
end
|
114
|
+
end
|
115
|
+
end
|
116
|
+
end
|
@@ -0,0 +1,215 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module ReplicateClient
|
4
|
+
class Deployment
|
5
|
+
INDEX_PATH = "/deployments"
|
6
|
+
|
7
|
+
class << self
|
8
|
+
# List all deployments.
|
9
|
+
#
|
10
|
+
# @yield [ReplicateClient::Deployment] Yields a deployment.
|
11
|
+
#
|
12
|
+
# @return [void]
|
13
|
+
def auto_paging_each(&block)
|
14
|
+
cursor = nil
|
15
|
+
|
16
|
+
loop do
|
17
|
+
url_params = cursor ? "?cursor=#{cursor}" : ""
|
18
|
+
attributes = ReplicateClient.client.get("#{INDEX_PATH}#{url_params}")
|
19
|
+
|
20
|
+
deployments = attributes["results"].map { |deployment| new(deployment) }
|
21
|
+
|
22
|
+
deployments.each(&block)
|
23
|
+
|
24
|
+
cursor = attributes["next"] ? URI.decode_www_form(URI.parse(attributes["next"]).query).to_h["cursor"] : nil
|
25
|
+
break if cursor.nil?
|
26
|
+
end
|
27
|
+
end
|
28
|
+
|
29
|
+
# Create a new deployment.
|
30
|
+
#
|
31
|
+
# @param name [String] The name of the deployment.
|
32
|
+
# @param model [ReplicateClient::Model, String] The model identifier in "owner/name" format.
|
33
|
+
# @param version_id [String, nil] The version ID of the model.
|
34
|
+
# @param hardware [ReplicateClient::Hardware, String] The hardware SKU.
|
35
|
+
# @param min_instances [Integer] The minimum number of instances.
|
36
|
+
# @param max_instances [Integer] The maximum number of instances.
|
37
|
+
#
|
38
|
+
# @return [ReplicateClient::Deployment]
|
39
|
+
def create!(name:, model:, hardware:, min_instances:, max_instances:, version_id: nil)
|
40
|
+
model_full_name = model.is_a?(Model) ? model.full_name : model
|
41
|
+
hardware_sku = hardware.is_a?(Hardware) ? hardware.sku : hardware
|
42
|
+
version = if version_id
|
43
|
+
version_id
|
44
|
+
elsif model.is_a?(Model)
|
45
|
+
model.version_id
|
46
|
+
else
|
47
|
+
Model.find(model).latest_version.id
|
48
|
+
end
|
49
|
+
|
50
|
+
body = {
|
51
|
+
name: name,
|
52
|
+
model: model_full_name,
|
53
|
+
version: version,
|
54
|
+
hardware: hardware_sku,
|
55
|
+
min_instances: min_instances,
|
56
|
+
max_instances: max_instances
|
57
|
+
}
|
58
|
+
|
59
|
+
attributes = ReplicateClient.client.post(INDEX_PATH, body)
|
60
|
+
new(attributes)
|
61
|
+
end
|
62
|
+
|
63
|
+
# Find a deployment by owner and name.
|
64
|
+
#
|
65
|
+
# @param full_name [String] The full name of the deployment in "owner/name" format.
|
66
|
+
#
|
67
|
+
# @return [ReplicateClient::Deployment]
|
68
|
+
def find(full_name)
|
69
|
+
path = build_path(**parse_full_name(full_name))
|
70
|
+
attributes = ReplicateClient.client.get(path)
|
71
|
+
new(attributes)
|
72
|
+
end
|
73
|
+
|
74
|
+
# Find a deployment by owner and name.
|
75
|
+
#
|
76
|
+
# @param owner [String] The owner of the deployment.
|
77
|
+
# @param name [String] The name of the deployment.
|
78
|
+
#
|
79
|
+
# @return [ReplicateClient::Deployment]
|
80
|
+
def find_by!(owner:, name:)
|
81
|
+
path = build_path(owner: owner, name: name)
|
82
|
+
attributes = ReplicateClient.client.get(path)
|
83
|
+
new(attributes)
|
84
|
+
end
|
85
|
+
|
86
|
+
# Find a deployment by owner and name.
|
87
|
+
#
|
88
|
+
# @param owner [String] The owner of the deployment.
|
89
|
+
# @param name [String] The name of the deployment.
|
90
|
+
#
|
91
|
+
# @return [ReplicateClient::Deployment, nil]
|
92
|
+
def find_by(owner:, name:)
|
93
|
+
find_by!(owner: owner, name: name)
|
94
|
+
rescue ReplicateClient::NotFoundError
|
95
|
+
nil
|
96
|
+
end
|
97
|
+
|
98
|
+
# Delete a deployment.
|
99
|
+
#
|
100
|
+
# @param owner [String] The owner of the deployment.
|
101
|
+
# @param name [String] The name of the deployment.
|
102
|
+
#
|
103
|
+
# @return [void]
|
104
|
+
def destroy!(owner:, name:)
|
105
|
+
path = build_path(owner: owner, name: name)
|
106
|
+
ReplicateClient.client.delete(path)
|
107
|
+
end
|
108
|
+
|
109
|
+
# Build the path for a specific deployment.
|
110
|
+
#
|
111
|
+
# @param owner [String] The owner of the deployment.
|
112
|
+
# @param name [String] The name of the deployment.
|
113
|
+
#
|
114
|
+
# @return [String]
|
115
|
+
def build_path(owner:, name:)
|
116
|
+
"#{INDEX_PATH}/#{owner}/#{name}"
|
117
|
+
end
|
118
|
+
|
119
|
+
# Parse the full name for a deployment.
|
120
|
+
#
|
121
|
+
# @param full_name [String] The full name of the deployment.
|
122
|
+
#
|
123
|
+
# @return [Hash]
|
124
|
+
def parse_full_name(full_name)
|
125
|
+
parts = full_name.split("/")
|
126
|
+
{ owner: parts[0], name: parts[1] }
|
127
|
+
end
|
128
|
+
end
|
129
|
+
|
130
|
+
# Attributes for deployment.
|
131
|
+
attr_accessor :owner, :name, :current_release
|
132
|
+
|
133
|
+
# Initialize a new deployment instance.
|
134
|
+
#
|
135
|
+
# @param attributes [Hash] The attributes of the deployment.
|
136
|
+
#
|
137
|
+
# @return [ReplicateClient::Deployment]
|
138
|
+
def initialize(attributes)
|
139
|
+
reset_attributes(attributes)
|
140
|
+
end
|
141
|
+
|
142
|
+
# Destroy the deployment.
|
143
|
+
#
|
144
|
+
# @return [void]
|
145
|
+
def destroy!
|
146
|
+
self.class.destroy!(owner: owner, name: name)
|
147
|
+
end
|
148
|
+
|
149
|
+
# Update the deployment.
|
150
|
+
#
|
151
|
+
# @param hardware [String, nil] The hardware SKU.
|
152
|
+
# @param min_instances [Integer, nil] The minimum number of instances.
|
153
|
+
# @param max_instances [Integer, nil] The maximum number of instances.
|
154
|
+
# @param version [ReplicateClient::Version, String, nil] The version ID of the model.
|
155
|
+
#
|
156
|
+
# @return [void]
|
157
|
+
def update!(hardware: nil, min_instances: nil, max_instances: nil, version: nil)
|
158
|
+
version_id = version.is_a?(Version) ? version.id : version
|
159
|
+
path = build_path(owner: owner, name: name)
|
160
|
+
body = {
|
161
|
+
hardware: hardware,
|
162
|
+
min_instances: min_instances,
|
163
|
+
max_instances: max_instances,
|
164
|
+
version: version_id
|
165
|
+
}.compact
|
166
|
+
|
167
|
+
attributes = ReplicateClient.client.patch(path, body)
|
168
|
+
reset_attributes(attributes)
|
169
|
+
end
|
170
|
+
|
171
|
+
# Reload the deployment.
|
172
|
+
#
|
173
|
+
# @return [void]
|
174
|
+
def reload!
|
175
|
+
attributes = ReplicateClient.client.get(path)
|
176
|
+
reset_attributes(attributes)
|
177
|
+
end
|
178
|
+
|
179
|
+
# Build the path for the deployment.
|
180
|
+
#
|
181
|
+
# @return [String]
|
182
|
+
def path
|
183
|
+
self.class.build_path(owner: owner, name: name)
|
184
|
+
end
|
185
|
+
|
186
|
+
# Create prediction for the deployment.
|
187
|
+
#
|
188
|
+
# @param input [Hash] The input for the prediction.
|
189
|
+
# @param webhook_url [String, nil] The URL to send webhook events to.
|
190
|
+
# @param webhook_events_filter [Array<String>, nil] The events to send to the webhook.
|
191
|
+
#
|
192
|
+
# @return [ReplicateClient::Prediction]
|
193
|
+
def create_prediction!(input, webhook_url: nil, webhook_events_filter: nil)
|
194
|
+
Prediction.create_for_deployment!(
|
195
|
+
deployment: self,
|
196
|
+
input: input,
|
197
|
+
webhook_url: webhook_url,
|
198
|
+
webhook_events_filter: webhook_events_filter
|
199
|
+
)
|
200
|
+
end
|
201
|
+
|
202
|
+
private
|
203
|
+
|
204
|
+
# Set the attributes of the deployment.
|
205
|
+
#
|
206
|
+
# @param attributes [Hash] The attributes of the deployment.
|
207
|
+
#
|
208
|
+
# @return [void]
|
209
|
+
def reset_attributes(attributes)
|
210
|
+
@owner = attributes["owner"]
|
211
|
+
@name = attributes["name"]
|
212
|
+
@current_release = attributes["current_release"]
|
213
|
+
end
|
214
|
+
end
|
215
|
+
end
|