replicate-client 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/.rubocop.yml +42 -0
- data/.ruby-version +1 -0
- data/Gemfile +9 -0
- data/Gemfile.lock +61 -0
- data/README.md +202 -0
- data/Rakefile +16 -0
- data/lib/replicate-client/client.rb +116 -0
- data/lib/replicate-client/deployment.rb +215 -0
- data/lib/replicate-client/hardware.rb +53 -0
- data/lib/replicate-client/model/version.rb +132 -0
- data/lib/replicate-client/model.rb +335 -0
- data/lib/replicate-client/prediction.rb +287 -0
- data/lib/replicate-client/training.rb +250 -0
- data/lib/replicate-client/version.rb +5 -0
- data/lib/replicate-client.rb +78 -0
- metadata +76 -0
@@ -0,0 +1,53 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module ReplicateClient
|
4
|
+
class Hardware
|
5
|
+
INDEX_PATH = "/hardware"
|
6
|
+
|
7
|
+
class << self
|
8
|
+
# List all available hardware.
|
9
|
+
#
|
10
|
+
# @return [Array<ReplicateClient::Hardware>]
|
11
|
+
def all
|
12
|
+
response = ReplicateClient.client.get(INDEX_PATH)
|
13
|
+
response.map { |attributes| new(attributes) }
|
14
|
+
end
|
15
|
+
|
16
|
+
# Find hardware by SKU.
|
17
|
+
#
|
18
|
+
# @param sku [String] The SKU of the hardware.
|
19
|
+
#
|
20
|
+
# @return [ReplicateClient::Hardware, nil]
|
21
|
+
def find_by(sku:)
|
22
|
+
all.find { |hardware| hardware.sku == sku }
|
23
|
+
end
|
24
|
+
end
|
25
|
+
|
26
|
+
# The SKU of the hardware.
|
27
|
+
#
|
28
|
+
# @return [String]
|
29
|
+
attr_accessor :sku
|
30
|
+
|
31
|
+
# The name of the hardware.
|
32
|
+
#
|
33
|
+
# @return [String]
|
34
|
+
attr_accessor :name
|
35
|
+
|
36
|
+
# Initialize a new hardware instance.
|
37
|
+
#
|
38
|
+
# @param attributes [Hash] The attributes of the hardware.
|
39
|
+
#
|
40
|
+
# @return [ReplicateClient::Hardware]
|
41
|
+
def initialize(attributes)
|
42
|
+
@sku = attributes["sku"]
|
43
|
+
@name = attributes["name"]
|
44
|
+
end
|
45
|
+
|
46
|
+
# Convert the hardware object to a string representation.
|
47
|
+
#
|
48
|
+
# @return [String]
|
49
|
+
def to_s
|
50
|
+
"#{name} (#{sku})"
|
51
|
+
end
|
52
|
+
end
|
53
|
+
end
|
@@ -0,0 +1,132 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module ReplicateClient
|
4
|
+
class Model
|
5
|
+
class Version
|
6
|
+
INDEX_PATH = "/versions"
|
7
|
+
|
8
|
+
class << self
|
9
|
+
# Find a version of a model.
|
10
|
+
#
|
11
|
+
# @param owner [String] The owner of the model.
|
12
|
+
# @param name [String] The name of the model.
|
13
|
+
# @param version_id [String] The version id of the model.
|
14
|
+
#
|
15
|
+
# @return [ReplicateClient::Model::Version]
|
16
|
+
def find_by!(owner:, name:, version_id:)
|
17
|
+
path = build_path(owner: owner, name: name, version_id: version_id)
|
18
|
+
response = ReplicateClient.client.get(path)
|
19
|
+
new(response)
|
20
|
+
end
|
21
|
+
|
22
|
+
# Find a version of a model.
|
23
|
+
#
|
24
|
+
# @param owner [String] The owner of the model.
|
25
|
+
# @param name [String] The name of the model.
|
26
|
+
# @param version_id [String] The version id of the model.
|
27
|
+
#
|
28
|
+
# @return [ReplicateClient::Model::Version]
|
29
|
+
def find_by(owner:, name:, version_id:)
|
30
|
+
find_by!(owner: owner, name: name, version_id: version_id)
|
31
|
+
rescue ReplicateClient::NotFoundError
|
32
|
+
nil
|
33
|
+
end
|
34
|
+
|
35
|
+
# Get all versions of a model.
|
36
|
+
#
|
37
|
+
# @param owner [String] The owner of the model.
|
38
|
+
# @param name [String] The name of the model.
|
39
|
+
#
|
40
|
+
# @return [Array<ReplicateClient::Model::Version>]
|
41
|
+
def where(owner:, name:)
|
42
|
+
versions = []
|
43
|
+
|
44
|
+
auto_paging_each(owner: owner, name: name) do |version|
|
45
|
+
versions << version
|
46
|
+
end
|
47
|
+
|
48
|
+
versions
|
49
|
+
end
|
50
|
+
|
51
|
+
# Paginate through all models.
|
52
|
+
#
|
53
|
+
# @param name [String] The name of the model.
|
54
|
+
# @param owner [String] The owner of the model.
|
55
|
+
# @yield [ReplicateClient::Model] Yields a model.
|
56
|
+
#
|
57
|
+
# @return [void]
|
58
|
+
def auto_paging_each(owner:, name:, &block)
|
59
|
+
cursor = nil
|
60
|
+
model_path = Model.build_path(owner: owner, name: name)
|
61
|
+
|
62
|
+
loop do
|
63
|
+
url_params = cursor ? "?cursor=#{cursor}" : ""
|
64
|
+
attributes = ReplicateClient.client.get("#{model_path}#{INDEX_PATH}#{url_params}")
|
65
|
+
|
66
|
+
versions = attributes["results"].map { |version| new(version) }
|
67
|
+
|
68
|
+
versions.each(&block)
|
69
|
+
|
70
|
+
cursor = attributes["next"] ? URI.decode_www_form(URI.parse(attributes["next"]).query).to_h["cursor"] : nil
|
71
|
+
break if cursor.nil?
|
72
|
+
end
|
73
|
+
end
|
74
|
+
|
75
|
+
# Build the path for the model version.
|
76
|
+
#
|
77
|
+
# @param owner [String] The owner of the model.
|
78
|
+
# @param name [String] The name of the model.
|
79
|
+
# @param version_id [String] The version id of the model.
|
80
|
+
#
|
81
|
+
# @return [String]
|
82
|
+
def build_path(owner:, name:, version_id:)
|
83
|
+
model_path = Model.build_path(owner: owner, name: name)
|
84
|
+
"#{model_path}#{INDEX_PATH}/#{version_id}"
|
85
|
+
end
|
86
|
+
end
|
87
|
+
|
88
|
+
# The ID of the model version.
|
89
|
+
#
|
90
|
+
# @return [String]
|
91
|
+
attr_accessor :id
|
92
|
+
|
93
|
+
# The date the model version was created.
|
94
|
+
#
|
95
|
+
# @return [Time]
|
96
|
+
attr_accessor :created_at
|
97
|
+
|
98
|
+
# The cog version of the model version.
|
99
|
+
#
|
100
|
+
# @return [String]
|
101
|
+
attr_accessor :cog_version
|
102
|
+
|
103
|
+
# The OpenAPI schema of the model version.
|
104
|
+
#
|
105
|
+
# @return [Hash]
|
106
|
+
attr_accessor :openapi_schema
|
107
|
+
|
108
|
+
def initialize(attributes)
|
109
|
+
@id = attributes["id"]
|
110
|
+
@created_at = Time.parse(attributes["created_at"])
|
111
|
+
@cog_version = attributes["cog_version"]
|
112
|
+
@openapi_schema = attributes["openapi_schema"]
|
113
|
+
end
|
114
|
+
|
115
|
+
# Create a new prediction.
|
116
|
+
#
|
117
|
+
# @param input [Hash] The input data for the prediction.
|
118
|
+
# @param webhook_url [String, nil] A URL to receive webhook notifications.
|
119
|
+
# @param webhook_events_filter [Array, nil] The events to trigger webhook requests.
|
120
|
+
#
|
121
|
+
# @return [ReplicateClient::Prediction]
|
122
|
+
def create_prediction!(input:, webhook_url: nil, webhook_events_filter: nil)
|
123
|
+
Prediction.create!(
|
124
|
+
version: self,
|
125
|
+
input: input,
|
126
|
+
webhook_url: webhook_url,
|
127
|
+
webhook_events_filter: webhook_events_filter
|
128
|
+
)
|
129
|
+
end
|
130
|
+
end
|
131
|
+
end
|
132
|
+
end
|
@@ -0,0 +1,335 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require_relative "model/version"
|
4
|
+
|
5
|
+
module ReplicateClient
|
6
|
+
class Model
|
7
|
+
INDEX_PATH = "/models"
|
8
|
+
|
9
|
+
module Visibility
|
10
|
+
PUBLIC = "public"
|
11
|
+
PRIVATE = "private"
|
12
|
+
end
|
13
|
+
|
14
|
+
class << self
|
15
|
+
# Find a model.
|
16
|
+
#
|
17
|
+
# @param owner [String] The owner of the model.
|
18
|
+
# @param name [String] The name of the model.
|
19
|
+
#
|
20
|
+
# @return [ReplicateClient::Model]
|
21
|
+
def find_by!(owner:, name:, version_id: nil)
|
22
|
+
path = build_path(owner: owner, name: name)
|
23
|
+
response = ReplicateClient.client.get(path)
|
24
|
+
new(response, version_id: version_id)
|
25
|
+
end
|
26
|
+
|
27
|
+
# Find a model.
|
28
|
+
#
|
29
|
+
# @param owner [String] The owner of the model.
|
30
|
+
# @param name [String] The name of the model.
|
31
|
+
# @param version_id [String] The version id of the model to use.
|
32
|
+
#
|
33
|
+
# @return [ReplicateClient::Model]
|
34
|
+
def find_by(owner:, name:, version_id: nil)
|
35
|
+
find_by!(owner: owner, name: name, version_id: version_id)
|
36
|
+
rescue ReplicateClient::NotFoundError
|
37
|
+
nil
|
38
|
+
end
|
39
|
+
|
40
|
+
# Find a model by name.
|
41
|
+
# The name should be in the format "owner/name".
|
42
|
+
#
|
43
|
+
# @param name [String] The name of the model.
|
44
|
+
# @param version_id [String] The version id of the model to use.
|
45
|
+
#
|
46
|
+
# @return [ReplicateClient::Model]
|
47
|
+
def find(name, version_id: nil)
|
48
|
+
find_by!(**parse_model_name(name), version_id: version_id)
|
49
|
+
end
|
50
|
+
|
51
|
+
# Build the path for the model.
|
52
|
+
#
|
53
|
+
# @param owner [String] The owner of the model.
|
54
|
+
# @param name [String] The name of the model.
|
55
|
+
#
|
56
|
+
# @return [String]
|
57
|
+
def build_path(owner:, name:)
|
58
|
+
"#{INDEX_PATH}/#{owner}/#{name}"
|
59
|
+
end
|
60
|
+
|
61
|
+
# Paginate through all models.
|
62
|
+
#
|
63
|
+
# @yield [ReplicateClient::Model] Yields a model.
|
64
|
+
#
|
65
|
+
# @return [void]
|
66
|
+
def auto_paging_each(&block)
|
67
|
+
cursor = nil
|
68
|
+
|
69
|
+
loop do
|
70
|
+
url_params = cursor ? "?cursor=#{cursor}" : ""
|
71
|
+
attributes = ReplicateClient.client.get("#{INDEX_PATH}#{url_params}")
|
72
|
+
|
73
|
+
models = attributes["results"].map { |model| new(model) }
|
74
|
+
|
75
|
+
models.each(&block)
|
76
|
+
|
77
|
+
cursor = attributes["next"] ? URI.decode_www_form(URI.parse(attributes["next"]).query).to_h["cursor"] : nil
|
78
|
+
break if cursor.nil?
|
79
|
+
end
|
80
|
+
end
|
81
|
+
|
82
|
+
# Create a new model.
|
83
|
+
#
|
84
|
+
# @param owner [String] The owner of the model.
|
85
|
+
# @param name [String] The name of the model.
|
86
|
+
# @param description [String] A description of the model.
|
87
|
+
# @param visibility [String] "public" or "private".
|
88
|
+
# @param hardware [String] The SKU for the hardware used to run the model.
|
89
|
+
# @param github_url [String, nil] A URL for the model’s source code on GitHub.
|
90
|
+
# @param paper_url [String, nil] A URL for the model’s paper.
|
91
|
+
# @param license_url [String, nil] A URL for the model’s license.
|
92
|
+
# @param cover_image_url [String, nil] A URL for the model’s cover image.
|
93
|
+
#
|
94
|
+
# @return [ReplicateClient::Model]
|
95
|
+
def create!(
|
96
|
+
owner:,
|
97
|
+
name:,
|
98
|
+
description:,
|
99
|
+
visibility:,
|
100
|
+
hardware:,
|
101
|
+
github_url: nil,
|
102
|
+
paper_url: nil,
|
103
|
+
license_url: nil,
|
104
|
+
cover_image_url: nil
|
105
|
+
)
|
106
|
+
new_attributes = {
|
107
|
+
owner: owner,
|
108
|
+
name: name,
|
109
|
+
description: description,
|
110
|
+
visibility: visibility,
|
111
|
+
hardware: hardware,
|
112
|
+
github_url: github_url,
|
113
|
+
paper_url: paper_url,
|
114
|
+
license_url: license_url,
|
115
|
+
cover_image_url: cover_image_url
|
116
|
+
}
|
117
|
+
|
118
|
+
attributes = ReplicateClient.client.post("/models", new_attributes)
|
119
|
+
|
120
|
+
new(attributes)
|
121
|
+
end
|
122
|
+
|
123
|
+
# Parse the model name.
|
124
|
+
#
|
125
|
+
# @param model_name [String] The name of the model.
|
126
|
+
#
|
127
|
+
# @return [Hash]
|
128
|
+
def parse_model_name(model_name)
|
129
|
+
parts = model_name.split("/")
|
130
|
+
|
131
|
+
{
|
132
|
+
owner: parts[0],
|
133
|
+
name: parts[1]
|
134
|
+
}
|
135
|
+
end
|
136
|
+
end
|
137
|
+
|
138
|
+
# The URL of the model.
|
139
|
+
#
|
140
|
+
# @return [String]
|
141
|
+
attr_accessor :url
|
142
|
+
|
143
|
+
# The name of the user or organization that will own the model.
|
144
|
+
#
|
145
|
+
# @return [String]
|
146
|
+
attr_accessor :owner
|
147
|
+
|
148
|
+
# The name of the model.
|
149
|
+
#
|
150
|
+
# @return [String]
|
151
|
+
attr_accessor :name
|
152
|
+
|
153
|
+
# A description of the model.
|
154
|
+
#
|
155
|
+
# @return [String]
|
156
|
+
attr_accessor :description
|
157
|
+
|
158
|
+
# Whether the model should be public or private. A public model can be viewed and run by anyone, whereas
|
159
|
+
# a private model can be viewed and run only by the user or organization members that own the model.
|
160
|
+
#
|
161
|
+
# @return [String] "public" or "private"
|
162
|
+
attr_accessor :visibility
|
163
|
+
|
164
|
+
# A URL for the model’s source code on GitHub.
|
165
|
+
#
|
166
|
+
# @return [String]
|
167
|
+
attr_accessor :github_url
|
168
|
+
|
169
|
+
# A URL for the model’s paper.
|
170
|
+
#
|
171
|
+
# @return [String]
|
172
|
+
attr_accessor :paper_url
|
173
|
+
|
174
|
+
# A URL for the model’s license.
|
175
|
+
#
|
176
|
+
# @return [String]
|
177
|
+
attr_accessor :license_url
|
178
|
+
|
179
|
+
# The number of times the model has been run.
|
180
|
+
#
|
181
|
+
# @return [Integer]
|
182
|
+
attr_accessor :run_count
|
183
|
+
|
184
|
+
# A URL for the model’s cover image. This should be an image file.
|
185
|
+
#
|
186
|
+
# @return [String]
|
187
|
+
attr_accessor :cover_image_url
|
188
|
+
|
189
|
+
# The default example of the model.
|
190
|
+
#
|
191
|
+
# @return [Hash]
|
192
|
+
attr_accessor :default_example
|
193
|
+
|
194
|
+
# The current version id of the model.
|
195
|
+
#
|
196
|
+
# @return [String]
|
197
|
+
attr_accessor :version_id
|
198
|
+
|
199
|
+
# The id of the latest version of the model.
|
200
|
+
#
|
201
|
+
# @return [Hash]
|
202
|
+
attr_accessor :latest_version_id
|
203
|
+
|
204
|
+
# Initialize a new model.
|
205
|
+
#
|
206
|
+
# @param attributes [Hash] The attributes of the model.
|
207
|
+
# @param version_id [String] The version of the model to use.
|
208
|
+
#
|
209
|
+
# @return [ReplicateClient::Model]
|
210
|
+
def initialize(attributes, version_id: nil)
|
211
|
+
reset_attributes(attributes, version_id: version_id)
|
212
|
+
end
|
213
|
+
|
214
|
+
# The path of the model.
|
215
|
+
#
|
216
|
+
# @return [String]
|
217
|
+
def path
|
218
|
+
self.class.build_path(owner: owner, name: name)
|
219
|
+
end
|
220
|
+
|
221
|
+
# Delete the model.
|
222
|
+
#
|
223
|
+
# @return [void]
|
224
|
+
def destroy!
|
225
|
+
ReplicateClient.client.delete(path)
|
226
|
+
end
|
227
|
+
|
228
|
+
# The path of the current version.
|
229
|
+
#
|
230
|
+
# @return [String]
|
231
|
+
def version_path
|
232
|
+
Version.build_path(owner: owner, name: name, version_id: version_id)
|
233
|
+
end
|
234
|
+
|
235
|
+
# The version of the model.
|
236
|
+
#
|
237
|
+
# @return [ReplicateClient::Model::Version]
|
238
|
+
def version
|
239
|
+
@version ||= Version.find_by!(owner: owner, name: name, version_id: version_id)
|
240
|
+
end
|
241
|
+
|
242
|
+
# The latest version of the model.
|
243
|
+
#
|
244
|
+
# @return [ReplicateClient::Model::Version]
|
245
|
+
def latest_version
|
246
|
+
@latest_version ||= Version.find_by!(owner: owner, name: name, version_id: latest_version_id)
|
247
|
+
end
|
248
|
+
|
249
|
+
# The versions of the model.
|
250
|
+
#
|
251
|
+
# @return [Array<ReplicateClient::Model::Version>]
|
252
|
+
def versions
|
253
|
+
@versions ||= Version.where(owner: owner, name: name)
|
254
|
+
end
|
255
|
+
|
256
|
+
# Create a new prediction for the model.
|
257
|
+
#
|
258
|
+
# @param input [Hash] The input data for the prediction.
|
259
|
+
#
|
260
|
+
# @return [ReplicateClient::Prediction]
|
261
|
+
def create_prediction!(input:, webhook_url: nil, webhook_events_filter: nil)
|
262
|
+
if version_id.nil?
|
263
|
+
Prediction.create_for_official_model!(
|
264
|
+
model: self,
|
265
|
+
input: input,
|
266
|
+
webhook_url: webhook_url,
|
267
|
+
webhook_events_filter: webhook_events_filter
|
268
|
+
)
|
269
|
+
else
|
270
|
+
Prediction.create!(
|
271
|
+
version: version_id,
|
272
|
+
input: input,
|
273
|
+
webhook_url: webhook_url,
|
274
|
+
webhook_events_filter: webhook_events_filter
|
275
|
+
)
|
276
|
+
end
|
277
|
+
end
|
278
|
+
|
279
|
+
# Reload the model.
|
280
|
+
#
|
281
|
+
# @return [void]
|
282
|
+
def reload!
|
283
|
+
attributes = ReplicateClient.client.get(path)
|
284
|
+
reset_attributes(attributes)
|
285
|
+
end
|
286
|
+
|
287
|
+
# Check if the model is public.
|
288
|
+
#
|
289
|
+
# @return [Boolean]
|
290
|
+
def public?
|
291
|
+
visibility == Visibility::PUBLIC
|
292
|
+
end
|
293
|
+
|
294
|
+
# Check if the model is private.
|
295
|
+
#
|
296
|
+
# @return [Boolean]
|
297
|
+
def private?
|
298
|
+
visibility == Visibility::PRIVATE
|
299
|
+
end
|
300
|
+
|
301
|
+
# Returns the full name of the model in "owner/name" format.
|
302
|
+
#
|
303
|
+
# @return [String]
|
304
|
+
def full_name
|
305
|
+
"#{owner}/#{name}"
|
306
|
+
end
|
307
|
+
|
308
|
+
private
|
309
|
+
|
310
|
+
# Set the attributes of the model.
|
311
|
+
#
|
312
|
+
# @param attributes [Hash] The attributes of the model.
|
313
|
+
# @param version_id [String] The version of the model to use.
|
314
|
+
#
|
315
|
+
# @return [void]
|
316
|
+
def reset_attributes(attributes, version_id: nil)
|
317
|
+
@owner = attributes["owner"]
|
318
|
+
@name = attributes["name"]
|
319
|
+
@description = attributes["description"]
|
320
|
+
@visibility = attributes["visibility"]
|
321
|
+
@github_url = attributes["github_url"]
|
322
|
+
@paper_url = attributes["paper_url"]
|
323
|
+
@license_url = attributes["license_url"]
|
324
|
+
@run_count = attributes["run_count"]
|
325
|
+
@cover_image_url = attributes["cover_image_url"]
|
326
|
+
@default_example = attributes["default_example"]
|
327
|
+
@latest_version_id = attributes.dig("latest_version", "id")
|
328
|
+
@version_id = version_id || attributes.dig("latest_version", "id")
|
329
|
+
|
330
|
+
@version = nil
|
331
|
+
@versions = nil
|
332
|
+
@latest_version = nil
|
333
|
+
end
|
334
|
+
end
|
335
|
+
end
|