replicate-client 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,287 @@
1
+ # frozen_string_literal: true
2
+
3
+ module ReplicateClient
4
+ class Prediction
5
+ INDEX_PATH = "/predictions"
6
+
7
+ module Status
8
+ STARTING = "starting"
9
+ PROCESSING = "processing"
10
+ SUCCEEDED = "succeeded"
11
+ FAILED = "failed"
12
+ CANCELED = "canceled"
13
+ end
14
+
15
+ class << self
16
+ # Create a new prediction for a version.
17
+ #
18
+ # @param version [String, ReplicateClient::Version] The version of the model to use for the prediction.
19
+ # @param input [Hash] The input data for the prediction.
20
+ # @param webhook_url [String] The URL to send webhook events to.
21
+ # @param webhook_events_filter [Array<Symbol>] The events to send to the webhook.
22
+ #
23
+ # @return [ReplicateClient::Prediction]
24
+ def create!(version:, input:, webhook_url: nil, webhook_events_filter: nil)
25
+ args = {
26
+ version: version.is_a?(Model::Version) ? version.id : version,
27
+ input: input,
28
+ webhook: webhook_url || ReplicateClient.configuration.webhook_url,
29
+ webhook_events_filter: webhook_events_filter&.map(&:to_s)
30
+ }
31
+
32
+ prediction = ReplicateClient.client.post(INDEX_PATH, args)
33
+
34
+ new(prediction)
35
+ end
36
+
37
+ # Create a new prediction for a deployment.
38
+ #
39
+ # @param deployment [String, ReplicateClient::Deployment] The deployment to use for the prediction.
40
+ # @param input [Hash] The input data for the prediction.
41
+ # @param webhook_url [String] The URL to send webhook events to.
42
+ # @param webhook_events_filter [Array<Symbol>] The events to send to the webhook.
43
+ #
44
+ # @return [ReplicateClient::Prediction]
45
+ def create_for_deployment!(deployment:, input:, webhook_url: nil, webhook_events_filter: nil)
46
+ args = {
47
+ input: input,
48
+ webhook: webhook_url || ReplicateClient.configuration.webhook_url,
49
+ webhook_events_filter: webhook_events_filter&.map(&:to_s)
50
+ }
51
+
52
+ prediction = ReplicateClient.client.post("#{deployment.path}#{INDEX_PATH}", args)
53
+
54
+ new(prediction)
55
+ end
56
+
57
+ # Create a new prediction for a model.
58
+ #
59
+ # @param model [String, ReplicateClient::Model] The model to use for the prediction.
60
+ # @param input [Hash] The input data for the prediction.
61
+ # @param webhook_url [String] The URL to send webhook events to.
62
+ # @param webhook_events_filter [Array<Symbol>] The events to send to the webhook.
63
+ #
64
+ # @return [ReplicateClient::Prediction]
65
+ def create_for_official_model!(model:, input:, webhook_url: nil, webhook_events_filter: nil)
66
+ model_path = model.is_a?(Model) ? model.path : Model.build_path(**Model.parse_model_name(model))
67
+
68
+ args = {
69
+ input: input,
70
+ webhook: webhook_url || ReplicateClient.configuration.webhook_url,
71
+ webhook_events_filter: webhook_events_filter&.map(&:to_s)
72
+ }
73
+
74
+ prediction = ReplicateClient.client.post("#{model_path}#{INDEX_PATH}", args)
75
+
76
+ new(prediction)
77
+ end
78
+
79
+ # Find a prediction.
80
+ #
81
+ # @param id [String] The ID of the prediction.
82
+ #
83
+ # @return [ReplicateClient::Prediction]
84
+ def find(id)
85
+ attributes = ReplicateClient.client.get(build_path(id))
86
+ new(attributes)
87
+ end
88
+
89
+ # Find a prediction.
90
+ #
91
+ # @param id [String] The ID of the prediction.
92
+ #
93
+ # @return [ReplicateClient::Prediction]
94
+ def find_by!(id:)
95
+ find(id)
96
+ end
97
+
98
+ # Find a prediction.
99
+ #
100
+ # @param id [String] The ID of the prediction.
101
+ #
102
+ # @return [ReplicateClient::Prediction]
103
+ def find_by(id:)
104
+ find_by!(id: id)
105
+ rescue ReplicateClient::NotFoundError
106
+ nil
107
+ end
108
+
109
+ # Build the path for the prediction.
110
+ #
111
+ # @param id [String] The ID of the prediction.
112
+ #
113
+ # @return [String]
114
+ def build_path(id)
115
+ "#{INDEX_PATH}/#{id}"
116
+ end
117
+
118
+ # Cancel a prediction.
119
+ #
120
+ # @param id [String] The ID of the prediction.
121
+ #
122
+ # @return [void]
123
+ def cancel!(id)
124
+ ReplicateClient.client.post("#{build_path(id)}/cancel")
125
+ end
126
+ end
127
+
128
+ # The ID of the prediction.
129
+ #
130
+ # @return [String]
131
+ attr_accessor :id
132
+
133
+ # The version of the model used for the prediction.
134
+ #
135
+ # @return [String]
136
+ attr_accessor :version_id
137
+
138
+ # The model used for the prediction.
139
+ #
140
+ # @return [String]
141
+ attr_accessor :model_name
142
+
143
+ # The input data for the prediction.
144
+ #
145
+ # @return [Hash]
146
+ attr_accessor :input
147
+
148
+ # The output data for the prediction.
149
+ #
150
+ # @return [Hash]
151
+ attr_accessor :output
152
+
153
+ # The error message for the prediction.
154
+ #
155
+ # @return [String]
156
+ attr_accessor :error
157
+
158
+ # The status of the prediction.
159
+ #
160
+ # @return [String]
161
+ attr_accessor :status
162
+
163
+ # The date the prediction was created.
164
+ #
165
+ # @return [Time]
166
+ attr_accessor :created_at
167
+
168
+ # The date the prediction was removed.
169
+ #
170
+ # @return [Time]
171
+ attr_accessor :data_removed
172
+
173
+ # The date the prediction was started.
174
+ #
175
+ # @return [Time]
176
+ attr_accessor :started_at
177
+
178
+ # The date the prediction was completed.
179
+ #
180
+ # @return [Time]
181
+ attr_accessor :completed_at
182
+
183
+ # The metrics for the prediction.
184
+ #
185
+ # @return [Hash]
186
+ attr_accessor :metrics
187
+
188
+ # The URLs for the prediction.
189
+ #
190
+ # @return [Hash]
191
+ attr_accessor :urls
192
+
193
+ def initialize(attributes)
194
+ reset_attributes(attributes)
195
+ end
196
+
197
+ # Reload the prediction.
198
+ #
199
+ # @return [ReplicateClient::Prediction]
200
+ def reload!
201
+ attributes = ReplicateClient.client.get(Prediction.build_path(@id))
202
+ reset_attributes(attributes)
203
+ end
204
+
205
+ # The model used for the prediction.
206
+ #
207
+ # @return [ReplicateClient::Model]
208
+ def model
209
+ @model ||= Model.find(@model_name, version_id: @version_id)
210
+ end
211
+
212
+ # The version of the model used for the prediction.
213
+ #
214
+ # @return [ReplicateClient::Model::Version]
215
+ def version
216
+ @version ||= model.version
217
+ end
218
+
219
+ # Cancel the prediction.
220
+ #
221
+ # @return [void]
222
+ def cancel!
223
+ Prediction.cancel!(id)
224
+ end
225
+
226
+ # Check if the prediction is succeeded.
227
+ #
228
+ # @return [Boolean]
229
+ def succeeded?
230
+ status == Status::SUCCEEDED
231
+ end
232
+
233
+ # Check if the prediction is failed.
234
+ #
235
+ # @return [Boolean]
236
+ def failed?
237
+ status == Status::FAILED
238
+ end
239
+
240
+ # Check if the prediction is canceled.
241
+ #
242
+ # @return [Boolean]
243
+ def canceled?
244
+ status == Status::CANCELED
245
+ end
246
+
247
+ # Check if the prediction is starting.
248
+ #
249
+ # @return [Boolean]
250
+ def starting?
251
+ status == Status::STARTING
252
+ end
253
+
254
+ # Check if the prediction is processing.
255
+ #
256
+ # @return [Boolean]
257
+ def processing?
258
+ status == Status::PROCESSING
259
+ end
260
+
261
+ private
262
+
263
+ # Set the attributes of the prediction.
264
+ #
265
+ # @param attributes [Hash] The attributes of the prediction.
266
+ #
267
+ # @return [void]
268
+ def reset_attributes(attributes)
269
+ @id = attributes["id"]
270
+ @version_id = attributes["version"]
271
+ @model_name = attributes["model"]
272
+ @input = attributes["input"]
273
+ @output = attributes["output"]
274
+ @error = attributes["error"]
275
+ @status = attributes["status"]
276
+ @created_at = attributes["created_at"]
277
+ @data_removed = attributes["data_removed"]
278
+ @started_at = attributes["started_at"]
279
+ @completed_at = attributes["completed_at"]
280
+ @metrics = attributes["metrics"]
281
+ @urls = attributes["urls"]
282
+
283
+ @model = nil
284
+ @version = nil
285
+ end
286
+ end
287
+ end
@@ -0,0 +1,250 @@
1
+ # frozen_string_literal: true
2
+
3
+ module ReplicateClient
4
+ class Training
5
+ INDEX_PATH = "/trainings"
6
+
7
+ module Status
8
+ STARTING = "starting"
9
+ PROCESSING = "processing"
10
+ SUCCEEDED = "succeeded"
11
+ FAILED = "failed"
12
+ CANCELED = "canceled"
13
+ end
14
+
15
+ class << self
16
+ # List all trainings.
17
+ #
18
+ # @yield [ReplicateClient::Training] Yields a training.
19
+ #
20
+ # @return [void]
21
+ def auto_paging_each(&block)
22
+ cursor = nil
23
+
24
+ loop do
25
+ url_params = cursor ? "?cursor=#{cursor}" : ""
26
+ attributes = ReplicateClient.client.get("#{INDEX_PATH}#{url_params}")
27
+
28
+ trainings = attributes["results"].map { |training| new(training) }
29
+
30
+ trainings.each(&block)
31
+
32
+ cursor = attributes["next"] ? URI.decode_www_form(URI.parse(attributes["next"]).query).to_h["cursor"] : nil
33
+ break if cursor.nil?
34
+ end
35
+ end
36
+
37
+ # Create a new training.
38
+ #
39
+ # @param owner [String] The owner of the model.
40
+ # @param name [String] The name of the model.
41
+ # @param version [ReplicateClient::Version, String] The version of the model to train.
42
+ # @param destination [ReplicateClient::Model, String] The destination model instance or string in "owner/name"
43
+ # format.
44
+ # @param input [Hash] The input data for the training.
45
+ # @param webhook [String, nil] A URL to receive webhook notifications.
46
+ # @param webhook_events_filter [Array, nil] The events to trigger webhook requests.
47
+ #
48
+ # @return [ReplicateClient::Training]
49
+ def create!(owner:, name:, version:, destination:, input:, webhook: nil, webhook_events_filter: nil)
50
+ destination_str = destination.is_a?(ReplicateClient::Model) ? destination.full_name : destination
51
+ version_id = version.is_a?(ReplicateClient::Model::Version) ? version.id : version
52
+
53
+ path = "/models/#{owner}/#{name}/versions/#{version_id}/trainings"
54
+ body = {
55
+ destination: destination_str,
56
+ input: input,
57
+ webhook: webhook,
58
+ webhook_events_filter: webhook_events_filter
59
+ }
60
+
61
+ attributes = ReplicateClient.client.post(path, body)
62
+ new(attributes)
63
+ end
64
+
65
+ # Create a new training for a specific model.
66
+ #
67
+ # @param model [ReplicateClient::Model, String] The model instance or a string representing the model ID.
68
+ # @param destination [ReplicateClient::Model, String] The destination model or full name in "owner/name" format.
69
+ # @param input [Hash] The input data for the training.
70
+ # @param webhook [String, nil] A URL to receive webhook notifications.
71
+ # @param webhook_events_filter [Array, nil] The events to trigger webhook requests.
72
+ #
73
+ # @return [ReplicateClient::Training]
74
+ def create_for_model!(model:, destination:, input:, webhook: nil, webhook_events_filter: nil)
75
+ model_instance = model.is_a?(ReplicateClient::Model) ? model : ReplicateClient::Model.find(model)
76
+ raise ArgumentError, "Invalid model" unless model_instance
77
+
78
+ create!(
79
+ owner: model_instance.owner,
80
+ name: model_instance.name,
81
+ version: model_instance.version_id,
82
+ destination: destination,
83
+ input: input,
84
+ webhook: webhook,
85
+ webhook_events_filter: webhook_events_filter
86
+ )
87
+ end
88
+
89
+ # Find a training by id.
90
+ #
91
+ # @param id [String] The id of the training.
92
+ #
93
+ # @return [ReplicateClient::Training]
94
+ def find(id)
95
+ path = build_path(id: id)
96
+ attributes = ReplicateClient.client.get(path)
97
+ new(attributes)
98
+ end
99
+
100
+ # Cancel a training.
101
+ #
102
+ # @param id [String] The id of the training.
103
+ #
104
+ # @return [void]
105
+ def cancel!(id)
106
+ path = "#{build_path(id: id)}/cancel"
107
+ ReplicateClient.client.post(path)
108
+ end
109
+
110
+ # Build the path for a specific training.
111
+ #
112
+ # @param id [String] The id of the training.
113
+ #
114
+ # @return [String]
115
+ def build_path(id:)
116
+ "#{INDEX_PATH}/#{id}"
117
+ end
118
+ end
119
+
120
+ # The unique identifier of the training.
121
+ #
122
+ # @return [String]
123
+ attr_accessor :id
124
+
125
+ # The full model name in the format "owner/name".
126
+ #
127
+ # @return [String]
128
+ attr_accessor :model
129
+
130
+ # The version ID of the model being trained.
131
+ #
132
+ # @return [String]
133
+ attr_accessor :version
134
+
135
+ # The input data provided for the training.
136
+ #
137
+ # @return [Hash]
138
+ attr_accessor :input
139
+
140
+ # The current status of the training.
141
+ # Possible values: "starting", "processing", "succeeded", "failed", "canceled".
142
+ #
143
+ # @return [String]
144
+ attr_accessor :status
145
+
146
+ # The timestamp when the training was created.
147
+ #
148
+ # @return [String]
149
+ attr_accessor :created_at
150
+
151
+ # The timestamp when the training was completed.
152
+ #
153
+ # @return [String, nil]
154
+ attr_accessor :completed_at
155
+
156
+ # The logs generated during the training process.
157
+ #
158
+ # @return [String]
159
+ attr_accessor :logs
160
+
161
+ # The error message, if any, encountered during the training process.
162
+ #
163
+ # @return [String, nil]
164
+ attr_accessor :error
165
+
166
+ # URLs related to the training, such as those for retrieving or canceling it.
167
+ #
168
+ # @return [Hash]
169
+ attr_accessor :urls
170
+
171
+ # Initialize a new training instance.
172
+ #
173
+ # @param attributes [Hash] The attributes of the training.
174
+ #
175
+ # @return [ReplicateClient::Training]
176
+ def initialize(attributes)
177
+ reset_attributes(attributes)
178
+ end
179
+
180
+ # Check if the training is starting.
181
+ #
182
+ # @return [Boolean]
183
+ def starting?
184
+ status == Status::STARTING
185
+ end
186
+
187
+ # Check if the training is processing.
188
+ #
189
+ # @return [Boolean]
190
+ def processing?
191
+ status == Status::PROCESSING
192
+ end
193
+
194
+ # Check if the training has succeeded.
195
+ #
196
+ # @return [Boolean]
197
+ def succeeded?
198
+ status == Status::SUCCEEDED
199
+ end
200
+
201
+ # Check if the training has failed.
202
+ #
203
+ # @return [Boolean]
204
+ def failed?
205
+ status == Status::FAILED
206
+ end
207
+
208
+ # Check if the training was canceled.
209
+ #
210
+ # @return [Boolean]
211
+ def canceled?
212
+ status == Status::CANCELED
213
+ end
214
+
215
+ # Cancel the training.
216
+ #
217
+ # @return [void]
218
+ def cancel!
219
+ ReplicateClient::Training.cancel!(id)
220
+ end
221
+
222
+ # Reload the training.
223
+ #
224
+ # @return [void]
225
+ def reload!
226
+ attributes = ReplicateClient.client.get(Training.build_path(id: id))
227
+ reset_attributes(attributes)
228
+ end
229
+
230
+ private
231
+
232
+ # Set the attributes of the training.
233
+ #
234
+ # @param attributes [Hash] The attributes of the training.
235
+ #
236
+ # @return [void]
237
+ def reset_attributes(attributes)
238
+ @id = attributes["id"]
239
+ @model = attributes["model"]
240
+ @version = attributes["version"]
241
+ @input = attributes["input"]
242
+ @status = attributes["status"]
243
+ @created_at = attributes["created_at"]
244
+ @completed_at = attributes["completed_at"]
245
+ @logs = attributes["logs"]
246
+ @error = attributes["error"]
247
+ @urls = attributes["urls"]
248
+ end
249
+ end
250
+ end
@@ -0,0 +1,5 @@
1
+ # frozen_string_literal: true
2
+
3
+ module ReplicateClient
4
+ VERSION = "0.1.0"
5
+ end
@@ -0,0 +1,78 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "faraday"
4
+ require "time"
5
+
6
+ require_relative "replicate-client/client"
7
+ require_relative "replicate-client/prediction"
8
+ require_relative "replicate-client/version"
9
+ require_relative "replicate-client/model"
10
+ require_relative "replicate-client/hardware"
11
+ require_relative "replicate-client/training"
12
+ require_relative "replicate-client/deployment"
13
+
14
+ module ReplicateClient
15
+ class Error < StandardError; end
16
+ class UnauthorizedError < Error; end
17
+ class NotFoundError < Error; end
18
+ class ServerError < Error; end
19
+ class ConfigurationError < Error; end
20
+ class ForbiddenError < Error; end
21
+
22
+ class Configuration
23
+ DEFAULT_URI_BASE = "https://api.replicate.com/v1"
24
+ DEFAULT_REQUEST_TIMEOUT = 120
25
+
26
+ # The access token for the API.
27
+ #
28
+ # @return [String]
29
+ attr_accessor :access_token
30
+
31
+ # The base URI for the API.
32
+ #
33
+ # @return [String]
34
+ attr_accessor :uri_base
35
+
36
+ # The request timeout in seconds.
37
+ #
38
+ # @return [Integer]
39
+ attr_accessor :request_timeout
40
+
41
+ # The URL to send webhook events to.
42
+ #
43
+ # @return [String]
44
+ attr_accessor :webhook_url
45
+
46
+ # Initialize the configuration.
47
+ #
48
+ # @return [ReplicateClient::Configuration]
49
+ def initialize
50
+ @access_token = nil
51
+ @webhook_url = nil
52
+ @uri_base = DEFAULT_URI_BASE
53
+ @request_timeout = DEFAULT_REQUEST_TIMEOUT
54
+ end
55
+ end
56
+
57
+ class << self
58
+ # The configuration for the client.
59
+ #
60
+ # @return [ReplicateClient::Configuration]
61
+ attr_accessor :configuration
62
+
63
+ # Configure the client.
64
+ #
65
+ # @yield [ReplicateClient::Configuration] The configuration for the client.
66
+ def configure
67
+ self.configuration ||= Configuration.new
68
+ yield(configuration)
69
+ end
70
+
71
+ # The client for the API.
72
+ #
73
+ # @return [ReplicateClient::Client]
74
+ def client
75
+ @client ||= Client.new(configuration)
76
+ end
77
+ end
78
+ end
metadata ADDED
@@ -0,0 +1,76 @@
1
+ --- !ruby/object:Gem::Specification
2
+ name: replicate-client
3
+ version: !ruby/object:Gem::Version
4
+ version: 0.1.0
5
+ platform: ruby
6
+ authors:
7
+ - Dylan Player
8
+ autorequire:
9
+ bindir: exe
10
+ cert_chain: []
11
+ date: 2024-08-20 00:00:00.000000000 Z
12
+ dependencies:
13
+ - !ruby/object:Gem::Dependency
14
+ name: faraday
15
+ requirement: !ruby/object:Gem::Requirement
16
+ requirements:
17
+ - - ">="
18
+ - !ruby/object:Gem::Version
19
+ version: '1'
20
+ type: :runtime
21
+ prerelease: false
22
+ version_requirements: !ruby/object:Gem::Requirement
23
+ requirements:
24
+ - - ">="
25
+ - !ruby/object:Gem::Version
26
+ version: '1'
27
+ description:
28
+ email:
29
+ - dylan@851.sh
30
+ executables: []
31
+ extensions: []
32
+ extra_rdoc_files: []
33
+ files:
34
+ - ".rubocop.yml"
35
+ - ".ruby-version"
36
+ - Gemfile
37
+ - Gemfile.lock
38
+ - README.md
39
+ - Rakefile
40
+ - lib/replicate-client.rb
41
+ - lib/replicate-client/client.rb
42
+ - lib/replicate-client/deployment.rb
43
+ - lib/replicate-client/hardware.rb
44
+ - lib/replicate-client/model.rb
45
+ - lib/replicate-client/model/version.rb
46
+ - lib/replicate-client/prediction.rb
47
+ - lib/replicate-client/training.rb
48
+ - lib/replicate-client/version.rb
49
+ homepage: https://github.com/851-labs/replicate
50
+ licenses:
51
+ - MIT
52
+ metadata:
53
+ allowed_push_host: https://rubygems.org
54
+ homepage_uri: https://github.com/851-labs/replicate
55
+ source_code_uri: https://github.com/851-labs/replicate
56
+ rubygems_mfa_required: 'true'
57
+ post_install_message:
58
+ rdoc_options: []
59
+ require_paths:
60
+ - lib
61
+ required_ruby_version: !ruby/object:Gem::Requirement
62
+ requirements:
63
+ - - ">="
64
+ - !ruby/object:Gem::Version
65
+ version: 3.3.0
66
+ required_rubygems_version: !ruby/object:Gem::Requirement
67
+ requirements:
68
+ - - ">="
69
+ - !ruby/object:Gem::Version
70
+ version: '0'
71
+ requirements: []
72
+ rubygems_version: 3.5.3
73
+ signing_key:
74
+ specification_version: 4
75
+ summary: Ruby client for Replicate API.
76
+ test_files: []