replicate-client 0.1.9 → 0.1.10

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: e9834788e800bcbdd19fe6e813b8c19c30b4d62ca0fe5035fe05d32b95126ac7
4
- data.tar.gz: da4c94e22f0f3e884a690ab7c4d4701c812b17ec304efd2e8e956ea557aebf9f
3
+ metadata.gz: fad1f0756031bf4b2d7847c59836b4fcc570ed2354ccd8e771056f2f3d26fa0e
4
+ data.tar.gz: e8c163b483dba74ce529a8a7af809807706fcdd45d7ae3edaec37124b7259765
5
5
  SHA512:
6
- metadata.gz: 60b067804aae73acf3bdf193936fa504bc748f023b3a55d2a168fc3014fd8a920b3b6a276ef2f5fa2c76d51c57ad2354cb324bd2e6495d7ba95b7d7283ab39e2
7
- data.tar.gz: e8bba1499966efe2ec8deb92bcf9cc352368586b39f06309a151dbfb18d3f6de85ed69c36be86340d18b820989f36efde8666cc2d14718026713cd9c00fc9437
6
+ metadata.gz: f4292065e85be45396415204d7898b76a4aecc498e44b05354178ef59103adbffc27c453429dd976371ca9c17893b443b119d3faa672a518cdf853e1438a84f6
7
+ data.tar.gz: 55194df7b411c76e99ff369e4186fe3ed6d8cd25c79a1b2d6a63f53319c7baafde517057e61e7d643f6795679b1ca57e1490fff7dd63fecd1266d53d73493606
data/Gemfile.lock CHANGED
@@ -1,7 +1,7 @@
1
1
  PATH
2
2
  remote: .
3
3
  specs:
4
- replicate-client (0.1.8)
4
+ replicate-client (0.1.10)
5
5
  faraday (>= 1)
6
6
 
7
7
  GEM
data/README.md CHANGED
@@ -1,16 +1,14 @@
1
1
  # ReplicateClient
2
2
 
3
- **🚧 This gem is still under development 🚧**
4
-
5
3
  ## Installation
6
4
 
7
5
  Install the gem and add to the application"s Gemfile by executing:
8
6
 
9
- $ bundle add replicate
7
+ $ bundle add replicate-client
10
8
 
11
9
  If bundler is not being used to manage dependencies, install the gem by executing:
12
10
 
13
- $ gem install replicate
11
+ $ gem install replicate-client
14
12
 
15
13
  ## Usage
16
14
 
@@ -199,4 +197,3 @@ Official models will not have vesions. The version id will be nil.
199
197
  After checking out the repo, run `bin/setup` to install dependencies. Then, run `rake test` to run the tests. You can also run `bin/console` for an interactive prompt that will allow you to experiment.
200
198
 
