replicate-client 0.1.6 → 0.1.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 5556c24e737090c5fb8425f5e947822e4b43dee0fbda7703d3fb1bc0f2bf4dd7
4
- data.tar.gz: 6340ba5606a808c35fa4bf9f181ea82aa85c6108906ec93aaa0c5387d6a1a11e
3
+ metadata.gz: 61acc30d7e6406c1aa6f797a45ab6927945d399f3ce49934a2ada361a818df18
4
+ data.tar.gz: 6e0023bde6ce5e0897314e3f1bde97e1b7d99b9dd1de776568a883822852a874
5
5
  SHA512:
6
- metadata.gz: c934a95922c925246279b3c59f6412c309508d926e14f1a0734c63728580db0ca0323a148fde5596563f703d2c8ee5adf08bdc89a87b4d52d4aac95e997dcf3a
7
- data.tar.gz: 47e28051e876e484f43f993dc78a9fe9f14c3ddab29241974dc35693713fce0a845bb4a2c3f744d36a6a20b6680a40230921a6f89600a2c19ba365aef907c1e8
6
+ metadata.gz: ee9a5f6e05fbf0b8af2a34c4890dae5ec925d47b5ed0c9afceb866e07c1be71a3d2c50f98c1ca569b84bb03d17abc2f665e48a6927b0243fab2180ff2f43df44
7
+ data.tar.gz: 692003bec1af63306d896535c434b888ce9162431839809ae140b811a5c40e9323027143096b5e8ea0a1270fd79f6f23a4b544ed33d2090f43f4d7c0d2d449c6
data/Gemfile.lock CHANGED
@@ -1,7 +1,7 @@
1
1
  PATH
2
2
  remote: .
3
3
  specs:
4
- replicate-client (0.1.4)
4
+ replicate-client (0.1.7)
5
5
  faraday (>= 1)
6
6
 
7
7
  GEM
@@ -15,13 +15,15 @@ module ReplicateClient
15
15
  #
16
16
  # @param path [String] The path to the API endpoint.
17
17
  # @param payload [Hash] The payload to send to the API.
18
+ # @param headers [Hash] The headers to send to the API.
18
19
  #
19
20
  # @return [Hash] The response from the API.
20
- def post(path, payload)
21
+ def post(path, payload, headers: {})
21
22
  response = connection.post(build_url(path)) do |request|
22
23
  request.headers["Authorization"] = "Bearer #{@configuration.access_token}"
23
24
  request.headers["Content-Type"] = "application/json"
24
25
  request.headers["Accept"] = "application/json"
26
+ request.headers.merge!(headers)
25
27
  request.body = payload.compact.to_json
26
28
  end
27
29
 
@@ -127,6 +127,37 @@ module ReplicateClient
127
127
  webhook_events_filter: webhook_events_filter
128
128
  )
129
129
  end
130
+
131
+ # Get the prediction input schema from the openapi schema.
132
+ #
133
+ # @return [Hash] The prediction input schema.
134
+ def prediction_input_schema
135
+ resolve_ref(openapi_schema.dig("components", "schemas", "PredictionRequest", "properties", "input"))
136
+ end
137
+
138
+ # Get the training input schema from the openapi schema.
139
+ #
140
+ # @return [Hash] The training input schema.
141
+ def training_input_schema
142
+ resolve_ref(openapi_schema.dig("components", "schemas", "TrainingRequest", "properties", "input"))
143
+ end
144
+
145
+ private
146
+
147
+ # Resolve a reference in the openapi schema.
148
+ #
149
+ # @param schema [Hash] The schema to resolve.
150
+ #
151
+ # @return [Hash] The resolved schema.
152
+ def resolve_ref(schema)
153
+ if schema.is_a?(Hash) && schema["$ref"]
154
+ ref_path = schema["$ref"].split("/").drop(1)
155
+ resolved_schema = openapi_schema.dig(*ref_path)
156
+ resolve_ref(resolved_schema)
157
+ else
158
+ schema
159
+ end
160
+ end
130
161
  end
131
162
  end
132
163
  end
@@ -19,9 +19,10 @@ module ReplicateClient
19
19
  # @param input [Hash] The input data for the prediction.
