openai.rb 0.0.0 → 0.0.3

Sign up to get free protection for your applications and to get access to all the features.
Files changed (41) hide show
  1. checksums.yaml +4 -4
  2. data/.github/workflows/main.yml +27 -0
  3. data/.rubocop.yml +18 -0
  4. data/.ruby-version +1 -1
  5. data/Gemfile +9 -5
  6. data/Gemfile.lock +29 -24
  7. data/README.md +401 -0
  8. data/bin/console +9 -4
  9. data/lib/openai/api/cache.rb +137 -0
  10. data/lib/openai/api/client.rb +86 -0
  11. data/lib/openai/api/resource.rb +232 -0
  12. data/lib/openai/api/response.rb +384 -0
  13. data/lib/openai/api.rb +75 -0
  14. data/lib/openai/chat.rb +125 -0
  15. data/lib/openai/tokenizer.rb +50 -0
  16. data/lib/openai/util.rb +47 -0
  17. data/lib/openai/version.rb +1 -1
  18. data/lib/openai.rb +38 -357
  19. data/openai.gemspec +9 -3
  20. data/spec/data/sample_french.mp3 +0 -0
  21. data/spec/data/sample_image.png +0 -0
  22. data/spec/data/sample_image_mask.png +0 -0
  23. data/spec/shared/api_resource_context.rb +22 -0
  24. data/spec/spec_helper.rb +4 -0
  25. data/spec/unit/openai/api/audio_spec.rb +78 -0
  26. data/spec/unit/openai/api/cache_spec.rb +115 -0
  27. data/spec/unit/openai/api/chat_completions_spec.rb +130 -0
  28. data/spec/unit/openai/api/completions_spec.rb +125 -0
  29. data/spec/unit/openai/api/edits_spec.rb +40 -0
  30. data/spec/unit/openai/api/embeddings_spec.rb +45 -0
  31. data/spec/unit/openai/api/files_spec.rb +163 -0
  32. data/spec/unit/openai/api/fine_tunes_spec.rb +322 -0
  33. data/spec/unit/openai/api/images_spec.rb +137 -0
  34. data/spec/unit/openai/api/models_spec.rb +98 -0
  35. data/spec/unit/openai/api/moderations_spec.rb +63 -0
  36. data/spec/unit/openai/api/response_spec.rb +203 -0
  37. data/spec/unit/openai/chat_spec.rb +32 -0
  38. data/spec/unit/openai/tokenizer_spec.rb +45 -0
  39. data/spec/unit/openai_spec.rb +47 -736
  40. metadata +97 -7
  41. data/bin/codegen +0 -371