201
199
  To install this gem onto your local machine, run `bundle exec rake install`. To release a new version, update the version number in `version.rb`, and then run `bundle exec rake release`, which will create a git tag for the version, push git commits and the created tag, and push the `.gem` file to [rubygems.org](https://rubygems.org).
202
-
data/Rakefile CHANGED
@@ -6,7 +6,7 @@ require "rake/testtask"
6
6
  Rake::TestTask.new(:test) do |t|
7
7
  t.libs << "test"
8
8
  t.libs << "lib"
9
- t.test_files = FileList["test/**/test_*.rb"]
9
+ t.test_files = FileList["test/**/*_test.rb"]
10
10
  end
11
11
 
12
12
  require "rubocop/rake_task"
@@ -2,6 +2,11 @@
2
2
 
3
3
  module ReplicateClient
4
4
  class Client
5
+ # The configuration for the client.
6
+ #
7
+ # @return [ReplicateClient::Configuration]
8
+ attr_accessor :configuration
9
+
5
10
  # Initialize the client.
6
11
  #
7
12
  # @param configuration [ReplicateClient::Configuration] The configuration for the client.
@@ -7,17 +7,19 @@ module ReplicateClient
7
7
  class << self
8
8
  # List all deployments.
9
9
  #
10
+ # @param client [ReplicateClient::Client] The client to use for requests.
11
+ #
10
12
  # @yield [ReplicateClient::Deployment] Yields a deployment.
11
13
  #
12
14
  # @return [void]
13
- def auto_paging_each(&block)
15
+ def auto_paging_each(client: ReplicateClient.client, &block)
14
16
  cursor = nil
15
17
 
16
18
  loop do
17
19
  url_params = cursor ? "?cursor=#{cursor}" : ""
18
- attributes = ReplicateClient.client.get("#{INDEX_PATH}#{url_params}")
20
+ attributes = client.get("#{INDEX_PATH}#{url_params}")
19
21
 
20
- deployments = attributes["results"].map { |deployment| new(deployment) }
22
+ deployments = attributes["results"].map { |deployment| new(deployment, client: client) }
21
23
 
22
24
  deployments.each(&block)
23
25
 
@@ -34,9 +36,11 @@ module ReplicateClient
34
36
  # @param hardware [ReplicateClient::Hardware, String] The hardware SKU.
35
37
  # @param min_instances [Integer] The minimum number of instances.
36
38
  # @param max_instances [Integer] The maximum number of instances.
39
+ # @param client [ReplicateClient::Client] The client to use for requests.
37
40
  #
38
41
  # @return [ReplicateClient::Deployment]
39
- def create!(name:, model:, hardware:, min_instances:, max_instances:, version_id: nil)
42
+ def create!(name:, model:, hardware:, min_instances:, max_instances:, version_id: nil,
43
+ client: ReplicateClient.client)
40
44
  model_full_name = model.is_a?(Model) ? model.full_name : model
41
45
  hardware_sku = hardware.is_a?(Hardware) ? hardware.sku : hardware
42
46
  version = if version_id
@@ -44,7 +48,7 @@ module ReplicateClient
44
48
  elsif model.is_a?(Model)
45
49
  model.version_id
46
50
  else
47
- Model.find(model).latest_version.id
51
+ Model.find(model, client: client).latest_version.id
48
52
  end
49
53
 
50
54
  body = {
@@ -56,41 +60,44 @@ module ReplicateClient
56
60
  max_instances: max_instances
57
61
  }
58
62
 
59
- attributes = ReplicateClient.client.post(INDEX_PATH, body)
60
- new(attributes)
63
+ attributes = client.post(INDEX_PATH, body)
64
+ new(attributes, client: client)
61
65
  end
62
66
 
63
67
  # Find a deployment by owner and name.
64
68
  #
65
69
  # @param full_name [String] The full name of the deployment in "owner/name" format.
70
+ # @param client [ReplicateClient::Client] The client to use for requests.
66
71
  #
67
72
  # @return [ReplicateClient::Deployment]
68
- def find(full_name)
73
+ def find(full_name, client: ReplicateClient.client)
69
74
  path = build_path(**parse_full_name(full_name))
70
- attributes = ReplicateClient.client.get(path)
71
- new(attributes)
75
+ attributes = client.get(path)
76
+ new(attributes, client: client)
72
77
  end
73
78
 
74
79
  # Find a deployment by owner and name.
75
80
  #
76
81
  # @param owner [String] The owner of the deployment.
77
82
  # @param name [String] The name of the deployment.
83
+ # @param client [ReplicateClient::Client] The client to use for requests.
78
84
  #
79
85
  # @return [ReplicateClient::Deployment]
80
- def find_by!(owner:, name:)
86
+ def find_by!(owner:, name:, client: ReplicateClient.client)
81
87
  path = build_path(owner: owner, name: name)
82
- attributes = ReplicateClient.client.get(path)
83
- new(attributes)
88
+ attributes = client.get(path)
89
+ new(attributes, client: client)
84
90
  end
85
91
 
86
92
  # Find a deployment by owner and name.
87
93
  #
88
94
  # @param owner [String] The owner of the deployment.
89
95
  # @param name [String] The name of the deployment.
96
+ # @param client [ReplicateClient::Client] The client to use for requests.
90
97
  #
91
98
  # @return [ReplicateClient::Deployment, nil]
92
- def find_by(owner:, name:)
93
- find_by!(owner: owner, name: name)
99
+ def find_by(owner:, name:, client: ReplicateClient.client)
100
+ find_by!(owner: owner, name: name, client: client)
94
101
  rescue ReplicateClient::NotFoundError
95
102
  nil
96
103
  end
@@ -99,11 +106,12 @@ module ReplicateClient
99
106
  #
100
107
  # @param owner [String] The owner of the deployment.
101
108
  # @param name [String] The name of the deployment.
109
+ # @param client [ReplicateClient::Client] The client to use for requests.
102
110
  #
103
111
  # @return [void]
104
- def destroy!(owner:, name:)
112
+ def destroy!(owner:, name:, client: ReplicateClient.client)
105
113
  path = build_path(owner: owner, name: name)
106
- ReplicateClient.client.delete(path)
114
+ client.delete(path)
107
115
  end
108
116
 
109
117
  # Build the path for a specific deployment.
@@ -130,12 +138,19 @@ module ReplicateClient
130
138
  # Attributes for deployment.
131
139
  attr_accessor :owner, :name, :current_release
132
140
 
141
+ # The client used to make API requests for this deployment.
142
+ #
143
+ # @return [ReplicateClient::Client]
144
+ attr_accessor :client
145
+
133
146
  # Initialize a new deployment instance.
134
147
  #
135
148
  # @param attributes [Hash] The attributes of the deployment.
149
+ # @param client [ReplicateClient::Client] The client to use for requests.
136
150
  #
137
151
  # @return [ReplicateClient::Deployment]
138
- def initialize(attributes)
152
+ def initialize(attributes, client: ReplicateClient.client)
153
+ @client = client
139
154
  reset_attributes(attributes)
140
155
  end
141
156
 
@@ -143,7 +158,7 @@ module ReplicateClient
143
158
  #
144
159
  # @return [void]
145
160
  def destroy!
146
- self.class.destroy!(owner: owner, name: name)
161
+ self.class.destroy!(owner: owner, name: name, client: @client)
147
162
  end
148
163
 
149
164
  # Update the deployment.
@@ -155,8 +170,8 @@ module ReplicateClient
155
170
  #
156
171
  # @return [void]
157
172
  def update!(hardware: nil, min_instances: nil, max_instances: nil, version: nil)
158
- version_id = version.is_a?(Version) ? version.id : version
159
- path = build_path(owner: owner, name: name)
173
+ version_id = version.is_a?(ReplicateClient::Model::Version) ? version.id : version
174
+ path = self.class.build_path(owner: owner, name: name)
160
175
  body = {
161
176
  hardware: hardware,
162
177
  min_instances: min_instances,
@@ -164,7 +179,7 @@ module ReplicateClient
164
179
  version: version_id
165
180
  }.compact
166
181
 
167
- attributes = ReplicateClient.client.patch(path, body)
182
+ attributes = @client.patch(path, body)
168
183
  reset_attributes(attributes)
169
184
  end
170
185
 
@@ -172,7 +187,7 @@ module ReplicateClient
172
187
  #
173
188
  # @return [void]
174
189
  def reload!
175
- attributes = ReplicateClient.client.get(path)
190
+ attributes = @client.get(path)
176
191
  reset_attributes(attributes)
177
192
  end
178
193
 
@@ -195,7 +210,8 @@ module ReplicateClient
195
210
  deployment: self,
196
211
  input: input,
197
212
  webhook_url: webhook_url,
198
- webhook_events_filter: webhook_events_filter
213
+ webhook_events_filter: webhook_events_filter,
214
+ client: @client
199
215
  )
