openai.rb 0.0.0 → 0.0.3

Sign up to get free protection for your applications and to get access to all the features.
Files changed (41) hide show
  1. checksums.yaml +4 -4
  2. data/.github/workflows/main.yml +27 -0
  3. data/.rubocop.yml +18 -0
  4. data/.ruby-version +1 -1
  5. data/Gemfile +9 -5
  6. data/Gemfile.lock +29 -24
  7. data/README.md +401 -0
  8. data/bin/console +9 -4
  9. data/lib/openai/api/cache.rb +137 -0
  10. data/lib/openai/api/client.rb +86 -0
  11. data/lib/openai/api/resource.rb +232 -0
  12. data/lib/openai/api/response.rb +384 -0
  13. data/lib/openai/api.rb +75 -0
  14. data/lib/openai/chat.rb +125 -0
  15. data/lib/openai/tokenizer.rb +50 -0
  16. data/lib/openai/util.rb +47 -0
  17. data/lib/openai/version.rb +1 -1
  18. data/lib/openai.rb +38 -357
  19. data/openai.gemspec +9 -3
  20. data/spec/data/sample_french.mp3 +0 -0
  21. data/spec/data/sample_image.png +0 -0
  22. data/spec/data/sample_image_mask.png +0 -0
  23. data/spec/shared/api_resource_context.rb +22 -0
  24. data/spec/spec_helper.rb +4 -0
  25. data/spec/unit/openai/api/audio_spec.rb +78 -0
  26. data/spec/unit/openai/api/cache_spec.rb +115 -0
  27. data/spec/unit/openai/api/chat_completions_spec.rb +130 -0
  28. data/spec/unit/openai/api/completions_spec.rb +125 -0
  29. data/spec/unit/openai/api/edits_spec.rb +40 -0
  30. data/spec/unit/openai/api/embeddings_spec.rb +45 -0
  31. data/spec/unit/openai/api/files_spec.rb +163 -0
  32. data/spec/unit/openai/api/fine_tunes_spec.rb +322 -0
  33. data/spec/unit/openai/api/images_spec.rb +137 -0
  34. data/spec/unit/openai/api/models_spec.rb +98 -0
  35. data/spec/unit/openai/api/moderations_spec.rb +63 -0
  36. data/spec/unit/openai/api/response_spec.rb +203 -0
  37. data/spec/unit/openai/chat_spec.rb +32 -0
  38. data/spec/unit/openai/tokenizer_spec.rb +45 -0
  39. data/spec/unit/openai_spec.rb +47 -736
  40. metadata +97 -7
  41. data/bin/codegen +0 -371
data/lib/openai.rb CHANGED
@@ -1,381 +1,62 @@
1
1
  # frozen_string_literal: true
2
2
 
3
+ require 'pathname'
4
+ require 'logger'
5
+
3
6
  require 'concord'
4
7
  require 'anima'
8
+ require 'abstract_type'
5
9
  require 'http'
6
10
  require 'addressable'
7
-
11
+ require 'ice_nine'
12
+ require 'tiktoken_ruby'
13
+
14
+ require 'openai/util'
15
+ require 'openai/tokenizer'
16
+ require 'openai/chat'
17
+ require 'openai/api'
18
+ require 'openai/api/cache'
19
+ require 'openai/api/client'
20
+ require 'openai/api/resource'
21
+ require 'openai/api/response'
8
22
  require 'openai/version'
9
23
 
10
24
  class OpenAI
11
- include Concord.new(:api_key, :http)
12
-
13
- ResponseError = Class.new(StandardError)
14
-
15
- HOST = Addressable::URI.parse('https://api.openai.com/v1')
25
+ include Concord.new(:api_client, :logger)
16
26
 
17
- def initialize(api_key, http: HTTP)
18
- super(api_key, http)
19
- end
20
-
21
- def create_completion(model:, **kwargs)
22
- Response::Completion.from_json(
23
- post('/v1/completions', model: model, **kwargs)
24
- )
25
- end
27
+ public :logger
26
28
 