@@ -0,0 +1,384 @@
1
+ # frozen_string_literal: true
2
+
3
+ class OpenAI
4
+ class API
5
+ class Response
6
+ include Concord.new(:internal_data)
7
+ include AbstractType
8
+
9
+ class MissingFieldError < StandardError
10
+ include Anima.new(:path, :missing_key, :actual_payload)
11
+
12
+ def message
13
+ <<~ERROR
14
+ Missing field #{missing_key.inspect} in response payload!
15
+ Was attempting to access value at path `#{path}`.
16
+ Payload: #{JSON.pretty_generate(actual_payload)}
17
+ ERROR
18
+ end
19
+ end
20
+
21
+ class << self
22
+ private
23
+
24
+ attr_accessor :field_registry
25
+ end
26
+
27
+ def self.register_field(field_name)
28
+ self.field_registry ||= []
29
+ field_registry << field_name
30
+ end
31
+
32
+ def self.from_json(raw_json)
33
+ new(JSON.parse(raw_json, symbolize_names: true))
34
+ end
35
+
36
+ def initialize(internal_data)
37
+ super(IceNine.deep_freeze(internal_data))
38
+ end
39
+
40
+ def self.field(name, path: [name], wrapper: nil)
41
+ register_field(name)
42
+
43
+ define_method(name) do
44
+ field(path, wrapper: wrapper)
45
+ end
46
+ end
47
+
48
+ def self.optional_field(name, path: name, wrapper: nil)
49
+ register_field(name)
50
+
51
+ define_method(name) do
52
+ optional_field(path, wrapper: wrapper)
53
+ end
54
+ end
55
+
56
+ def original_payload
57
+ internal_data
58
+ end
59
+
60
+ def inspect
61
+ attr_list = field_list.map do |field_name|
62
+ "#{field_name}=#{__send__(field_name).inspect}"
63
+ end.join(' ')
64
+ "#<#{self.class} #{attr_list}>"
65
+ end
66
+
67
+ private
68
+
69
+ # We need to access the registry list from the instance for `#inspect`.
70
+ # It is just private in terms of the public API which is why we do this
71
+ # weird private dispatch on our own class.
72
+ def field_list
73
+ self.class.__send__(:field_registry)
74
+ end
75
+
76
+ def optional_field(key_path, wrapper: nil)
77
+ *head, tail = key_path
78
+
79
+ parent = field(head)
80
+ return unless parent.key?(tail)
81
+
82
+ wrap_value(parent.fetch(tail), wrapper)
83
+ end
84
+
85
+ def field(key_path, wrapper: nil)
86
+ value = key_path.reduce(internal_data) do |object, key|
87
+ object.fetch(key) do
88
+ raise MissingFieldError.new(
89
+ path: key_path,
90
+ missing_key: key,
91
+ actual_payload: internal_data
92
+ )
93
+ end
94
+ end
95
+
96
+ wrap_value(value, wrapper)
97
+ end
98
+
99
+ def wrap_value(value, wrapper)
100
+ return value unless wrapper
101
+
102
+ if value.instance_of?(Array)
103
+ value.map { |item| wrapper.new(item) }
104
+ else
105
+ wrapper.new(value)
106
+ end
107
+ end
108
+
109
+ class Usage < Response
110
+ field :prompt_tokens
111
+ field :completion_tokens
112
+ field :total_tokens
113
+ end
114
+
115
+ class Completion < Response
116
+ class Choice < Response
117
+ field :text
118
+ field :index
119
+ field :logprobs
120
+ field :finish_reason
121
+ end
122
+
123
+ field :id
124
+ field :object
125
+ field :created
126
+ field :model
127
+ field :choices, wrapper: Choice
128
+ optional_field :usage, wrapper: Usage
129
+
130
+ def choice
131
+ Util.one(choices)
132
+ end
133
+
134
+ # This is a convenience method for getting the response text when there is exactly
135
+ # one choice.
136
+ #
137
+ # @see #response
138
+ def response_text
139
+ choice.text
140
+ end
141
+ end
142
+
143
+ class ChatCompletion < Response
144
+ class Choice < Response
145
+ class Message < Response
146
+ field :role
147
+ field :content
148
+ end
149
+
150
+ field :index
151
+ field :message, wrapper: Message
152
+ field :finish_reason
153
+ end
154
+
155
+ field :id
156
+ field :object
157
+ field :created
158
+ field :choices, wrapper: Choice
159
+ field :usage, wrapper: Usage
160
+
161
+ # This is a convenience method for the common use case where you have exactly
162
+ # one choice and you want to get the message out.
163
+ #
164
+ # @see #response_text
165
+ def response
166
+ Util.one(choices).message
167
+ end
168
+
169
+ # This is a convenience method for getting the response text when there is exactly
170
+ # one choice.
171
+ #
172
+ # @see #response
173
+ def response_text
174
+ response.content
175
+ end
176
+ end
177
+
178
+ class ChatCompletionChunk < Response
179
+ class Delta < Response
180
+ optional_field :role
181
+ optional_field :_content, path: %i[content]
182
+
183
+ def content
184
+ _content.to_s
185
+ end
186
+ end
187
+
188
+ class Choice < Response
189
+ field :delta, wrapper: Delta
190
+ end
191
+
192
+ field :id
193
+ field :object
194
+ field :created
195
+ field :model
196
+ field :choices, wrapper: Choice
197
+
198
+ # This is a convenience method for the common use case where you have exactly
199
+ # one choice and you want to get the message out.
200
+ #
201
+ # @see #response_text
202
+ def response
203
+ Util.one(choices).delta
204
+ end
205
+
206
+ # This is a convenience method for getting the response text when there is exactly
207
+ # one choice.
208
+ #
209
+ # @see #response
210
+ def response_text
211
+ response.content
212
+ end
213
+ end
214
+
215
+ class Embedding < Response
216
+ class EmbeddingData < Response
217
+ field :object
218
+ field :embedding
219
+ field :index
220
+ end
221
+
222
+ class Usage < Response
223
+ field :prompt_tokens
224
+ field :total_tokens
225
+ end
226
+
227
+ field :object
228
+ field :data, wrapper: EmbeddingData
229
+ field :model
230
+ field :usage, wrapper: Usage
231
+ end
232
+
233
+ class Model < Response
234
+ field :id
235
+ field :object
236
+ field :owned_by
237
+ field :permission
238
+ end
239
+
240
+ class Moderation < Response
241
+ class Category < Response
242
+ field :hate
243
+ field :hate_threatening, path: %i[hate/threatening]
244
+ field :self_harm, path: %i[self-harm]
245
+ field :sexual
246
+ field :sexual_minors, path: %i[sexual/minors]
247
+ field :violence
248
+ field :violence_graphic, path: %i[violence/graphic]
249
+ end
250
+
251
+ class CategoryScore < Response
252
+ field :hate
253
+ field :hate_threatening, path: %i[hate/threatening]
254
+ field :self_harm, path: %i[self-harm]
255
+ field :sexual
256
+ field :sexual_minors, path: %i[sexual/minors]
257
+ field :violence
258
+ field :violence_graphic, path: %i[violence/graphic]
259
+ end
260
+
261
+ class Result < Response
262
+ field :categories, wrapper: Category
263
+ field :category_scores, wrapper: CategoryScore
264
+ field :flagged
265
+ end
266
+
267
+ field :id
268
+ field :model
269
+ field :results, wrapper: Result
270
+ end
271
+
272
+ class ListModel < Response
273
+ field :data, wrapper: Model
274
+ end
275
+
276
+ class Edit < Response
277
+ class Choice < Response
278
+ field :text
279
+ field :index
280
+ end
281
+
282
+ field :object
283
+ field :created
284
+ field :choices, wrapper: Choice
285
+ field :usage, wrapper: Usage
286
+ end
287
+
288
+ class ImageGeneration < Response
289
+ class Image < Response
290
+ field :url
291
+ end
292
+
293
+ field :created
294
+ field :data, wrapper: Image
295
+ end
296
+
297
+ class ImageEdit < Response
298
+ class ImageEditData < Response
299
+ field :url
300
+ end
301
+
302
+ field :created
303
+ field :data, wrapper: ImageEditData
304
+ end
305
+
306
+ class ImageVariation < Response
307
+ class ImageVariationData < Response
308
+ field :url
309
+ end
310
+
311
+ field :created
312
+ field :data, wrapper: ImageVariationData
313
+ end
314
+
315
+ class File < Response
316
+ field :id
317
+ field :object
318
+ field :bytes
319
+ field :created_at
320
+ field :filename
321
+ field :purpose
322
+ optional_field :deleted?, path: :deleted
323
+ end
324
+
325
+ class FileList < Response
326
+ field :data, wrapper: File
327
+ field :object
328
+ end
329
+
330
+ class FineTune < Response
331
+ class Event < Response
332
+ field :object
333
+ field :created_at
334
+ field :level
335
+ field :message
336
+ end
337
+
338
+ class Hyperparams < Response
339
+ field :batch_size
340
+ field :learning_rate_multiplier
341
+ field :n_epochs
342
+ field :prompt_loss_weight
343
+ end
344
+
345
+ class File < Response
346
+ field :id
347
+ field :object
348
+ field :bytes
349
+ field :created_at
350
+ field :filename
351
+ field :purpose
352
+ end
353
+
354
+ field :id
355
+ field :object
356
+ field :model
357
+ field :created_at
358
+ field :events, wrapper: Event
359
+ field :fine_tuned_model
360
+ field :hyperparams, wrapper: Hyperparams
361
+ field :organization_id
362
+ field :result_files, wrapper: File
363
+ field :status
364
+ field :validation_files, wrapper: File
365
+ field :training_files, wrapper: File
366
+ field :updated_at
367
+ end
368
+
369
+ class FineTuneList < Response
370
+ field :object
371
+ field :data, wrapper: FineTune
372
+ end
373
+
374
+ class FineTuneEventList < Response
375
+ field :data, wrapper: FineTune::Event
376
+ field :object
377
+ end
378
+
379
+ class Transcription < Response
380
+ field :text
381
+ end
382
+ end
383
+ end
384
+ end
data/lib/openai/api.rb ADDED
@@ -0,0 +1,75 @@
1
+ # frozen_string_literal: true
2
+
3
+ class OpenAI
4
+ class API
5
+ include Concord.new(:client)
6
+
7
+ class Error < StandardError
8
+ include Concord::Public.new(:http_response)
9
+
10
+ def self.parse(http_response)
11
+ data = JSON.parse(http_response.body.to_s, symbolize_names: true)
12
+ if data.dig(:error, :code) == 'context_length_exceeded'
13
+ Error::ContextLengthExceeded.new(http_response)
14
+ else
15
+ new(http_response)
16
+ end
17
+ rescue JSON::ParserError
18
+ new(http_response)
19
+ end
20
+
21
+ def message
22
+ <<~ERROR
23
+ Unexpected response status! Expected 2xx but got: #{http_response.status}
24
+
25
+ Body:
26
+
27
+ #{http_response.body}
28
+ ERROR
29
+ end
30
+
31
+ class ContextLengthExceeded < self
32
+ end
33
+ end
34
+
35
+ def completions
36
+ API::Completion.new(client)
37
+ end
38
+
39
+ def chat_completions
40
+ API::ChatCompletion.new(client)
41
+ end
42
+
43
+ def embeddings
44
+ API::Embedding.new(client)
45
+ end
46
+
47
+ def models
48
+ API::Model.new(client)
49
+ end
50
+
51
+ def edits
52
+ API::Edit.new(client)
53
+ end
54
+
55
+ def files
56
+ API::File.new(client)
57
+ end
58
+
59
+ def fine_tunes
60
+ API::FineTune.new(client)
61
+ end
62
+
63
+ def images
64
+ API::Image.new(client)
65
+ end
66
+
67
+ def audio
68
+ API::Audio.new(client)
69
+ end
70
+
71
+ def moderations
72
+ API::Moderation.new(client)
73
+ end
74
+ end
75
+ end
@@ -0,0 +1,125 @@
1
+ # frozen_string_literal: true
2
+
3
+ class OpenAI
4
+ class Chat
5
+ include Anima.new(:messages, :api_settings, :openai, :config)
6
+ using Util::Colorize
7
+
8
+ def initialize(messages:, settings: {}, config: Config.create, **kwargs)
9
+ messages = messages.map do |msg|
10
+ if msg.is_a?(Hash)
11
+ Message.new(msg)
12
+ else
13
+ msg
14
+ end
15
+ end
16
+
17
+ super(
18
+ messages: messages,
19
+ api_settings: settings,
20
+ config: config,
21
+ **kwargs
22
+ )
23
+ end
24
+
25
+ def configure(**configuration)
26
+ with(config: config.with(configuration))
27
+ end
28
+
29
+ def add_user_message(message)
30
+ add_message('user', message)
31
+ end
32
+ alias user add_user_message
33
+
34
+ def add_system_message(message)
35
+ add_message('system', message)
36
+ end
37
+ alias system add_system_message
38
+
39
+ def add_assistant_message(message)
40
+ add_message('assistant', message)
41
+ end
42
+ alias assistant add_assistant_message
43
+
44
+ def submit
45
+ openai.logger.info("[Chat] [tokens=#{total_tokens}] Submitting messages:\n\n#{to_log_format}")
46
+
47
+ begin
48
+ response = openai.api.chat_completions.create(
49
+ **api_settings,
50
+ messages: raw_messages
51
+ )
52
+ rescue OpenAI::API::Error::ContextLengthExceeded
53
+ raise 'Context length exceeded.'
54
+ openai.logger.warn('[Chat] Context length exceeded. Shifting chat')
55
+ return shift_history.submit
56
+ end
57
+
58
+ msg = response.choices.first.message
59
+
60
+ add_message(msg.role, msg.content).tap do |new_chat|
61
+ openai.logger.info("[Chat] Response:\n\n#{new_chat.last_message.to_log_format(config)}")
62
+ end
63
+ end
64
+
65
+ def last_message
66
+ messages.last
67
+ end
68
+
69
+ def to_log_format
70
+ messages.map do |msg|
71
+ msg.to_log_format(config)
72
+ end.join("\n\n")
73
+ end
74
+
75
+ private
76
+
77
+ def shift_history
78
+ drop_index = messages.index { |msg| msg.role != 'system' }
79
+ new_messages = messages.slice(0...drop_index) + messages.slice((drop_index + 1)..)
80
+
81
+ with(messages: new_messages)
82
+ end
83
+
84
+ def total_tokens
85
+ openai.tokens.for_model(api_settings.fetch(:model)).num_tokens(messages.map(&:content).join(' '))
86
+ end
87
+
88
+ def raw_messages
89
+ messages.map(&:to_h)
90
+ end
91
+
92
+ def add_message(role, content)
93
+ with_message(role: role, content: content)
94
+ end
95
+
96
+ def with_message(message)
97
+ with(messages: messages + [message])
98
+ end
99
+
100
+ class Config
101
+ include Anima.new(:assistant_name)
102
+
103
+ def self.create
104
+ new(assistant_name: 'assistant')
105
+ end
106
+ end
107
+
108
+ class Message
109
+ include Anima.new(:role, :content)
110
+
111
+ def to_log_format(config)
112
+ prefix =
113
+ case role
114
+ when 'user' then "#{role}:".upcase.green
115
+ when 'system' then "#{role}:".upcase.yellow
116
+ when 'assistant' then "#{config.assistant_name}:".upcase.red
117
+ else
118
+ raise "Unknown role: #{role}"
119
+ end
120
+
121
+ "#{prefix} #{content}"
122
+ end
123
+ end
124
+ end
125
+ end
@@ -0,0 +1,50 @@
1
+ # frozen_string_literal: true
2
+
3
+ class OpenAI
4
+ class Tokenizer
5
+ include Equalizer.new
6
+
7
+ UnknownModel = Class.new(StandardError)
8
+ UnknownEncoding = Class.new(StandardError)
9
+
10
+ def for_model(model)
11
+ encoding = Tiktoken.encoding_for_model(model)
12
+ raise UnknownModel, "Invalid model name or not recognized by Tiktoken: #{model.inspect}" if encoding.nil?
13
+
14
+ Encoding.new(encoding.name)
15
+ end
16
+
17
+ def get(encoding_name)
18
+ encoding = Tiktoken.get_encoding(encoding_name)
19
+ if encoding.nil?
20
+ raise UnknownEncoding,
21
+ "Invalid encoding name or not recognized by Tiktoken: #{encoding_name.inspect}"
22
+ end
23
+
24
+ Encoding.new(encoding.name)
25
+ end
26
+
27
+ class Encoding
28
+ include Concord.new(:name)
29
+
30
+ def encode(text)
31
+ encoder.encode(text)
32
+ end
33
+ alias tokenize encode
34
+
35
+ def decode(tokens)
36
+ encoder.decode(tokens)
37
+ end
38
+
39
+ def num_tokens(text)
40
+ encode(text).size
41
+ end
42
+
43
+ private
44
+
45
+ def encoder
46
+ Tiktoken.get_encoding(name)
47
+ end
48
+ end
49
+ end
50
+ end
@@ -0,0 +1,47 @@
1
+ # frozen_string_literal: true
2
+
3
+ class OpenAI
4
+ module Util
5
+ OneError = Class.new(ArgumentError)
6
+
7
+ def self.one(list)
8
+ raise OneError, "Expected exactly one element, got #{list.size}" unless list.size == 1
9
+
10
+ list.first
11
+ end
12
+
13
+ module Colorize
14
+ refine String do
15
+ def red
16
+ colorize(31)
17
+ end
18
+
19
+ def green
20
+ colorize(32)
21
+ end
22
+
23
+ def yellow
24
+ colorize(33)
25
+ end
26
+
27
+ def blue
28
+ colorize(34)
29
+ end
30
+
31
+ def magenta
32
+ colorize(35)
33
+ end
34
+
35
+ def cyan
36
+ colorize(36)
37
+ end
38
+
39
+ private
40
+
41
+ def colorize(color_code)
42
+ "\e[#{color_code}m#{self}\e[0m"
43
+ end
44
+ end
45
+ end
46
+ end
47
+ end
@@ -1,5 +1,5 @@
1
1
  # frozen_string_literal: true
2
2
 
3
3
  class OpenAI
4
- VERSION = '0.0.0'
4
+ VERSION = '0.0.3'
5
5
  end