200
216
  end
201
217
 
@@ -7,19 +7,22 @@ module ReplicateClient
7
7
  class << self
8
8
  # List all available hardware.
9
9
  #
10
+ # @param client [ReplicateClient::Client] The client to use for requests.
11
+ #
10
12
  # @return [Array<ReplicateClient::Hardware>]
11
- def all
12
- response = ReplicateClient.client.get(INDEX_PATH)
13
+ def all(client: ReplicateClient.client)
14
+ response = client.get(INDEX_PATH)
13
15
  response.map { |attributes| new(attributes) }
14
16
  end
15
17
 
16
18
  # Find hardware by SKU.
17
19
  #
18
20
  # @param sku [String] The SKU of the hardware.
21
+ # @param client [ReplicateClient::Client] The client to use for requests.
19
22
  #
20
23
  # @return [ReplicateClient::Hardware, nil]
21
- def find_by(sku:)
22
- all.find { |hardware| hardware.sku == sku }
24
+ def find_by(sku:, client: ReplicateClient.client)
25
+ all(client: client).find { |hardware| hardware.sku == sku }
23
26
  end
24
27
  end
25
28
 
@@ -11,12 +11,13 @@ module ReplicateClient
11
11
  # @param owner [String] The owner of the model.
12
12
  # @param name [String] The name of the model.
13
13
  # @param version_id [String] The version id of the model.
14
+ # @param client [ReplicateClient::Client] The client to use for requests.
14
15
  #
15
16
  # @return [ReplicateClient::Model::Version]
16
- def find_by!(owner:, name:, version_id:)
17
+ def find_by!(owner:, name:, version_id:, client: ReplicateClient.client)
17
18
  path = build_path(owner: owner, name: name, version_id: version_id)
18
- response = ReplicateClient.client.get(path)
19
- new(response)
19
+ response = client.get(path)
20
+ new(response, client: client)
20
21
  end
21
22
 
22
23
  # Find a version of a model.
@@ -24,10 +25,11 @@ module ReplicateClient
24
25
  # @param owner [String] The owner of the model.
25
26
  # @param name [String] The name of the model.
26
27
  # @param version_id [String] The version id of the model.
28
+ # @param client [ReplicateClient::Client] The client to use for requests.
27
29
  #
28
30
  # @return [ReplicateClient::Model::Version]
29
- def find_by(owner:, name:, version_id:)
30
- find_by!(owner: owner, name: name, version_id: version_id)
31
+ def find_by(owner:, name:, version_id:, client: ReplicateClient.client)
32
+ find_by!(owner: owner, name: name, version_id: version_id, client: client)
31
33
  rescue ReplicateClient::NotFoundError
32
34
  nil
33
35
  end
@@ -36,12 +38,13 @@ module ReplicateClient
36
38
  #
37
39
  # @param owner [String] The owner of the model.
38
40
  # @param name [String] The name of the model.
41
+ # @param client [ReplicateClient::Client] The client to use for requests.
39
42
  #
40
43
  # @return [Array<ReplicateClient::Model::Version>]
41
- def where(owner:, name:)
44
+ def where(owner:, name:, client: ReplicateClient.client)
42
45
  versions = []
43
46
 
44
- auto_paging_each(owner: owner, name: name) do |version|
47
+ auto_paging_each(owner: owner, name: name, client: client) do |version|
45
48
  versions << version
46
49
  end
47
50
 
@@ -52,18 +55,19 @@ module ReplicateClient
52
55
  #
53
56
  # @param name [String] The name of the model.
54
57
  # @param owner [String] The owner of the model.
58
+ # @param client [ReplicateClient::Client] The client to use for requests.
55
59
  # @yield [ReplicateClient::Model] Yields a model.
56
60
  #
57
61
  # @return [void]
58
- def auto_paging_each(owner:, name:, &block)
62
+ def auto_paging_each(owner:, name:, client: ReplicateClient.client, &block)
59
63
  cursor = nil
60
64
  model_path = Model.build_path(owner: owner, name: name)
61
65
 
62
66
  loop do
63
67
  url_params = cursor ? "?cursor=#{cursor}" : ""