20
20
  # @param webhook_url [String] The URL to send webhook events to.
21
21
  # @param webhook_events_filter [Array<Symbol>] The events to send to the webhook.
22
+ # @param sync [Boolean] Whether to wait for the prediction to complete.
22
23
  #
23
24
  # @return [ReplicateClient::Prediction]
24
- def create!(version:, input:, webhook_url: nil, webhook_events_filter: nil)
25
+ def create!(version:, input:, webhook_url: nil, webhook_events_filter: nil, sync: false)
25
26
  args = {
26
27
  version: version.is_a?(Model::Version) ? version.id : version,
27
28
  input: input,
@@ -29,7 +30,9 @@ module ReplicateClient
29
30
  webhook_events_filter: webhook_events_filter&.map(&:to_s)
30
31
  }
31
32
 
32
- prediction = ReplicateClient.client.post(INDEX_PATH, args)
33
+ headers = sync ? { "Prefer" => 'wait' } : {}
34
+
35
+ prediction = ReplicateClient.client.post(INDEX_PATH, args, headers:)
33
36
 
34
37
  new(prediction)
35
38
  end
@@ -40,16 +43,19 @@ module ReplicateClient
40
43
  # @param input [Hash] The input data for the prediction.
41
44
  # @param webhook_url [String] The URL to send webhook events to.
42
45
  # @param webhook_events_filter [Array<Symbol>] The events to send to the webhook.
46
+ # @param sync [Boolean] Whether to wait for the prediction to complete.
43
47
  #
44
48
  # @return [ReplicateClient::Prediction]
45
- def create_for_deployment!(deployment:, input:, webhook_url: nil, webhook_events_filter: nil)
49
+ def create_for_deployment!(deployment:, input:, webhook_url: nil, webhook_events_filter: nil, sync: false)
46
50
  args = {
47
51
  input: input,
48
52
  webhook: webhook_url || ReplicateClient.configuration.webhook_url,
49
53
  webhook_events_filter: webhook_events_filter&.map(&:to_s)
50
54
  }
51
55
 
52
- prediction = ReplicateClient.client.post("#{deployment.path}#{INDEX_PATH}", args)
56
+ headers = sync ? { "Prefer" => 'wait' } : {}
57
+
58
+ prediction = ReplicateClient.client.post("#{deployment.path}#{INDEX_PATH}", args, headers:)
53
59
 
54
60
  new(prediction)
55
61
  end
@@ -60,9 +66,10 @@ module ReplicateClient
60
66
  # @param input [Hash] The input data for the prediction.
61
67
  # @param webhook_url [String] The URL to send webhook events to.
62
68
  # @param webhook_events_filter [Array<Symbol>] The events to send to the webhook.
69
+ # @param sync [Boolean] Whether to wait for the prediction to complete.
63
70
  #
64
71
  # @return [ReplicateClient::Prediction]
65
- def create_for_official_model!(model:, input:, webhook_url: nil, webhook_events_filter: nil)
72
+ def create_for_official_model!(model:, input:, webhook_url: nil, webhook_events_filter: nil, sync: false)
66
73
  model_path = model.is_a?(Model) ? model.path : Model.build_path(**Model.parse_model_name(model))
67
74
 
68
75
  args = {
@@ -71,7 +78,9 @@ module ReplicateClient
71
78
  webhook_events_filter: webhook_events_filter&.map(&:to_s)
72
79
  }
73
80
 
74
- prediction = ReplicateClient.client.post("#{model_path}#{INDEX_PATH}", args)
81
+ headers = sync ? { "Prefer" => 'wait' } : {}
82
+
83
+ prediction = ReplicateClient.client.post("#{model_path}#{INDEX_PATH}", args, headers:)
75
84
 
76
85
  new(prediction)
77
86
  end
@@ -1,5 +1,5 @@
1
1
  # frozen_string_literal: true
2
2
 
3
3
  module ReplicateClient
4
- VERSION = "0.1.6"
4
+ VERSION = "0.1.8"
5
5
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: replicate-client
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.6
4
+ version: 0.1.8
5
5
  platform: ruby
6
6
  authors:
7
7
  - Dylan Player
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2024-09-01 00:00:00.000000000 Z
11
+ date: 2025-01-19 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: faraday