27
- def create_chat_completion(model:, messages:, **kwargs)
28
- Response::ChatCompletion.from_json(
29
- post('/v1/chat/completions', model: model, messages: messages, **kwargs)
30
- )
31
- end
32
-
33
- def create_embedding(model:, input:, **kwargs)
34
- Response::Embedding.from_json(
35
- post('/v1/embeddings', model: model, input: input, **kwargs)
36
- )
37
- end
38
-
39
- def list_models
40
- Response::ListModel.from_json(get('/v1/models'))
41
- end
29
+ ROOT = Pathname.new(__dir__).parent.expand_path.freeze
42
30
 
43
- def get_model(model_id)
44
- Response::Model.from_json(
45
- get("/v1/models/#{model_id}")
46
- )
47
- end
48
-
49
- def create_edit(model:, instruction:, **kwargs)
50
- Response::Edit.from_json(
51
- post('/v1/edits', model: model, instruction: instruction, **kwargs)
52
- )
53
- end
54
-
55
- def create_image_generation(prompt:, **kwargs)
56
- Response::ImageGeneration.from_json(
57
- post('/v1/images/generations', prompt: prompt, **kwargs)
58
- )
59
- end
60
-
61
- def create_file(file:, purpose:)
62
- absolute_path = Pathname.new(file).expand_path.to_s
63
- form_file = HTTP::FormData::File.new(absolute_path)
64
- Response::File.from_json(
65
- post_form_multipart('/v1/files', file: form_file, purpose: purpose)
66
- )
67
- end
68
-
69
- def list_files
70
- Response::FileList.from_json(
71
- get('/v1/files')
72
- )
73
- end
31
+ def self.create(api_key, cache: nil, logger: Logger.new('/dev/null'))
32
+ client = API::Client.new(api_key)
74
33
 