64
- attributes = ReplicateClient.client.get("#{model_path}#{INDEX_PATH}#{url_params}")
68
+ attributes = client.get("#{model_path}#{INDEX_PATH}#{url_params}")
65
69
 
66
- versions = attributes["results"].map { |version| new(version) }
70
+ versions = attributes["results"].map { |version| new(version, client: client) }
67
71
 
68
72
  versions.each(&block)
69
73
 
@@ -105,7 +109,13 @@ module ReplicateClient
105
109
  # @return [Hash]
106
110
  attr_accessor :openapi_schema
107
111
 
108
- def initialize(attributes)
112
+ # The client used to make API requests for this model version.
113
+ #
114
+ # @return [ReplicateClient::Client]
115
+ attr_accessor :client
116
+
117
+ def initialize(attributes, client: ReplicateClient.client)
118
+ @client = client
109
119
  @id = attributes["id"]
110
120
  @created_at = Time.parse(attributes["created_at"])
111
121
  @cog_version = attributes["cog_version"]
@@ -124,7 +134,8 @@ module ReplicateClient
124
134
  version: self,
125
135
  input: input,
126
136
  webhook_url: webhook_url,
127
- webhook_events_filter: webhook_events_filter
137
+ webhook_events_filter: webhook_events_filter,
138
+ client: @client
128
139
  )
129
140
  end
130
141
 
@@ -16,12 +16,13 @@ module ReplicateClient
16
16
  #
17
17
  # @param owner [String] The owner of the model.
18
18
  # @param name [String] The name of the model.
19
+ # @param client [ReplicateClient::Client] The client to use for requests.
19
20
  #
20
21
  # @return [ReplicateClient::Model]
21
- def find_by!(owner:, name:, version_id: nil)
22
+ def find_by!(owner:, name:, version_id: nil, client: ReplicateClient.client)
22
23
  path = build_path(owner: owner, name: name)
23
- response = ReplicateClient.client.get(path)
24
- new(response, version_id: version_id)
24
+ response = client.get(path)
25
+ new(response, version_id: version_id, client: client)
25
26
  end
26
27
 
27
28
  # Find a model.
@@ -29,10 +30,11 @@ module ReplicateClient
29
30
  # @param owner [String] The owner of the model.
30
31
  # @param name [String] The name of the model.
31
32
  # @param version_id [String] The version id of the model to use.
33
+ # @param client [ReplicateClient::Client] The client to use for requests.
32
34
  #
33
35
  # @return [ReplicateClient::Model]
34
- def find_by(owner:, name:, version_id: nil)
35
- find_by!(owner: owner, name: name, version_id: version_id)
36
+ def find_by(owner:, name:, version_id: nil, client: ReplicateClient.client)
37
+ find_by!(owner: owner, name: name, version_id: version_id, client: client)
36
38
  rescue ReplicateClient::NotFoundError
37
39
  nil
38
40
  end
@@ -42,10 +44,11 @@ module ReplicateClient
42
44
  #
43
45
  # @param name [String] The name of the model.
44
46
  # @param version_id [String] The version id of the model to use.
47
+ # @param client [ReplicateClient::Client] The client to use for requests.
45
48
  #
46
49
  # @return [ReplicateClient::Model]
47
- def find(name, version_id: nil)
48
- find_by!(**parse_model_name(name), version_id: version_id)
50
+ def find(name, version_id: nil, client: ReplicateClient.client)
51
+ find_by!(**parse_model_name(name), version_id: version_id, client: client)
49
52
  end
50
53
 
51
54
  # Build the path for the model.
@@ -60,17 +63,19 @@ module ReplicateClient
60
63
 
61
64
  # Paginate through all models.
62
65
  #
66
+ # @param client [ReplicateClient::Client] The client to use for requests.
67
+ #
63
68
  # @yield [ReplicateClient::Model] Yields a model.
64
69
  #
65
70
  # @return [void]
66
- def auto_paging_each(&block)
71
+ def auto_paging_each(client: ReplicateClient.client, &block)
67
72
  cursor = nil
68
73
 
69
74
  loop do
70
75
  url_params = cursor ? "?cursor=#{cursor}" : ""
71
- attributes = ReplicateClient.client.get("#{INDEX_PATH}#{url_params}")
76
+ attributes = client.get("#{INDEX_PATH}#{url_params}")
72
77
 
73
- models = attributes["results"].map { |model| new(model) }
78
+ models = attributes["results"].map { |model| new(model, client: client) }
74
79
 
75
80
  models.each(&block)
76
81
 
@@ -90,6 +95,7 @@ module ReplicateClient
90
95
  # @param paper_url [String, nil] A URL for the model’s paper.
91
96
  # @param license_url [String, nil] A URL for the model’s license.
92
97
  # @param cover_image_url [String, nil] A URL for the model’s cover image.
98
+ # @param client [ReplicateClient::Client] The client to use for requests.
93
99
  #
94
100
  # @return [ReplicateClient::Model]
95
101
  def create!(
@@ -101,7 +107,8 @@ module ReplicateClient
101
107
  github_url: nil,
102
108
  paper_url: nil,
103
109
  license_url: nil,
104
- cover_image_url: nil
110
+ cover_image_url: nil,
111
+ client: ReplicateClient.client
105
112
  )
