replicate-client 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,53 @@
1
+ # frozen_string_literal: true
2
+
3
+ module ReplicateClient
4
+ class Hardware
5
+ INDEX_PATH = "/hardware"
6
+
7
+ class << self
8
+ # List all available hardware.
9
+ #
10
+ # @return [Array<ReplicateClient::Hardware>]
11
+ def all
12
+ response = ReplicateClient.client.get(INDEX_PATH)
13
+ response.map { |attributes| new(attributes) }
14
+ end
15
+
16
+ # Find hardware by SKU.
17
+ #
18
+ # @param sku [String] The SKU of the hardware.
19
+ #
20
+ # @return [ReplicateClient::Hardware, nil]
21
+ def find_by(sku:)
22
+ all.find { |hardware| hardware.sku == sku }
23
+ end
24
+ end
25
+
26
+ # The SKU of the hardware.
27
+ #
28
+ # @return [String]
29
+ attr_accessor :sku
30
+
31
+ # The name of the hardware.
32
+ #
33
+ # @return [String]
34
+ attr_accessor :name
35
+
36
+ # Initialize a new hardware instance.
37
+ #
38
+ # @param attributes [Hash] The attributes of the hardware.
39
+ #
40
+ # @return [ReplicateClient::Hardware]
41
+ def initialize(attributes)
42
+ @sku = attributes["sku"]
43
+ @name = attributes["name"]
44
+ end
45
+
46
+ # Convert the hardware object to a string representation.
47
+ #
48
+ # @return [String]
49
+ def to_s
50
+ "#{name} (#{sku})"
51
+ end
52
+ end
53
+ end
@@ -0,0 +1,132 @@
1
+ # frozen_string_literal: true
2
+
3
+ module ReplicateClient
4
+ class Model
5
+ class Version
6
+ INDEX_PATH = "/versions"
7
+
8
+ class << self
9
+ # Find a version of a model.
10
+ #
11
+ # @param owner [String] The owner of the model.
12
+ # @param name [String] The name of the model.
13
+ # @param version_id [String] The version id of the model.
14
+ #
15
+ # @return [ReplicateClient::Model::Version]
16
+ def find_by!(owner:, name:, version_id:)
17
+ path = build_path(owner: owner, name: name, version_id: version_id)
18
+ response = ReplicateClient.client.get(path)
19
+ new(response)
20
+ end
21
+
22
+ # Find a version of a model.
23
+ #
24
+ # @param owner [String] The owner of the model.
25
+ # @param name [String] The name of the model.
26
+ # @param version_id [String] The version id of the model.
27
+ #
28
+ # @return [ReplicateClient::Model::Version]
29
+ def find_by(owner:, name:, version_id:)
30
+ find_by!(owner: owner, name: name, version_id: version_id)
31
+ rescue ReplicateClient::NotFoundError
32
+ nil
33
+ end
34
+
35
+ # Get all versions of a model.
36
+ #
37
+ # @param owner [String] The owner of the model.
38
+ # @param name [String] The name of the model.
39
+ #
40
+ # @return [Array<ReplicateClient::Model::Version>]
41
+ def where(owner:, name:)
42
+ versions = []
43
+
44
+ auto_paging_each(owner: owner, name: name) do |version|
45
+ versions << version
46
+ end
47
+
48
+ versions
49
+ end
50
+
51
+ # Paginate through all models.
52
+ #
53
+ # @param name [String] The name of the model.
54
+ # @param owner [String] The owner of the model.
55
+ # @yield [ReplicateClient::Model] Yields a model.
56
+ #
57
+ # @return [void]
58
+ def auto_paging_each(owner:, name:, &block)
59
+ cursor = nil
60
+ model_path = Model.build_path(owner: owner, name: name)
61
+
62
+ loop do
63
+ url_params = cursor ? "?cursor=#{cursor}" : ""
64
+ attributes = ReplicateClient.client.get("#{model_path}#{INDEX_PATH}#{url_params}")
65
+
66
+ versions = attributes["results"].map { |version| new(version) }
67
+
68
+ versions.each(&block)
69
+
70
+ cursor = attributes["next"] ? URI.decode_www_form(URI.parse(attributes["next"]).query).to_h["cursor"] : nil
71
+ break if cursor.nil?
72
+ end
73
+ end
74
+
75
+ # Build the path for the model version.
76
+ #
77
+ # @param owner [String] The owner of the model.
78
+ # @param name [String] The name of the model.
79
+ # @param version_id [String] The version id of the model.
80
+ #
81
+ # @return [String]
82
+ def build_path(owner:, name:, version_id:)
83
+ model_path = Model.build_path(owner: owner, name: name)
84
+ "#{model_path}#{INDEX_PATH}/#{version_id}"
85
+ end
86
+ end
87
+
88
+ # The ID of the model version.
89
+ #
90
+ # @return [String]
91
+ attr_accessor :id
92
+
93
+ # The date the model version was created.
94
+ #
95
+ # @return [Time]
96
+ attr_accessor :created_at
97
+
98
+ # The cog version of the model version.
99
+ #
100
+ # @return [String]
101
+ attr_accessor :cog_version
102
+
103
+ # The OpenAPI schema of the model version.
104
+ #
105
+ # @return [Hash]
106
+ attr_accessor :openapi_schema
107
+
108
+ def initialize(attributes)
109
+ @id = attributes["id"]
110
+ @created_at = Time.parse(attributes["created_at"])
111
+ @cog_version = attributes["cog_version"]
112
+ @openapi_schema = attributes["openapi_schema"]
113
+ end
114
+
115
+ # Create a new prediction.
116
+ #
117
+ # @param input [Hash] The input data for the prediction.
118
+ # @param webhook_url [String, nil] A URL to receive webhook notifications.
119
+ # @param webhook_events_filter [Array, nil] The events to trigger webhook requests.
120
+ #
121
+ # @return [ReplicateClient::Prediction]
122
+ def create_prediction!(input:, webhook_url: nil, webhook_events_filter: nil)
123
+ Prediction.create!(
124
+ version: self,
125
+ input: input,
126
+ webhook_url: webhook_url,
127
+ webhook_events_filter: webhook_events_filter
128
+ )
129
+ end
130
+ end
131
+ end
132
+ end
@@ -0,0 +1,335 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "model/version"
4
+
5
+ module ReplicateClient
6
+ class Model
7
+ INDEX_PATH = "/models"
8
+
9
+ module Visibility
10
+ PUBLIC = "public"
11
+ PRIVATE = "private"
12
+ end
13
+
14
+ class << self
15
+ # Find a model.
16
+ #
17
+ # @param owner [String] The owner of the model.
18
+ # @param name [String] The name of the model.
19
+ #
20
+ # @return [ReplicateClient::Model]
21
+ def find_by!(owner:, name:, version_id: nil)
22
+ path = build_path(owner: owner, name: name)
23
+ response = ReplicateClient.client.get(path)
24
+ new(response, version_id: version_id)
25
+ end
26
+
27
+ # Find a model.
28
+ #
29
+ # @param owner [String] The owner of the model.
30
+ # @param name [String] The name of the model.
31
+ # @param version_id [String] The version id of the model to use.
32
+ #
33
+ # @return [ReplicateClient::Model]
34
+ def find_by(owner:, name:, version_id: nil)
35
+ find_by!(owner: owner, name: name, version_id: version_id)
36
+ rescue ReplicateClient::NotFoundError
37
+ nil
38
+ end
39
+
40
+ # Find a model by name.
41
+ # The name should be in the format "owner/name".
42
+ #
43
+ # @param name [String] The name of the model.
44
+ # @param version_id [String] The version id of the model to use.
45
+ #
46
+ # @return [ReplicateClient::Model]
47
+ def find(name, version_id: nil)
48
+ find_by!(**parse_model_name(name), version_id: version_id)
49
+ end
50
+
51
+ # Build the path for the model.
52
+ #
53
+ # @param owner [String] The owner of the model.
54
+ # @param name [String] The name of the model.
55
+ #
56
+ # @return [String]
57
+ def build_path(owner:, name:)
58
+ "#{INDEX_PATH}/#{owner}/#{name}"
59
+ end
60
+
61
+ # Paginate through all models.
62
+ #
63
+ # @yield [ReplicateClient::Model] Yields a model.
64
+ #
65
+ # @return [void]
66
+ def auto_paging_each(&block)
67
+ cursor = nil
68
+
69
+ loop do
70
+ url_params = cursor ? "?cursor=#{cursor}" : ""
71
+ attributes = ReplicateClient.client.get("#{INDEX_PATH}#{url_params}")
72
+
73
+ models = attributes["results"].map { |model| new(model) }
74
+
75
+ models.each(&block)
76
+
77
+ cursor = attributes["next"] ? URI.decode_www_form(URI.parse(attributes["next"]).query).to_h["cursor"] : nil
78
+ break if cursor.nil?
79
+ end
80
+ end
81
+
82
+ # Create a new model.
83
+ #
84
+ # @param owner [String] The owner of the model.
85
+ # @param name [String] The name of the model.
86
+ # @param description [String] A description of the model.
87
+ # @param visibility [String] "public" or "private".
88
+ # @param hardware [String] The SKU for the hardware used to run the model.
89
+ # @param github_url [String, nil] A URL for the model’s source code on GitHub.
90
+ # @param paper_url [String, nil] A URL for the model’s paper.
91
+ # @param license_url [String, nil] A URL for the model’s license.
92
+ # @param cover_image_url [String, nil] A URL for the model’s cover image.
93
+ #
94
+ # @return [ReplicateClient::Model]
95
+ def create!(
96
+ owner:,
97
+ name:,
98
+ description:,
99
+ visibility:,
100
+ hardware:,
101
+ github_url: nil,
102
+ paper_url: nil,
103
+ license_url: nil,
104
+ cover_image_url: nil
105
+ )
106
+ new_attributes = {
107
+ owner: owner,
108
+ name: name,
109
+ description: description,
110
+ visibility: visibility,
111
+ hardware: hardware,
112
+ github_url: github_url,
113
+ paper_url: paper_url,
114
+ license_url: license_url,
115
+ cover_image_url: cover_image_url
116
+ }
117
+
118
+ attributes = ReplicateClient.client.post("/models", new_attributes)
119
+
120
+ new(attributes)
121
+ end
122
+
123
+ # Parse the model name.
124
+ #
125
+ # @param model_name [String] The name of the model.
126
+ #
127
+ # @return [Hash]
128
+ def parse_model_name(model_name)
129
+ parts = model_name.split("/")
130
+
131
+ {
132
+ owner: parts[0],
133
+ name: parts[1]
134
+ }
135
+ end
136
+ end
137
+
138
+ # The URL of the model.
139
+ #
140
+ # @return [String]
141
+ attr_accessor :url
142
+
143
+ # The name of the user or organization that will own the model.
144
+ #
145
+ # @return [String]
146
+ attr_accessor :owner
147
+
148
+ # The name of the model.
149
+ #
150
+ # @return [String]
151
+ attr_accessor :name
152
+
153
+ # A description of the model.
154
+ #
155
+ # @return [String]
156
+ attr_accessor :description
157
+
158
+ # Whether the model should be public or private. A public model can be viewed and run by anyone, whereas
159
+ # a private model can be viewed and run only by the user or organization members that own the model.
160
+ #
161
+ # @return [String] "public" or "private"
162
+ attr_accessor :visibility
163
+
164
+ # A URL for the model’s source code on GitHub.
165
+ #
166
+ # @return [String]
167
+ attr_accessor :github_url
168
+
169
+ # A URL for the model’s paper.
170
+ #
171
+ # @return [String]
172
+ attr_accessor :paper_url
173
+
174
+ # A URL for the model’s license.
175
+ #
176
+ # @return [String]
177
+ attr_accessor :license_url
178
+
179
+ # The number of times the model has been run.
180
+ #
181
+ # @return [Integer]
182
+ attr_accessor :run_count
183
+
184
+ # A URL for the model’s cover image. This should be an image file.
185
+ #
186
+ # @return [String]
187
+ attr_accessor :cover_image_url
188
+
189
+ # The default example of the model.
190
+ #
191
+ # @return [Hash]
192
+ attr_accessor :default_example
193
+
194
+ # The current version id of the model.
195
+ #
196
+ # @return [String]
197
+ attr_accessor :version_id
198
+
199
+ # The id of the latest version of the model.
200
+ #
201
+ # @return [Hash]
202
+ attr_accessor :latest_version_id
203
+
204
+ # Initialize a new model.
205
+ #
206
+ # @param attributes [Hash] The attributes of the model.
207
+ # @param version_id [String] The version of the model to use.
208
+ #
209
+ # @return [ReplicateClient::Model]
210
+ def initialize(attributes, version_id: nil)
211
+ reset_attributes(attributes, version_id: version_id)
212
+ end
213
+
214
+ # The path of the model.
215
+ #
216
+ # @return [String]
217
+ def path
218
+ self.class.build_path(owner: owner, name: name)
219
+ end
220
+
221
+ # Delete the model.
222
+ #
223
+ # @return [void]
224
+ def destroy!
225
+ ReplicateClient.client.delete(path)
226
+ end
227
+
228
+ # The path of the current version.
229
+ #
230
+ # @return [String]
231
+ def version_path
232
+ Version.build_path(owner: owner, name: name, version_id: version_id)
233
+ end
234
+
235
+ # The version of the model.
236
+ #
237
+ # @return [ReplicateClient::Model::Version]
238
+ def version
239
+ @version ||= Version.find_by!(owner: owner, name: name, version_id: version_id)
240
+ end
241
+
242
+ # The latest version of the model.
243
+ #
244
+ # @return [ReplicateClient::Model::Version]
245
+ def latest_version
246
+ @latest_version ||= Version.find_by!(owner: owner, name: name, version_id: latest_version_id)
247
+ end
248
+
249
+ # The versions of the model.
250
+ #
251
+ # @return [Array<ReplicateClient::Model::Version>]
252
+ def versions
253
+ @versions ||= Version.where(owner: owner, name: name)
254
+ end
255
+
256
+ # Create a new prediction for the model.
257
+ #
258
+ # @param input [Hash] The input data for the prediction.
259
+ #
260
+ # @return [ReplicateClient::Prediction]
261
+ def create_prediction!(input:, webhook_url: nil, webhook_events_filter: nil)
262
+ if version_id.nil?
263
+ Prediction.create_for_official_model!(
264
+ model: self,
265
+ input: input,
266
+ webhook_url: webhook_url,
267
+ webhook_events_filter: webhook_events_filter
268
+ )
269
+ else
270
+ Prediction.create!(
271
+ version: version_id,
272
+ input: input,
273
+ webhook_url: webhook_url,
274
+ webhook_events_filter: webhook_events_filter
275
+ )
276
+ end
277
+ end
278
+
279
+ # Reload the model.
280
+ #
281
+ # @return [void]
282
+ def reload!
283
+ attributes = ReplicateClient.client.get(path)
284
+ reset_attributes(attributes)
285
+ end
286
+
287
+ # Check if the model is public.
288
+ #
289
+ # @return [Boolean]
290
+ def public?
291
+ visibility == Visibility::PUBLIC
292
+ end
293
+
294
+ # Check if the model is private.
295
+ #
296
+ # @return [Boolean]
297
+ def private?
298
+ visibility == Visibility::PRIVATE
299
+ end
300
+
301
+ # Returns the full name of the model in "owner/name" format.
302
+ #
303
+ # @return [String]
304
+ def full_name
305
+ "#{owner}/#{name}"
306
+ end
307
+
308
+ private
309
+
310
+ # Set the attributes of the model.
311
+ #
312
+ # @param attributes [Hash] The attributes of the model.
313
+ # @param version_id [String] The version of the model to use.
314
+ #
315
+ # @return [void]
316
+ def reset_attributes(attributes, version_id: nil)
317
+ @owner = attributes["owner"]
318
+ @name = attributes["name"]
319
+ @description = attributes["description"]
320
+ @visibility = attributes["visibility"]
321
+ @github_url = attributes["github_url"]
322
+ @paper_url = attributes["paper_url"]
323
+ @license_url = attributes["license_url"]
324
+ @run_count = attributes["run_count"]
325
+ @cover_image_url = attributes["cover_image_url"]
326
+ @default_example = attributes["default_example"]
327
+ @latest_version_id = attributes.dig("latest_version", "id")
328
+ @version_id = version_id || attributes.dig("latest_version", "id")
329
+
330
+ @version = nil
331
+ @versions = nil
332
+ @latest_version = nil
333
+ end
334
+ end
335
+ end