75
- def delete_file(file_id)
76
- Response::File.from_json(
77
- delete("/v1/files/#{file_id}")
78
- )
79
- end
80
-
81
- def get_file(file_id)
82
- Response::File.from_json(
83
- get("/v1/files/#{file_id}")
84
- )
85
- end
86
-
87
- def get_file_content(file_id)
88
- get("/v1/files/#{file_id}/content")
89
- end
90
-
91
- def list_fine_tunes
92
- Response::FineTuneList.from_json(
93
- get('/v1/fine-tunes')
94
- )
95
- end
96
-
97
- def create_fine_tune(training_file:, **kwargs)
98
- Response::FineTune.from_json(
99
- post('/v1/fine-tunes', training_file: training_file, **kwargs)
100
- )
101
- end
102
-
103
- def get_fine_tune(fine_tune_id)
104
- Response::FineTune.from_json(
105
- get("/v1/fine-tunes/#{fine_tune_id}")
106
- )
107
- end
108
-
109
- def cancel_fine_tune(fine_tune_id)
110
- Response::FineTune.from_json(
111
- post("/v1/fine-tunes/#{fine_tune_id}/cancel")
112
- )
113
- end
114
-
115
- def transcribe_audio(file:, model:, **kwargs)
116
- absolute_path = Pathname.new(file).expand_path.to_s
117
- form_file = HTTP::FormData::File.new(absolute_path)
118
- Response::Transcription.from_json(
119
- post_form_multipart(
120
- '/v1/audio/transcriptions',
121
- file: form_file,
122
- model: model,
123
- **kwargs
34
+ if cache.is_a?(Pathname) && cache.directory?
35
+ client = API::Cache.new(
36
+ client,
37
+ API::Cache::Strategy::FileSystem.new(cache)
124
38
  )
125
- )
126
- end
127
-
128
- def inspect
129
- "#<#{self.class}>"
130
- end
131
-
132
- private
133
-
134
- def get(route)
135
- unwrap_response(json_http_client.get(url_for(route)))
136
- end
137
-
138
- def delete(route)
139
- unwrap_response(json_http_client.delete(url_for(route)))
140
- end
141
-
142
- def post(route, **body)
143
- unwrap_response(json_http_client.post(url_for(route), json: body))
144
- end
145
-
146
- def post_form_multipart(route, **body)
147
- unwrap_response(http_client.post(url_for(route), form: body))
148
- end
149
-
150
- def url_for(route)
151
- HOST.join(route).to_str
152
- end
153
-
154
- def unwrap_response(response)
155
- unless response.status.success?
156
- raise ResponseError, "Unexpected response #{response.status}\nBody:\n#{response.body}"
157
39
  end
158
40
 
159
- response.body.to_s
41
+ new(client, logger)
160
42
  end
161
43
 
162
- def json_http_client
163
- http_client.headers('Content-Type' => 'application/json')
164
- end
44
+ private_class_method :new
165
45
 
166
- def http_client
167
- http.headers('Authorization' => "Bearer #{api_key}")
46
+ def api
47
+ API.new(api_client)
168
48
  end
169
49
 
170
- class Response
171
- class JSONPayload
172
- include Concord.new(:internal_data)
173
-
174
- def self.from_json(raw_json)
175
- new(JSON.parse(raw_json, symbolize_names: true))
176
- end
177
-
178
- def self.field(name, path: [name], wrapper: nil)
179
- given_wrapper = wrapper
180
- define_method(name) do
181
- field(path, wrapper: given_wrapper)
182
- end
183
- end
184
-
185
- def self.optional_field(name, path: name)
186
- define_method(name) do
187
- optional_field(path)
188
- end
189
- end
190
-
191
- def original_payload
192
- internal_data
193
- end
194
-
195
- private
196
-
197
- def optional_field(key_path)
198
- *head, tail = key_path
199
-
200
- field(head)[tail]
201
- end
202
-
203
- def field(key_path, wrapper: nil)
204
- value = key_path.reduce(internal_data, :fetch)
205
- return value unless wrapper
206
-
207
- if value.is_a?(Array)
208
- value.map { |item| wrapper.new(item) }
209
- else
210
- wrapper.new(value)
211
- end
212
- end
213
- end
214
-
215
- class Completion < JSONPayload
216
- class Choice < JSONPayload
217
- field :text
218
- field :index
219
- field :logprobs
220
- field :finish_reason
221
- end
222
-
223
- class Usage < JSONPayload
224
- field :prompt_tokens
225
- field :completion_tokens
226
- field :total_tokens
227
- end
228
-
229
- field :id
230
- field :object
231
- field :created
232
- field :model
233
- field :choices, wrapper: Choice
234
- field :usage, wrapper: Usage
235
- end
236
-
237
- class ChatCompletion < JSONPayload
238
- class Choice < JSONPayload
239
- class Message < JSONPayload
240
- field :role
241
- field :content
242
- end
243
-
244
- field :index
245
- field :message, wrapper: Message
246
- field :finish_reason
247
- end
248
-
249
- class Usage < JSONPayload
250
- field :prompt_tokens
251
- field :completion_tokens
252
- field :total_tokens
253
- end
254
-
255
- field :id
256
- field :object
257
- field :created
258
- field :choices, wrapper: Choice
259
- field :usage, wrapper: Usage
260
- end
261
-
262
- class Embedding < JSONPayload
263
- class EmbeddingData < JSONPayload
264
- field :object
265
- field :embedding
266
- field :index
267
- end
268
-
269
- class Usage < JSONPayload
270
- field :prompt_tokens
271
- field :total_tokens
272
- end
273
-
274
- field :object
275
- field :data, wrapper: EmbeddingData
276
- field :model
277
- field :usage, wrapper: Usage
278
- end
279
-
280
- class Model < JSONPayload
281
- field :id
282
- field :object
283
- field :owned_by
284
- field :permission
285
- end
286
-
287
- class ListModel < JSONPayload
288
- field :data, wrapper: Model
289
- end
290
-
291
- class Edit < JSONPayload
292
- class Choice < JSONPayload
293
- field :text
294
- field :index
295
- end
296
-
297
- class Usage < JSONPayload
298
- field :prompt_tokens
299
- field :completion_tokens
300
- field :total_tokens
301
- end
302
-
303
- field :object
304
- field :created
305
- field :choices, wrapper: Choice
306
- field :usage, wrapper: Usage
307
- end
308
-
309
- class ImageGeneration < JSONPayload
310
- class Image < JSONPayload
311
- field :url
312
- end
313
-
314
- field :created
315
- field :data, wrapper: Image
316
- end
317
-
318
- class File < JSONPayload
319
- field :id
320
- field :object
321
- field :bytes
322
- field :created_at
323
- field :filename
324
- field :purpose
325
- optional_field :deleted?, path: :deleted
326
- end
327
-
328
- class FileList < JSONPayload
329
- field :data, wrapper: File
330
- field :object
331
- end
332
-
333
- class FineTune < JSONPayload
334
- class Event < JSONPayload
335
- field :object
336
- field :created_at
337
- field :level
338
- field :message
339
- end
340
-
341
- class Hyperparams < JSONPayload
342
- field :batch_size
343
- field :learning_rate_multiplier
344
- field :n_epochs
345
- field :prompt_loss_weight
346
- end
347
-
348
- class File < JSONPayload
349
- field :id
350
- field :object
351
- field :bytes
352
- field :created_at
353
- field :filename
354
- field :purpose
355
- end
356
-
357
- field :id
358
- field :object
359
- field :model
360
- field :created_at
361
- field :events, wrapper: Event
362
- field :fine_tuned_model
363
- field :hyperparams, wrapper: Hyperparams
364
- field :organization_id
365
- field :result_files, wrapper: File
366
- field :status
367
- field :validation_files, wrapper: File
368
- field :training_files, wrapper: File
369
- field :updated_at
370
- end
371
-
372
- class FineTuneList < JSONPayload
373
- field :object
374
- field :data, wrapper: FineTune
375
- end
50
+ def tokenizer
51
+ Tokenizer.new
52
+ end
53
+ alias tokens tokenizer
376
54
 
377
- class Transcription < JSONPayload
378
- field :text
379
- end
55
+ def chat(model:, history: [], **kwargs)
56
+ Chat.new(
57
+ openai: self,
58
+ settings: kwargs.merge(model: model),
59
+ messages: history
60
+ )
380
61
  end
381
62
  end
data/openai.gemspec CHANGED
@@ -16,7 +16,13 @@ Gem::Specification.new do |spec|
16
16
  spec.require_paths = %w[lib]
17
17
  spec.executables = []
18
18
 
19
- spec.add_dependency 'anima', '~> 0.3'
20
- spec.add_dependency 'concord', '~> 0.1'
21
- spec.add_dependency 'http', '~> 5.1'
19
+ spec.required_ruby_version = '>= 2.7'
20
+
21
+ spec.add_dependency 'abstract_type', '~> 0.0.7'
22
+ spec.add_dependency 'anima', '~> 0.3'
23
+ spec.add_dependency 'concord', '~> 0.1'
24
+ spec.add_dependency 'http', '>= 4.4', '< 6.0'
25
+ spec.add_dependency 'ice_nine', '~> 0.11.x'
26
+ spec.add_dependency 'memoizable', '~> 0.4.2'
27
+ spec.add_dependency 'tiktoken_ruby', '~> 0.0.3'
22
28
  end
Binary file
Binary file
Binary file
@@ -0,0 +1,22 @@
1
+ # frozen_string_literal: true
2
+
3
+ RSpec.shared_context 'an API Resource' do
4
+ let(:api) { OpenAI::API.new(api_client) }
5
+ let(:api_client) { OpenAI::API::Client.new('sk-123', http: http) }
6
+ let(:http) { class_spy(HTTP) }
7
+ let(:response_status_code) { 200 }
8
+
9
+ let(:response) do
10
+ instance_double(
11
+ HTTP::Response,
12
+ status: HTTP::Response::Status.new(response_status_code),
13
+ body: JSON.dump(response_body)
14
+ )
15
+ end
16
+
17
+ before do
18
+ allow(http).to receive(:post).and_return(response)
19
+ allow(http).to receive(:get).and_return(response)
20
+ allow(http).to receive(:delete).and_return(response)
21
+ end
22
+ end
data/spec/spec_helper.rb CHANGED
@@ -9,6 +9,10 @@ module OpenAISpec
9
9
  SPEC_ROOT = ROOT.join('spec')
10
10
  end
11
11
 
12
+ OpenAISpec::SPEC_ROOT.glob('shared/*.rb').shuffle.each do |shared_spec|
13
+ require(shared_spec)
14
+ end
15
+
12
16
  RSpec.configure do |config|
13
17
  # Enable focused tests and run all tests if nothing is focused
14
18
  config.filter_run_when_matching(:focus)
@@ -0,0 +1,78 @@
1
+ # frozen_string_literal: true
2
+
3
+ RSpec.describe OpenAI::API, '#audio' do
4
+ include_context 'an API Resource'
5
+
6
+ let(:resource) { api.audio }
7
+ let(:sample_audio) { OpenAISpec::SPEC_ROOT.join('data/sample.mp3') }
8
+
9
+ context 'when transcribing audio' do
10
+ let(:response_body) do
11
+ {
12
+ "text": "Imagine the wildest idea that you've ever had, and you're curious about how it might scale to something that's a 100, a 1,000 times bigger. This is a place where you can get to do that."
13
+ }
14
+ end
15
+
16
+ it 'can transcribe audio' do
17
+ transcription = resource.transcribe(
18
+ file: sample_audio,
19
+ model: 'model-1234'
20
+ )
21
+
22
+ expect(http)
23
+ .to have_received(:post)
24
+ .with(
25
+ 'https://api.openai.com/v1/audio/transcriptions',
26
+ hash_including(
27
+ form: hash_including(
28
+ {
29
+ file: instance_of(HTTP::FormData::File),
30
+ model: 'model-1234'
31
+ }
32
+ )
33
+ )
34
+ )
35
+
36
+ expect(transcription.text).to eql("Imagine the wildest idea that you've ever had, and you're curious about how it might scale to something that's a 100, a 1,000 times bigger. This is a place where you can get to do that.")
37
+ end
38
+ end
39
+
40
+ context 'when translating audio' do
41
+ let(:sample_audio) { OpenAISpec::SPEC_ROOT.join('data/sample_french.mp3') }
42
+
43
+ let(:response_body) do
44
+ {
45
+ "text": 'Hello, my name is Wolfgang and I come from Germany. Where are you heading today?'
46
+ }
47
+ end
48
+
49
+ it 'can translate audio' do
50
+ translation = resource.translate(
51
+ file: sample_audio,
52
+ model: 'model-id',
53
+ prompt: 'Hello, my name is Wolfgang and I come from Germany. Where are you heading today?',
54
+ response_format: 'text',
55
+ temperature: 0.5
56
+ )
57
+
58
+ expect(http)
59
+ .to have_received(:post)
60
+ .with(
61
+ 'https://api.openai.com/v1/audio/translations',
62
+ hash_including(
63
+ form: hash_including(
64
+ {
65
+ file: instance_of(HTTP::FormData::File),
66
+ model: 'model-id',
67
+ prompt: 'Hello, my name is Wolfgang and I come from Germany. Where are you heading today?',
68
+ response_format: 'text',
69
+ temperature: 0.5
70
+ }
71
+ )
72
+ )
73
+ )
74
+
75
+ expect(translation.text).to eql('Hello, my name is Wolfgang and I come from Germany. Where are you heading today?')
76
+ end
77
+ end
78
+ end
@@ -0,0 +1,115 @@
1
+ # frozen_string_literal: true
2
+
3
+ RSpec.describe OpenAI::API::Cache do
4
+ let(:cached_client) do
5
+ described_class.new(client, cache_strategy)
6
+ end
7
+
8
+ let(:client) do
9
+ instance_double(OpenAI::API::Client, api_key: 'sk-123').tap do |double|
10
+ %i[get post post_form_multipart delete].each do |method|
11
+ allow(double).to receive(method).and_return(api_resource)
12
+ end
13
+ end
14
+ end
15
+
16
+ let(:api_resource) do
17
+ JSON.dump(text: 'Wow neat')
18
+ end
19
+
20
+ let(:cache_strategy) do
21
+ described_class::Strategy::Memory.new
22
+ end
23
+
24
+ it 'wraps the public API of API::Client' do
25
+ client_public_api =
26
+ OpenAI::API::Client.public_instance_methods(false) - %i[api_key inspect]
27
+
28
+ client_public_api.each do |client_method|
29
+ expect(cached_client).to respond_to(client_method)
30
+ end
31
+ end
32
+
33
+ it 'can cache get requests' do
34
+ cached_client.get('/v1/foo')
35
+ cached_client.get('/v1/foo')
36
+ cached_client.get('/v1/bar')
37
+
38
+ expect(client).to have_received(:get).with('/v1/foo').once
39
+ expect(client).to have_received(:get).with('/v1/bar').once
40
+ end
41
+
42
+ it 'can cache JSON post requests' do
43
+ cached_client.post('/v1/foo', model: 'model1', prompt: 'prompt1') # miss
44
+ cached_client.post('/v1/foo', model: 'model1', prompt: 'prompt1') # hit
45
+ cached_client.post('/v1/foo', model: 'model1', prompt: 'prompt2') # miss
46
+ cached_client.post('/v1/bar', model: 'model1', prompt: 'prompt2') # miss
47
+ cached_client.post_form_multipart('/v1/foo', model: 'model1', prompt: 'prompt1') # miss
48
+
49
+ expect(client).to have_received(:post).thrice
50
+ expect(client).to have_received(:post_form_multipart).once
51
+ end
52
+
53
+ it 'does not cache delete requests' do
54
+ cached_client.delete('/v1/foo')
55
+ cached_client.delete('/v1/foo')
56
+
57
+ expect(client).to have_received(:delete).twice
58
+ end
59
+
60
+ it 'can cache multipart form post requests' do
61
+ cached_client.post_form_multipart('/v1/foo', model: 'model1', prompt: 'prompt1') # miss
62
+ cached_client.post_form_multipart('/v1/foo', model: 'model1', prompt: 'prompt1') # hit
63
+ cached_client.post_form_multipart('/v1/foo', model: 'model1', prompt: 'prompt2') # miss
64
+ cached_client.post_form_multipart('/v1/bar', model: 'model1', prompt: 'prompt2') # miss
65
+ cached_client.post('/v1/foo', model: 'model1', prompt: 'prompt1') # miss
66
+
67
+ expect(client).to have_received(:post_form_multipart).thrice
68
+ end
69
+
70
+ it 'writes unique and somewhat human readable cache keys' do
71
+ expect(cache_strategy.cached?('get_foo_9bfe1439')).to be(false)
72
+ cached_client.get('/v1/foo')
73
+ expect(cache_strategy.cached?('get_foo_9bfe1439')).to be(true)
74
+ end
75
+
76
+ it 'returns identical values for cache hits and misses' do
77
+ miss = cached_client.get('/v1/foo')
78
+ hit = cached_client.get('/v1/foo')
79
+
80
+ expect(miss).to eq(hit)
81
+ end
82
+
83
+ context 'when the API key changes' do
84
+ before do
85
+ allow(client).to receive(:api_key).and_return('sk-123', 'sk-123', 'sk-456')
86
+ end
87
+
88
+ it 'factors the API key into the cache calculation' do
89
+ cached_client.get('/v1/foo')
90
+ cached_client.get('/v1/foo')
91
+ cached_client.get('/v1/foo')
92
+
93
+ expect(client).to have_received(:get).with('/v1/foo').twice
94
+ end
95
+ end
96
+
97
+ context 'when using the filesystem cache strategy' do
98
+ let(:cache_strategy) do
99
+ described_class::Strategy::FileSystem.new(cache_dir)
100
+ end
101
+
102
+ let(:cache_dir) do
103
+ Pathname.new(Dir.mktmpdir)
104
+ end
105
+
106
+ it 'writes JSON files' do
107
+ cache_path = cache_dir.join('get_foo_9bfe1439.json')
108
+ expect(cache_path.exist?).to be(false)
109
+ cached_client.get('/v1/foo')
110
+ expect(cache_path.exist?).to be(true)
111
+
112
+ expect(cache_strategy.read('get_foo_9bfe1439')).to eq(api_resource)
113
+ end
114
+ end
115
+ end