106
113
  new_attributes = {
107
114
  owner: owner,
@@ -115,9 +122,9 @@ module ReplicateClient
115
122
  cover_image_url: cover_image_url
116
123
  }
117
124
 
118
- attributes = ReplicateClient.client.post("/models", new_attributes)
125
+ attributes = client.post("/models", new_attributes)
119
126
 
120
- new(attributes)
127
+ new(attributes, client: client)
121
128
  end
122
129
 
123
130
  # Parse the model name.
@@ -201,13 +208,20 @@ module ReplicateClient
201
208
  # @return [Hash]
202
209
  attr_accessor :latest_version_id
203
210
 
211
+ # The client used to make API requests for this model.
212
+ #
213
+ # @return [ReplicateClient::Client]
214
+ attr_accessor :client
215
+
204
216
  # Initialize a new model.
205
217
  #
206
218
  # @param attributes [Hash] The attributes of the model.
207
219
  # @param version_id [String] The version of the model to use.
220
+ # @param client [ReplicateClient::Client] The client to use for requests.
208
221
  #
209
222
  # @return [ReplicateClient::Model]
210
- def initialize(attributes, version_id: nil)
223
+ def initialize(attributes, version_id: nil, client: ReplicateClient.client)
224
+ @client = client
211
225
  reset_attributes(attributes, version_id: version_id)
212
226
  end
213
227
 
@@ -222,7 +236,7 @@ module ReplicateClient
222
236
  #
223
237
  # @return [void]
224
238
  def destroy!
225
- ReplicateClient.client.delete(path)
239
+ @client.delete(path)
226
240
  end
227
241
 
228
242
  # The path of the current version.
@@ -236,21 +250,21 @@ module ReplicateClient
236
250
  #
237
251
  # @return [ReplicateClient::Model::Version]
238
252
  def version
239
- @version ||= Version.find_by!(owner: owner, name: name, version_id: version_id)
253
+ @version ||= Version.find_by!(owner: owner, name: name, version_id: version_id, client: @client)
240
254
  end
241
255
 
242
256
  # The latest version of the model.
243
257
  #
244
258
  # @return [ReplicateClient::Model::Version]
245
259
  def latest_version
246
- @latest_version ||= Version.find_by!(owner: owner, name: name, version_id: latest_version_id)
260
+ @latest_version ||= Version.find_by!(owner: owner, name: name, version_id: latest_version_id, client: @client)
247
261
  end
248
262
 
249
263
  # The versions of the model.
250
264
  #
251
265
  # @return [Array<ReplicateClient::Model::Version>]
252
266
  def versions
253
- @versions ||= Version.where(owner: owner, name: name)
267
+ @versions ||= Version.where(owner: owner, name: name, client: @client)
254
268
  end
255
269
 
256
270
  # Create a new prediction for the model.
@@ -264,14 +278,16 @@ module ReplicateClient
264
278
  model: self,
265
279
  input: input,
266
280
  webhook_url: webhook_url,
267
- webhook_events_filter: webhook_events_filter
281
+ webhook_events_filter: webhook_events_filter,
282
+ client: @client
268
283
  )
269
284
  else
270
285
  Prediction.create!(
271
286
  version: version_id,
272
287
  input: input,
273
288
  webhook_url: webhook_url,
274
- webhook_events_filter: webhook_events_filter
289
+ webhook_events_filter: webhook_events_filter,
290
+ client: @client
275
291
  )
276
292
  end
277
293
  end
@@ -280,7 +296,7 @@ module ReplicateClient
280
296
  #
281
297
  # @return [void]
282
298
  def reload!
283
- attributes = ReplicateClient.client.get(path)
299
+ attributes = @client.get(path)
284
300
  reset_attributes(attributes, version_id: version_id)
285
301
  end
286
302
 
@@ -20,21 +20,23 @@ module ReplicateClient
20
20
  # @param webhook_url [String] The URL to send webhook events to.
21
21
  # @param webhook_events_filter [Array<Symbol>] The events to send to the webhook.
22
22
  # @param sync [Boolean] Whether to wait for the prediction to complete.
23
+ # @param client [ReplicateClient::Client] The client to use for the prediction.
23
24
  #
24
25
  # @return [ReplicateClient::Prediction]
25
- def create!(version:, input:, webhook_url: nil, webhook_events_filter: nil, sync: false)
26
+ def create!(version:, input:, webhook_url: nil, webhook_events_filter: nil, sync: false,
27
+ client: ReplicateClient.client)
26
28
  args = {
27
29
  version: version.is_a?(Model::Version) ? version.id : version,
28
30
  input: input,
29
- webhook: webhook_url || ReplicateClient.configuration.webhook_url,
31
+ webhook: webhook_url || client.configuration.webhook_url,
30
32
  webhook_events_filter: webhook_events_filter&.map(&:to_s)
31
33
  }
32
34
 
33
35
  headers = sync ? { "Prefer" => "wait" } : {}
34
36
 
35
- prediction = ReplicateClient.client.post(INDEX_PATH, args, headers:)
37
+ prediction = client.post(INDEX_PATH, args, headers:)
36
38
 
37
- new(prediction)
39
+ new(prediction, client: client)
38
40
  end
39
41
 
40
42
  # Create a new prediction for a deployment.
