openai.rb 0.0.0 → 0.0.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.github/workflows/main.yml +27 -0
- data/.rubocop.yml +18 -0
- data/.ruby-version +1 -1
- data/Gemfile +9 -5
- data/Gemfile.lock +29 -24
- data/README.md +401 -0
- data/bin/console +9 -4
- data/lib/openai/api/cache.rb +137 -0
- data/lib/openai/api/client.rb +86 -0
- data/lib/openai/api/resource.rb +232 -0
- data/lib/openai/api/response.rb +384 -0
- data/lib/openai/api.rb +75 -0
- data/lib/openai/chat.rb +125 -0
- data/lib/openai/tokenizer.rb +50 -0
- data/lib/openai/util.rb +47 -0
- data/lib/openai/version.rb +1 -1
- data/lib/openai.rb +38 -357
- data/openai.gemspec +9 -3
- data/spec/data/sample_french.mp3 +0 -0
- data/spec/data/sample_image.png +0 -0
- data/spec/data/sample_image_mask.png +0 -0
- data/spec/shared/api_resource_context.rb +22 -0
- data/spec/spec_helper.rb +4 -0
- data/spec/unit/openai/api/audio_spec.rb +78 -0
- data/spec/unit/openai/api/cache_spec.rb +115 -0
- data/spec/unit/openai/api/chat_completions_spec.rb +130 -0
- data/spec/unit/openai/api/completions_spec.rb +125 -0
- data/spec/unit/openai/api/edits_spec.rb +40 -0
- data/spec/unit/openai/api/embeddings_spec.rb +45 -0
- data/spec/unit/openai/api/files_spec.rb +163 -0
- data/spec/unit/openai/api/fine_tunes_spec.rb +322 -0
- data/spec/unit/openai/api/images_spec.rb +137 -0
- data/spec/unit/openai/api/models_spec.rb +98 -0
- data/spec/unit/openai/api/moderations_spec.rb +63 -0
- data/spec/unit/openai/api/response_spec.rb +203 -0
- data/spec/unit/openai/chat_spec.rb +32 -0
- data/spec/unit/openai/tokenizer_spec.rb +45 -0
- data/spec/unit/openai_spec.rb +47 -736
- metadata +97 -7
- data/bin/codegen +0 -371
data/lib/openai.rb
CHANGED
@@ -1,381 +1,62 @@
|
|
1
1
|
# frozen_string_literal: true
|
2
2
|
|
3
|
+
require 'pathname'
|
4
|
+
require 'logger'
|
5
|
+
|
3
6
|
require 'concord'
|
4
7
|
require 'anima'
|
8
|
+
require 'abstract_type'
|
5
9
|
require 'http'
|
6
10
|
require 'addressable'
|
7
|
-
|
11
|
+
require 'ice_nine'
|
12
|
+
require 'tiktoken_ruby'
|
13
|
+
|
14
|
+
require 'openai/util'
|
15
|
+
require 'openai/tokenizer'
|
16
|
+
require 'openai/chat'
|
17
|
+
require 'openai/api'
|
18
|
+
require 'openai/api/cache'
|
19
|
+
require 'openai/api/client'
|
20
|
+
require 'openai/api/resource'
|
21
|
+
require 'openai/api/response'
|
8
22
|
require 'openai/version'
|
9
23
|
|
10
24
|
class OpenAI
|
11
|
-
include Concord.new(:
|
12
|
-
|
13
|
-
ResponseError = Class.new(StandardError)
|
14
|
-
|
15
|
-
HOST = Addressable::URI.parse('https://api.openai.com/v1')
|
25
|
+
include Concord.new(:api_client, :logger)
|
16
26
|
|
17
|
-
|
18
|
-
super(api_key, http)
|
19
|
-
end
|
20
|
-
|
21
|
-
def create_completion(model:, **kwargs)
|
22
|
-
Response::Completion.from_json(
|
23
|
-
post('/v1/completions', model: model, **kwargs)
|
24
|
-
)
|
25
|
-
end
|
27
|
+
public :logger
|
26
28
|
|
27
|
-
|
28
|
-
Response::ChatCompletion.from_json(
|
29
|
-
post('/v1/chat/completions', model: model, messages: messages, **kwargs)
|
30
|
-
)
|
31
|
-
end
|
32
|
-
|
33
|
-
def create_embedding(model:, input:, **kwargs)
|
34
|
-
Response::Embedding.from_json(
|
35
|
-
post('/v1/embeddings', model: model, input: input, **kwargs)
|
36
|
-
)
|
37
|
-
end
|
38
|
-
|
39
|
-
def list_models
|
40
|
-
Response::ListModel.from_json(get('/v1/models'))
|
41
|
-
end
|
29
|
+
ROOT = Pathname.new(__dir__).parent.expand_path.freeze
|
42
30
|
|
43
|
-
def
|
44
|
-
|
45
|
-
get("/v1/models/#{model_id}")
|
46
|
-
)
|
47
|
-
end
|
48
|
-
|
49
|
-
def create_edit(model:, instruction:, **kwargs)
|
50
|
-
Response::Edit.from_json(
|
51
|
-
post('/v1/edits', model: model, instruction: instruction, **kwargs)
|
52
|
-
)
|
53
|
-
end
|
54
|
-
|
55
|
-
def create_image_generation(prompt:, **kwargs)
|
56
|
-
Response::ImageGeneration.from_json(
|
57
|
-
post('/v1/images/generations', prompt: prompt, **kwargs)
|
58
|
-
)
|
59
|
-
end
|
60
|
-
|
61
|
-
def create_file(file:, purpose:)
|
62
|
-
absolute_path = Pathname.new(file).expand_path.to_s
|
63
|
-
form_file = HTTP::FormData::File.new(absolute_path)
|
64
|
-
Response::File.from_json(
|
65
|
-
post_form_multipart('/v1/files', file: form_file, purpose: purpose)
|
66
|
-
)
|
67
|
-
end
|
68
|
-
|
69
|
-
def list_files
|
70
|
-
Response::FileList.from_json(
|
71
|
-
get('/v1/files')
|
72
|
-
)
|
73
|
-
end
|
31
|
+
def self.create(api_key, cache: nil, logger: Logger.new('/dev/null'))
|
32
|
+
client = API::Client.new(api_key)
|
74
33
|
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
end
|
80
|
-
|
81
|
-
def get_file(file_id)
|
82
|
-
Response::File.from_json(
|
83
|
-
get("/v1/files/#{file_id}")
|
84
|
-
)
|
85
|
-
end
|
86
|
-
|
87
|
-
def get_file_content(file_id)
|
88
|
-
get("/v1/files/#{file_id}/content")
|
89
|
-
end
|
90
|
-
|
91
|
-
def list_fine_tunes
|
92
|
-
Response::FineTuneList.from_json(
|
93
|
-
get('/v1/fine-tunes')
|
94
|
-
)
|
95
|
-
end
|
96
|
-
|
97
|
-
def create_fine_tune(training_file:, **kwargs)
|
98
|
-
Response::FineTune.from_json(
|
99
|
-
post('/v1/fine-tunes', training_file: training_file, **kwargs)
|
100
|
-
)
|
101
|
-
end
|
102
|
-
|
103
|
-
def get_fine_tune(fine_tune_id)
|
104
|
-
Response::FineTune.from_json(
|
105
|
-
get("/v1/fine-tunes/#{fine_tune_id}")
|
106
|
-
)
|
107
|
-
end
|
108
|
-
|
109
|
-
def cancel_fine_tune(fine_tune_id)
|
110
|
-
Response::FineTune.from_json(
|
111
|
-
post("/v1/fine-tunes/#{fine_tune_id}/cancel")
|
112
|
-
)
|
113
|
-
end
|
114
|
-
|
115
|
-
def transcribe_audio(file:, model:, **kwargs)
|
116
|
-
absolute_path = Pathname.new(file).expand_path.to_s
|
117
|
-
form_file = HTTP::FormData::File.new(absolute_path)
|
118
|
-
Response::Transcription.from_json(
|
119
|
-
post_form_multipart(
|
120
|
-
'/v1/audio/transcriptions',
|
121
|
-
file: form_file,
|
122
|
-
model: model,
|
123
|
-
**kwargs
|
34
|
+
if cache.is_a?(Pathname) && cache.directory?
|
35
|
+
client = API::Cache.new(
|
36
|
+
client,
|
37
|
+
API::Cache::Strategy::FileSystem.new(cache)
|
124
38
|
)
|
125
|
-
)
|
126
|
-
end
|
127
|
-
|
128
|
-
def inspect
|
129
|
-
"#<#{self.class}>"
|
130
|
-
end
|
131
|
-
|
132
|
-
private
|
133
|
-
|
134
|
-
def get(route)
|
135
|
-
unwrap_response(json_http_client.get(url_for(route)))
|
136
|
-
end
|
137
|
-
|
138
|
-
def delete(route)
|
139
|
-
unwrap_response(json_http_client.delete(url_for(route)))
|
140
|
-
end
|
141
|
-
|
142
|
-
def post(route, **body)
|
143
|
-
unwrap_response(json_http_client.post(url_for(route), json: body))
|
144
|
-
end
|
145
|
-
|
146
|
-
def post_form_multipart(route, **body)
|
147
|
-
unwrap_response(http_client.post(url_for(route), form: body))
|
148
|
-
end
|
149
|
-
|
150
|
-
def url_for(route)
|
151
|
-
HOST.join(route).to_str
|
152
|
-
end
|
153
|
-
|
154
|
-
def unwrap_response(response)
|
155
|
-
unless response.status.success?
|
156
|
-
raise ResponseError, "Unexpected response #{response.status}\nBody:\n#{response.body}"
|
157
39
|
end
|
158
40
|
|
159
|
-
|
41
|
+
new(client, logger)
|
160
42
|
end
|
161
43
|
|
162
|
-
|
163
|
-
http_client.headers('Content-Type' => 'application/json')
|
164
|
-
end
|
44
|
+
private_class_method :new
|
165
45
|
|
166
|
-
def
|
167
|
-
|
46
|
+
def api
|
47
|
+
API.new(api_client)
|
168
48
|
end
|
169
49
|
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
def self.from_json(raw_json)
|
175
|
-
new(JSON.parse(raw_json, symbolize_names: true))
|
176
|
-
end
|
177
|
-
|
178
|
-
def self.field(name, path: [name], wrapper: nil)
|
179
|
-
given_wrapper = wrapper
|
180
|
-
define_method(name) do
|
181
|
-
field(path, wrapper: given_wrapper)
|
182
|
-
end
|
183
|
-
end
|
184
|
-
|
185
|
-
def self.optional_field(name, path: name)
|
186
|
-
define_method(name) do
|
187
|
-
optional_field(path)
|
188
|
-
end
|
189
|
-
end
|
190
|
-
|
191
|
-
def original_payload
|
192
|
-
internal_data
|
193
|
-
end
|
194
|
-
|
195
|
-
private
|
196
|
-
|
197
|
-
def optional_field(key_path)
|
198
|
-
*head, tail = key_path
|
199
|
-
|
200
|
-
field(head)[tail]
|
201
|
-
end
|
202
|
-
|
203
|
-
def field(key_path, wrapper: nil)
|
204
|
-
value = key_path.reduce(internal_data, :fetch)
|
205
|
-
return value unless wrapper
|
206
|
-
|
207
|
-
if value.is_a?(Array)
|
208
|
-
value.map { |item| wrapper.new(item) }
|
209
|
-
else
|
210
|
-
wrapper.new(value)
|
211
|
-
end
|
212
|
-
end
|
213
|
-
end
|
214
|
-
|
215
|
-
class Completion < JSONPayload
|
216
|
-
class Choice < JSONPayload
|
217
|
-
field :text
|
218
|
-
field :index
|
219
|
-
field :logprobs
|
220
|
-
field :finish_reason
|
221
|
-
end
|
222
|
-
|
223
|
-
class Usage < JSONPayload
|
224
|
-
field :prompt_tokens
|
225
|
-
field :completion_tokens
|
226
|
-
field :total_tokens
|
227
|
-
end
|
228
|
-
|
229
|
-
field :id
|
230
|
-
field :object
|
231
|
-
field :created
|
232
|
-
field :model
|
233
|
-
field :choices, wrapper: Choice
|
234
|
-
field :usage, wrapper: Usage
|
235
|
-
end
|
236
|
-
|
237
|
-
class ChatCompletion < JSONPayload
|
238
|
-
class Choice < JSONPayload
|
239
|
-
class Message < JSONPayload
|
240
|
-
field :role
|
241
|
-
field :content
|
242
|
-
end
|
243
|
-
|
244
|
-
field :index
|
245
|
-
field :message, wrapper: Message
|
246
|
-
field :finish_reason
|
247
|
-
end
|
248
|
-
|
249
|
-
class Usage < JSONPayload
|
250
|
-
field :prompt_tokens
|
251
|
-
field :completion_tokens
|
252
|
-
field :total_tokens
|
253
|
-
end
|
254
|
-
|
255
|
-
field :id
|
256
|
-
field :object
|
257
|
-
field :created
|
258
|
-
field :choices, wrapper: Choice
|
259
|
-
field :usage, wrapper: Usage
|
260
|
-
end
|
261
|
-
|
262
|
-
class Embedding < JSONPayload
|
263
|
-
class EmbeddingData < JSONPayload
|
264
|
-
field :object
|
265
|
-
field :embedding
|
266
|
-
field :index
|
267
|
-
end
|
268
|
-
|
269
|
-
class Usage < JSONPayload
|
270
|
-
field :prompt_tokens
|
271
|
-
field :total_tokens
|
272
|
-
end
|
273
|
-
|
274
|
-
field :object
|
275
|
-
field :data, wrapper: EmbeddingData
|
276
|
-
field :model
|
277
|
-
field :usage, wrapper: Usage
|
278
|
-
end
|
279
|
-
|
280
|
-
class Model < JSONPayload
|
281
|
-
field :id
|
282
|
-
field :object
|
283
|
-
field :owned_by
|
284
|
-
field :permission
|
285
|
-
end
|
286
|
-
|
287
|
-
class ListModel < JSONPayload
|
288
|
-
field :data, wrapper: Model
|
289
|
-
end
|
290
|
-
|
291
|
-
class Edit < JSONPayload
|
292
|
-
class Choice < JSONPayload
|
293
|
-
field :text
|
294
|
-
field :index
|
295
|
-
end
|
296
|
-
|
297
|
-
class Usage < JSONPayload
|
298
|
-
field :prompt_tokens
|
299
|
-
field :completion_tokens
|
300
|
-
field :total_tokens
|
301
|
-
end
|
302
|
-
|
303
|
-
field :object
|
304
|
-
field :created
|
305
|
-
field :choices, wrapper: Choice
|
306
|
-
field :usage, wrapper: Usage
|
307
|
-
end
|
308
|
-
|
309
|
-
class ImageGeneration < JSONPayload
|
310
|
-
class Image < JSONPayload
|
311
|
-
field :url
|
312
|
-
end
|
313
|
-
|
314
|
-
field :created
|
315
|
-
field :data, wrapper: Image
|
316
|
-
end
|
317
|
-
|
318
|
-
class File < JSONPayload
|
319
|
-
field :id
|
320
|
-
field :object
|
321
|
-
field :bytes
|
322
|
-
field :created_at
|
323
|
-
field :filename
|
324
|
-
field :purpose
|
325
|
-
optional_field :deleted?, path: :deleted
|
326
|
-
end
|
327
|
-
|
328
|
-
class FileList < JSONPayload
|
329
|
-
field :data, wrapper: File
|
330
|
-
field :object
|
331
|
-
end
|
332
|
-
|
333
|
-
class FineTune < JSONPayload
|
334
|
-
class Event < JSONPayload
|
335
|
-
field :object
|
336
|
-
field :created_at
|
337
|
-
field :level
|
338
|
-
field :message
|
339
|
-
end
|
340
|
-
|
341
|
-
class Hyperparams < JSONPayload
|
342
|
-
field :batch_size
|
343
|
-
field :learning_rate_multiplier
|
344
|
-
field :n_epochs
|
345
|
-
field :prompt_loss_weight
|
346
|
-
end
|
347
|
-
|
348
|
-
class File < JSONPayload
|
349
|
-
field :id
|
350
|
-
field :object
|
351
|
-
field :bytes
|
352
|
-
field :created_at
|
353
|
-
field :filename
|
354
|
-
field :purpose
|
355
|
-
end
|
356
|
-
|
357
|
-
field :id
|
358
|
-
field :object
|
359
|
-
field :model
|
360
|
-
field :created_at
|
361
|
-
field :events, wrapper: Event
|
362
|
-
field :fine_tuned_model
|
363
|
-
field :hyperparams, wrapper: Hyperparams
|
364
|
-
field :organization_id
|
365
|
-
field :result_files, wrapper: File
|
366
|
-
field :status
|
367
|
-
field :validation_files, wrapper: File
|
368
|
-
field :training_files, wrapper: File
|
369
|
-
field :updated_at
|
370
|
-
end
|
371
|
-
|
372
|
-
class FineTuneList < JSONPayload
|
373
|
-
field :object
|
374
|
-
field :data, wrapper: FineTune
|
375
|
-
end
|
50
|
+
def tokenizer
|
51
|
+
Tokenizer.new
|
52
|
+
end
|
53
|
+
alias tokens tokenizer
|
376
54
|
|
377
|
-
|
378
|
-
|
379
|
-
|
55
|
+
def chat(model:, history: [], **kwargs)
|
56
|
+
Chat.new(
|
57
|
+
openai: self,
|
58
|
+
settings: kwargs.merge(model: model),
|
59
|
+
messages: history
|
60
|
+
)
|
380
61
|
end
|
381
62
|
end
|
data/openai.gemspec
CHANGED
@@ -16,7 +16,13 @@ Gem::Specification.new do |spec|
|
|
16
16
|
spec.require_paths = %w[lib]
|
17
17
|
spec.executables = []
|
18
18
|
|
19
|
-
spec.
|
20
|
-
|
21
|
-
spec.add_dependency '
|
19
|
+
spec.required_ruby_version = '>= 2.7'
|
20
|
+
|
21
|
+
spec.add_dependency 'abstract_type', '~> 0.0.7'
|
22
|
+
spec.add_dependency 'anima', '~> 0.3'
|
23
|
+
spec.add_dependency 'concord', '~> 0.1'
|
24
|
+
spec.add_dependency 'http', '>= 4.4', '< 6.0'
|
25
|
+
spec.add_dependency 'ice_nine', '~> 0.11.x'
|
26
|
+
spec.add_dependency 'memoizable', '~> 0.4.2'
|
27
|
+
spec.add_dependency 'tiktoken_ruby', '~> 0.0.3'
|
22
28
|
end
|
Binary file
|
Binary file
|
Binary file
|
@@ -0,0 +1,22 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
RSpec.shared_context 'an API Resource' do
|
4
|
+
let(:api) { OpenAI::API.new(api_client) }
|
5
|
+
let(:api_client) { OpenAI::API::Client.new('sk-123', http: http) }
|
6
|
+
let(:http) { class_spy(HTTP) }
|
7
|
+
let(:response_status_code) { 200 }
|
8
|
+
|
9
|
+
let(:response) do
|
10
|
+
instance_double(
|
11
|
+
HTTP::Response,
|
12
|
+
status: HTTP::Response::Status.new(response_status_code),
|
13
|
+
body: JSON.dump(response_body)
|
14
|
+
)
|
15
|
+
end
|
16
|
+
|
17
|
+
before do
|
18
|
+
allow(http).to receive(:post).and_return(response)
|
19
|
+
allow(http).to receive(:get).and_return(response)
|
20
|
+
allow(http).to receive(:delete).and_return(response)
|
21
|
+
end
|
22
|
+
end
|
data/spec/spec_helper.rb
CHANGED
@@ -9,6 +9,10 @@ module OpenAISpec
|
|
9
9
|
SPEC_ROOT = ROOT.join('spec')
|
10
10
|
end
|
11
11
|
|
12
|
+
OpenAISpec::SPEC_ROOT.glob('shared/*.rb').shuffle.each do |shared_spec|
|
13
|
+
require(shared_spec)
|
14
|
+
end
|
15
|
+
|
12
16
|
RSpec.configure do |config|
|
13
17
|
# Enable focused tests and run all tests if nothing is focused
|
14
18
|
config.filter_run_when_matching(:focus)
|
@@ -0,0 +1,78 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
RSpec.describe OpenAI::API, '#audio' do
|
4
|
+
include_context 'an API Resource'
|
5
|
+
|
6
|
+
let(:resource) { api.audio }
|
7
|
+
let(:sample_audio) { OpenAISpec::SPEC_ROOT.join('data/sample.mp3') }
|
8
|
+
|
9
|
+
context 'when transcribing audio' do
|
10
|
+
let(:response_body) do
|
11
|
+
{
|
12
|
+
"text": "Imagine the wildest idea that you've ever had, and you're curious about how it might scale to something that's a 100, a 1,000 times bigger. This is a place where you can get to do that."
|
13
|
+
}
|
14
|
+
end
|
15
|
+
|
16
|
+
it 'can transcribe audio' do
|
17
|
+
transcription = resource.transcribe(
|
18
|
+
file: sample_audio,
|
19
|
+
model: 'model-1234'
|
20
|
+
)
|
21
|
+
|
22
|
+
expect(http)
|
23
|
+
.to have_received(:post)
|
24
|
+
.with(
|
25
|
+
'https://api.openai.com/v1/audio/transcriptions',
|
26
|
+
hash_including(
|
27
|
+
form: hash_including(
|
28
|
+
{
|
29
|
+
file: instance_of(HTTP::FormData::File),
|
30
|
+
model: 'model-1234'
|
31
|
+
}
|
32
|
+
)
|
33
|
+
)
|
34
|
+
)
|
35
|
+
|
36
|
+
expect(transcription.text).to eql("Imagine the wildest idea that you've ever had, and you're curious about how it might scale to something that's a 100, a 1,000 times bigger. This is a place where you can get to do that.")
|
37
|
+
end
|
38
|
+
end
|
39
|
+
|
40
|
+
context 'when translating audio' do
|
41
|
+
let(:sample_audio) { OpenAISpec::SPEC_ROOT.join('data/sample_french.mp3') }
|
42
|
+
|
43
|
+
let(:response_body) do
|
44
|
+
{
|
45
|
+
"text": 'Hello, my name is Wolfgang and I come from Germany. Where are you heading today?'
|
46
|
+
}
|
47
|
+
end
|
48
|
+
|
49
|
+
it 'can translate audio' do
|
50
|
+
translation = resource.translate(
|
51
|
+
file: sample_audio,
|
52
|
+
model: 'model-id',
|
53
|
+
prompt: 'Hello, my name is Wolfgang and I come from Germany. Where are you heading today?',
|
54
|
+
response_format: 'text',
|
55
|
+
temperature: 0.5
|
56
|
+
)
|
57
|
+
|
58
|
+
expect(http)
|
59
|
+
.to have_received(:post)
|
60
|
+
.with(
|
61
|
+
'https://api.openai.com/v1/audio/translations',
|
62
|
+
hash_including(
|
63
|
+
form: hash_including(
|
64
|
+
{
|
65
|
+
file: instance_of(HTTP::FormData::File),
|
66
|
+
model: 'model-id',
|
67
|
+
prompt: 'Hello, my name is Wolfgang and I come from Germany. Where are you heading today?',
|
68
|
+
response_format: 'text',
|
69
|
+
temperature: 0.5
|
70
|
+
}
|
71
|
+
)
|
72
|
+
)
|
73
|
+
)
|
74
|
+
|
75
|
+
expect(translation.text).to eql('Hello, my name is Wolfgang and I come from Germany. Where are you heading today?')
|
76
|
+
end
|
77
|
+
end
|
78
|
+
end
|
@@ -0,0 +1,115 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
RSpec.describe OpenAI::API::Cache do
|
4
|
+
let(:cached_client) do
|
5
|
+
described_class.new(client, cache_strategy)
|
6
|
+
end
|
7
|
+
|
8
|
+
let(:client) do
|
9
|
+
instance_double(OpenAI::API::Client, api_key: 'sk-123').tap do |double|
|
10
|
+
%i[get post post_form_multipart delete].each do |method|
|
11
|
+
allow(double).to receive(method).and_return(api_resource)
|
12
|
+
end
|
13
|
+
end
|
14
|
+
end
|
15
|
+
|
16
|
+
let(:api_resource) do
|
17
|
+
JSON.dump(text: 'Wow neat')
|
18
|
+
end
|
19
|
+
|
20
|
+
let(:cache_strategy) do
|
21
|
+
described_class::Strategy::Memory.new
|
22
|
+
end
|
23
|
+
|
24
|
+
it 'wraps the public API of API::Client' do
|
25
|
+
client_public_api =
|
26
|
+
OpenAI::API::Client.public_instance_methods(false) - %i[api_key inspect]
|
27
|
+
|
28
|
+
client_public_api.each do |client_method|
|
29
|
+
expect(cached_client).to respond_to(client_method)
|
30
|
+
end
|
31
|
+
end
|
32
|
+
|
33
|
+
it 'can cache get requests' do
|
34
|
+
cached_client.get('/v1/foo')
|
35
|
+
cached_client.get('/v1/foo')
|
36
|
+
cached_client.get('/v1/bar')
|
37
|
+
|
38
|
+
expect(client).to have_received(:get).with('/v1/foo').once
|
39
|
+
expect(client).to have_received(:get).with('/v1/bar').once
|
40
|
+
end
|
41
|
+
|
42
|
+
it 'can cache JSON post requests' do
|
43
|
+
cached_client.post('/v1/foo', model: 'model1', prompt: 'prompt1') # miss
|
44
|
+
cached_client.post('/v1/foo', model: 'model1', prompt: 'prompt1') # hit
|
45
|
+
cached_client.post('/v1/foo', model: 'model1', prompt: 'prompt2') # miss
|
46
|
+
cached_client.post('/v1/bar', model: 'model1', prompt: 'prompt2') # miss
|
47
|
+
cached_client.post_form_multipart('/v1/foo', model: 'model1', prompt: 'prompt1') # miss
|
48
|
+
|
49
|
+
expect(client).to have_received(:post).thrice
|
50
|
+
expect(client).to have_received(:post_form_multipart).once
|
51
|
+
end
|
52
|
+
|
53
|
+
it 'does not cache delete requests' do
|
54
|
+
cached_client.delete('/v1/foo')
|
55
|
+
cached_client.delete('/v1/foo')
|
56
|
+
|
57
|
+
expect(client).to have_received(:delete).twice
|
58
|
+
end
|
59
|
+
|
60
|
+
it 'can cache multipart form post requests' do
|
61
|
+
cached_client.post_form_multipart('/v1/foo', model: 'model1', prompt: 'prompt1') # miss
|
62
|
+
cached_client.post_form_multipart('/v1/foo', model: 'model1', prompt: 'prompt1') # hit
|
63
|
+
cached_client.post_form_multipart('/v1/foo', model: 'model1', prompt: 'prompt2') # miss
|
64
|
+
cached_client.post_form_multipart('/v1/bar', model: 'model1', prompt: 'prompt2') # miss
|
65
|
+
cached_client.post('/v1/foo', model: 'model1', prompt: 'prompt1') # miss
|
66
|
+
|
67
|
+
expect(client).to have_received(:post_form_multipart).thrice
|
68
|
+
end
|
69
|
+
|
70
|
+
it 'writes unique and somewhat human readable cache keys' do
|
71
|
+
expect(cache_strategy.cached?('get_foo_9bfe1439')).to be(false)
|
72
|
+
cached_client.get('/v1/foo')
|
73
|
+
expect(cache_strategy.cached?('get_foo_9bfe1439')).to be(true)
|
74
|
+
end
|
75
|
+
|
76
|
+
it 'returns identical values for cache hits and misses' do
|
77
|
+
miss = cached_client.get('/v1/foo')
|
78
|
+
hit = cached_client.get('/v1/foo')
|
79
|
+
|
80
|
+
expect(miss).to eq(hit)
|
81
|
+
end
|
82
|
+
|
83
|
+
context 'when the API key changes' do
|
84
|
+
before do
|
85
|
+
allow(client).to receive(:api_key).and_return('sk-123', 'sk-123', 'sk-456')
|
86
|
+
end
|
87
|
+
|
88
|
+
it 'factors the API key into the cache calculation' do
|
89
|
+
cached_client.get('/v1/foo')
|
90
|
+
cached_client.get('/v1/foo')
|
91
|
+
cached_client.get('/v1/foo')
|
92
|
+
|
93
|
+
expect(client).to have_received(:get).with('/v1/foo').twice
|
94
|
+
end
|
95
|
+
end
|
96
|
+
|
97
|
+
context 'when using the filesystem cache strategy' do
|
98
|
+
let(:cache_strategy) do
|
99
|
+
described_class::Strategy::FileSystem.new(cache_dir)
|
100
|
+
end
|
101
|
+
|
102
|
+
let(:cache_dir) do
|
103
|
+
Pathname.new(Dir.mktmpdir)
|
104
|
+
end
|
105
|
+
|
106
|
+
it 'writes JSON files' do
|
107
|
+
cache_path = cache_dir.join('get_foo_9bfe1439.json')
|
108
|
+
expect(cache_path.exist?).to be(false)
|
109
|
+
cached_client.get('/v1/foo')
|
110
|
+
expect(cache_path.exist?).to be(true)
|
111
|
+
|
112
|
+
expect(cache_strategy.read('get_foo_9bfe1439')).to eq(api_resource)
|
113
|
+
end
|
114
|
+
end
|
115
|
+
end
|