@@ -44,12 +46,14 @@ module ReplicateClient
44
46
  # @param webhook_url [String] The URL to send webhook events to.
45
47
  # @param webhook_events_filter [Array<Symbol>] The events to send to the webhook.
46
48
  # @param sync [Boolean] Whether to wait for the prediction to complete.
49
+ # @param client [ReplicateClient::Client] The client to use for the prediction.
47
50
  #
48
51
  # @return [ReplicateClient::Prediction]
49
- def create_for_deployment!(deployment:, input:, webhook_url: nil, webhook_events_filter: nil, sync: false)
52
+ def create_for_deployment!(deployment:, input:, webhook_url: nil, webhook_events_filter: nil, sync: false,
53
+ client: ReplicateClient.client)
50
54
  args = {
51
55
  input: input,
52
- webhook: webhook_url || ReplicateClient.configuration.webhook_url,
56
+ webhook: webhook_url || client.configuration.webhook_url,
53
57
  webhook_events_filter: webhook_events_filter&.map(&:to_s)
54
58
  }
55
59
 
@@ -57,9 +61,9 @@ module ReplicateClient
57
61
 
58
62
  deployment_path = deployment.is_a?(Deployment) ? deployment.path : "#{Deployment::INDEX_PATH}/#{deployment}"
59
63
 
60
- prediction = ReplicateClient.client.post("#{deployment_path}#{INDEX_PATH}", args, headers:)
64
+ prediction = client.post("#{deployment_path}#{INDEX_PATH}", args, headers:)
61
65
 
62
- new(prediction)
66
+ new(prediction, client: client)
63
67
  end
64
68
 
65
69
  # Create a new prediction for a model.
@@ -69,50 +73,55 @@ module ReplicateClient
69
73
  # @param webhook_url [String] The URL to send webhook events to.
70
74
  # @param webhook_events_filter [Array<Symbol>] The events to send to the webhook.
71
75
  # @param sync [Boolean] Whether to wait for the prediction to complete.
76
+ # @param client [ReplicateClient::Client] The client to use for the prediction.
72
77
  #
73
78
  # @return [ReplicateClient::Prediction]
74
- def create_for_official_model!(model:, input:, webhook_url: nil, webhook_events_filter: nil, sync: false)
79
+ def create_for_official_model!(model:, input:, webhook_url: nil, webhook_events_filter: nil, sync: false,
80
+ client: ReplicateClient.client)
75
81
  model_path = model.is_a?(Model) ? model.path : Model.build_path(**Model.parse_model_name(model))
76
82
 
77
83
  args = {
78
84
  input: input,
79
- webhook: webhook_url || ReplicateClient.configuration.webhook_url,
85
+ webhook: webhook_url || client.configuration.webhook_url,
80
86
  webhook_events_filter: webhook_events_filter&.map(&:to_s)
81
87
  }
82
88
 
83
89
  headers = sync ? { "Prefer" => "wait" } : {}
84
90
 
85
- prediction = ReplicateClient.client.post("#{model_path}#{INDEX_PATH}", args, headers:)
91
+ prediction = client.post("#{model_path}#{INDEX_PATH}", args, headers:)
86
92
 
87
- new(prediction)
93
+ new(prediction, client: client)
88
94
  end
89
95
 
90
96
  # Find a prediction.
91
97
  #
92
98
  # @param id [String] The ID of the prediction.
99
+ # @param client [ReplicateClient::Client] The client to use for the prediction.
93
100
  #
94
101
  # @return [ReplicateClient::Prediction]
95
- def find(id)
96
- attributes = ReplicateClient.client.get(build_path(id))
97
- new(attributes)
102
+ def find(id, client: ReplicateClient.client)
103
+ attributes = client.get(build_path(id))
104
+ new(attributes, client: client)
98
105
  end
99
106
 
100
107
  # Find a prediction.
101
108
  #
102
109
  # @param id [String] The ID of the prediction.
110
+ # @param client [ReplicateClient::Client] The client to use for the prediction.
103
111
  #
104
112
  # @return [ReplicateClient::Prediction]
105
- def find_by!(id:)
106
- find(id)
113
+ def find_by!(id:, client: ReplicateClient.client)
114
+ find(id, client: client)
107
115
  end
108
116
 
109
117
  # Find a prediction.
110
118
  #
111
119
  # @param id [String] The ID of the prediction.
120
+ # @param client [ReplicateClient::Client] The client to use for the prediction.
112
121
  #
113
122
  # @return [ReplicateClient::Prediction]
114
- def find_by(id:)
115
- find_by!(id: id)
123
+ def find_by(id:, client: ReplicateClient.client)
124
+ find_by!(id: id, client: client)
116
125
  rescue ReplicateClient::NotFoundError
117
126
  nil
118
127
  end
@@ -129,10 +138,11 @@ module ReplicateClient
129
138
  # Cancel a prediction.
130
139
  #
131
140
  # @param id [String] The ID of the prediction.
141
+ # @param client [ReplicateClient::Client] The client to use for the prediction.
132
142
  #
133
143
  # @return [void]
134
- def cancel!(id)
135
- ReplicateClient.client.post("#{build_path(id)}/cancel")
144
+ def cancel!(id, client: ReplicateClient.client)
145
+ client.post("#{build_path(id)}/cancel")
136
146
  end
137
147
  end
138
148
 
@@ -206,7 +216,13 @@ module ReplicateClient
206
216
  # @return [String]
207
217
  attr_accessor :logs
208
218
 
209
- def initialize(attributes)
219
+ # The client for the prediction.
220
+ #
221
+ # @return [ReplicateClient::Client]
222
+ attr_accessor :client
223
+
224
+ def initialize(attributes, client: ReplicateClient.client)
225
+ @client = client
210
226
  reset_attributes(attributes)
211
227
  end
212
228
 
@@ -214,7 +230,7 @@ module ReplicateClient
214
230
  #
215
231
  # @return [ReplicateClient::Prediction]
216
232
  def reload!
217
- attributes = ReplicateClient.client.get(Prediction.build_path(@id))
233
+ attributes = @client.get(Prediction.build_path(@id))
218
234
  reset_attributes(attributes)
219
235
  end
220
236
 
@@ -236,7 +252,7 @@ module ReplicateClient
236
252
  #
237
253
  # @return [void]
238
254
  def cancel!
239
- Prediction.cancel!(id)
255
+ Prediction.cancel!(id, client: @client)
240
256
  end
241
257
 
242
258
  # Check if the prediction is succeeded.
@@ -15,17 +15,19 @@ module ReplicateClient
15
15
  class << self
16
16
  # List all trainings.
17
17
  #
18
+ # @param client [ReplicateClient::Client] The client to use for requests.
19
+ #
18
20
  # @yield [ReplicateClient::Training] Yields a training.
19
21
  #
20
22
  # @return [void]
21
- def auto_paging_each(&block)
23
+ def auto_paging_each(client: ReplicateClient.client, &block)
22
24
  cursor = nil
23
25
 
24
26
  loop do
25
27
  url_params = cursor ? "?cursor=#{cursor}" : ""
26
- attributes = ReplicateClient.client.get("#{INDEX_PATH}#{url_params}")
28
+ attributes = client.get("#{INDEX_PATH}#{url_params}")
27
29
 
28
- trainings = attributes["results"].map { |training| new(training) }
30
+ trainings = attributes["results"].map { |training| new(training, client: client) }
29
31
 
30
32
  trainings.each(&block)
31
33
 
@@ -44,9 +46,11 @@ module ReplicateClient
44
46
  # @param input [Hash] The input data for the training.
45
47
  # @param webhook_url [String, nil] A URL to receive webhook notifications.
46
48
  # @param webhook_events_filter [Array, nil] The events to trigger webhook requests.
49
+ # @param client [ReplicateClient::Client] The client to use for requests.
47
50
  #
48
51
  # @return [ReplicateClient::Training]
49
- def create!(owner:, name:, version:, destination:, input:, webhook_url: nil, webhook_events_filter: nil)
52
+ def create!(owner:, name:, version:, destination:, input:, webhook_url: nil, webhook_events_filter: nil,
53
+ client: ReplicateClient.client)
50
54
  destination_str = destination.is_a?(ReplicateClient::Model) ? destination.full_name : destination
51
55
  version_id = version.is_a?(ReplicateClient::Model::Version) ? version.id : version
52
56
 
@@ -54,12 +58,12 @@ module ReplicateClient
54
58
  body = {
55
59
  destination: destination_str,
56
60
  input: input,
57
- webhook: webhook_url || ReplicateClient.configuration.webhook_url,
61
+ webhook: webhook_url || client.configuration.webhook_url,
58
62
  webhook_events_filter: webhook_events_filter
59
63
  }
60
64
 
61
- attributes = ReplicateClient.client.post(path, body)
62
- new(attributes)
65
+ attributes = client.post(path, body)
66
+ new(attributes, client: client)
63
67
  end
64
68
 
65
69
  # Create a new training for a specific model.
@@ -69,10 +73,17 @@ module ReplicateClient
69
73
  # @param input [Hash] The input data for the training.
70
74
  # @param webhook_url [String, nil] A URL to receive webhook notifications.
71
75
  # @param webhook_events_filter [Array, nil] The events to trigger webhook requests.
76
+ # @param client [ReplicateClient::Client] The client to use for requests.
72
77
  #
73
78
  # @return [ReplicateClient::Training]
74
- def create_for_model!(model:, destination:, input:, webhook_url: nil, webhook_events_filter: nil)
75
- model_instance = model.is_a?(ReplicateClient::Model) ? model : ReplicateClient::Model.find(model)
79
+ def create_for_model!(model:, destination:, input:, webhook_url: nil, webhook_events_filter: nil,
80
+ client: ReplicateClient.client)
81
+ model_instance = if model.is_a?(ReplicateClient::Model)
82
+ model
83
+ else
84
+ ReplicateClient::Model.find(model,
85
+ client: client)
86
+ end
76
87
  raise ArgumentError, "Invalid model" unless model_instance
77
88
 
78
89
  create!(
@@ -81,30 +92,33 @@ module ReplicateClient
81
92
  version: model_instance.version_id,
82
93
  destination: destination,
83
94
  input: input,
84
- webhook_url: webhook_url || ReplicateClient.configuration.webhook_url,
85
- webhook_events_filter: webhook_events_filter
95
+ webhook_url: webhook_url || client.configuration.webhook_url,
96
+ webhook_events_filter: webhook_events_filter,
97
+ client: client
86
98
  )
87
99
  end
88
100
 
89
101
  # Find a training by id.
90
102
  #
91
103
  # @param id [String] The id of the training.
104
+ # @param client [ReplicateClient::Client] The client to use for requests.
92
105
  #
93
106
  # @return [ReplicateClient::Training]
94
- def find(id)
107
+ def find(id, client: ReplicateClient.client)
95
108
  path = build_path(id: id)
96
- attributes = ReplicateClient.client.get(path)
97
- new(attributes)
109
+ attributes = client.get(path)
110
+ new(attributes, client: client)
98
111
  end
99
112
 
100
113
  # Cancel a training.
101
114
  #
102
115
  # @param id [String] The id of the training.
116
+ # @param client [ReplicateClient::Client] The client to use for requests.
103
117
  #
104
118
  # @return [void]
105
- def cancel!(id)
119
+ def cancel!(id, client: ReplicateClient.client)
106
120
  path = "#{build_path(id: id)}/cancel"
107
- ReplicateClient.client.post(path)
121
+ client.post(path, {})
108
122
  end
109
123
 
110
124
  # Build the path for a specific training.
@@ -183,12 +197,19 @@ module ReplicateClient
183
197
  # @return [Hash, nil]
184
198
  attr_accessor :metrics
185
199
 
200
+ # The client used to make API requests for this training.
201
+ #
202
+ # @return [ReplicateClient::Client]
203
+ attr_accessor :client
204
+
186
205
  # Initialize a new training instance.
187
206
  #
188
207
  # @param attributes [Hash] The attributes of the training.
208
+ # @param client [ReplicateClient::Client] The client to use for requests.
189
209
  #
190
210
  # @return [ReplicateClient::Training]
191
- def initialize(attributes)
211
+ def initialize(attributes, client: ReplicateClient.client)
212
+ @client = client
192
213
  reset_attributes(attributes)
193
214
  end
194
215
 
@@ -231,14 +252,14 @@ module ReplicateClient
231
252
  #
232
253
  # @return [void]
233
254
  def cancel!
234
- ReplicateClient::Training.cancel!(id)
255
+ ReplicateClient::Training.cancel!(id, client: @client)
235
256
  end
236
257
 
237
258
  # Reload the training.
238
259
  #
239
260
  # @return [void]
240
261
  def reload!
241
- attributes = ReplicateClient.client.get(Training.build_path(id: id))
262
+ attributes = @client.get(Training.build_path(id: id))
242
263
  reset_attributes(attributes)
243
264
  end
244
265
 
@@ -246,7 +267,7 @@ module ReplicateClient
246
267
  #
247
268
  # @return [ReplicateClient::Model]
248
269
  def model
249
- @model ||= ReplicateClient::Model.find(model_full_name, version_id: version_id)
270
+ @model ||= ReplicateClient::Model.find(model_full_name, version_id: version_id, client: @client)
250
271
  end
251
272
 
252
273
  # The version instance of the training.
@@ -1,5 +1,5 @@
1
1
  # frozen_string_literal: true
2
2
 
3
3
  module ReplicateClient
4
- VERSION = "0.1.9"
4
+ VERSION = "0.1.10"
5
5
  end
@@ -0,0 +1,31 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "lib/replicate-client/version"
4
+
5
+ Gem::Specification.new do |spec|
6
+ spec.name = "replicate-client"
7
+ spec.version = ReplicateClient::VERSION
8
+ spec.authors = ["Dylan Player"]
9
+ spec.email = ["dylan@851.sh"]
10
+
11
+ spec.summary = "Ruby client for Replicate API."
12
+ spec.homepage = "https://github.com/851-labs/replicate"
13
+ spec.required_ruby_version = ">= 3.3.0"
14
+
15
+ spec.metadata["allowed_push_host"] = "https://rubygems.org"
16
+ spec.metadata["homepage_uri"] = spec.homepage
17
+ spec.metadata["source_code_uri"] = spec.homepage
18
+ spec.license = "MIT"
19
+
20
+ spec.files = Dir.chdir(File.expand_path(__dir__)) do
21
+ `git ls-files -z`.split("\x0").reject do |f|
22
+ (f == __FILE__) || f.match(%r{\A(?:(?:bin|test|spec|features)/|\.(?:git|travis|circleci)|appveyor)})
23
+ end
24
+ end
25
+ spec.bindir = "exe"
26
+ spec.executables = spec.files.grep(%r{\Aexe/}) { |f| File.basename(f) }
27
+ spec.require_paths = ["lib"]
28
+
29
+ spec.add_dependency("faraday", ">= 1")
30
+ spec.metadata["rubygems_mfa_required"] = "true"
31
+ end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: replicate-client
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.9
4
+ version: 0.1.10
5
5
  platform: ruby
6
6
  authors:
7
7
  - Dylan Player
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2025-01-24 00:00:00.000000000 Z
11
+ date: 2025-09-17 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: faraday
@@ -48,6 +48,7 @@ files:
48
48
  - lib/replicate-client/training.rb
49
49
  - lib/replicate-client/version.rb
50
50
  - lib/replicate-client/webhook.rb
51
+ - replicate-client.gemspec
51
52
  homepage: https://github.com/851-labs/replicate
52
53
  licenses:
53
54
  - MIT