instruct 0.1.0a1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. checksums.yaml +7 -0
  2. data/LICENSE +202 -0
  3. data/README.md +387 -0
  4. data/SCRATCHPAD.md +489 -0
  5. data/lib/instruct/compile_erb.rb +39 -0
  6. data/lib/instruct/env.rb +27 -0
  7. data/lib/instruct/error.rb +4 -0
  8. data/lib/instruct/gen/completion_request.rb +63 -0
  9. data/lib/instruct/gen/completion_response.rb +66 -0
  10. data/lib/instruct/gen/gen.rb +70 -0
  11. data/lib/instruct/gen/generate_completion.rb +61 -0
  12. data/lib/instruct/helpers/erb_helper.rb +29 -0
  13. data/lib/instruct/helpers/gen_helper.rb +22 -0
  14. data/lib/instruct/helpers/helpers.rb +9 -0
  15. data/lib/instruct/helpers/model_helper.rb +13 -0
  16. data/lib/instruct/helpers/refinements.rb +54 -0
  17. data/lib/instruct/llms/anthropic/completion_model.rb +107 -0
  18. data/lib/instruct/llms/anthropic/messages_completion_response.rb +35 -0
  19. data/lib/instruct/llms/anthropic/middleware.rb +91 -0
  20. data/lib/instruct/llms/openai/chat_completion_response.rb +21 -0
  21. data/lib/instruct/llms/openai/completion_model.rb +129 -0
  22. data/lib/instruct/llms/openai/completion_response.rb +20 -0
  23. data/lib/instruct/llms/openai/middleware.rb +52 -0
  24. data/lib/instruct/middleware/chat_completion_middleware.rb +90 -0
  25. data/lib/instruct/middleware/chomp_middleware.rb +56 -0
  26. data/lib/instruct/model.rb +21 -0
  27. data/lib/instruct/prompt.rb +217 -0
  28. data/lib/instruct/rails/active_job_object_serializer.rb +23 -0
  29. data/lib/instruct/rails/active_record_coders.rb +36 -0
  30. data/lib/instruct/railtie.rb +15 -0
  31. data/lib/instruct/utils/middleware_chain.rb +48 -0
  32. data/lib/instruct/utils/serializable_with_version.rb +73 -0
  33. data/lib/instruct/utils/serializer.rb +70 -0
  34. data/lib/instruct/utils/symbolize_keys.rb +22 -0
  35. data/lib/instruct/utils/variables.rb +37 -0
  36. data/lib/instruct/version.rb +3 -0
  37. data/lib/instruct.rb +74 -0
  38. metadata +122 -0
@@ -0,0 +1,70 @@
1
+ module Instruct
2
+ class Gen
3
+ include Instruct::Serializable
4
+ set_instruct_class_id 2
5
+
6
+ attr_accessor :prompt, :model, :gen_kwargs
7
+ attr_accessor :capture_key, :capture_list_key
8
+ attr_reader :results
9
+ def initialize(prompt:, model:, **kwargs)
10
+ @prompt = prompt
11
+ @model = model
12
+ @gen_kwargs = kwargs
13
+ @results = []
14
+ @capture_key = nil
15
+ @capture_list_key = nil
16
+ end
17
+
18
+
19
+ def ==(other)
20
+ return false unless other.is_a?(Gen)
21
+ # skip looking at prompt and results for now as it makes two prompts not equal with a gen
22
+ # that has run and one that hasn't
23
+ return false if @gen_kwargs != other.gen_kwargs
24
+ return false if @model.is_a?(String) && other.model.is_a?(String) && @model != other.model
25
+ return @model.class == other.model.class
26
+ end
27
+
28
+ def capture(key, list: nil)
29
+ @capture_key, @capture_list_key = key, list
30
+ self
31
+ end
32
+
33
+ def completed?
34
+ @results.any?
35
+ end
36
+
37
+ # This is the method that actually calls the LLM API with the prompt and creates a completion
38
+ # @param model this is a model object or the name of a model.
39
+ # @param client_opts: this is an optional hash of options to pass to the API client when initializing a client model with a string
40
+ # @block streaming_block: this is an optional block that will be called with each chunk of the response when the response is streamed
41
+ def call(model: nil, **call_kwargs, &streaming_block)
42
+ gen_and_call_kwargs = gen_kwargs.merge(call_kwargs)
43
+ model = select_first_model_from(model, @model, Instruct.default_model, gen_and_call_kwargs:)
44
+
45
+ generate_completion = Instruct::GenerateCompletion.new(prompt:, model:, capture_key:, capture_list_key:, streaming_block:, gen_and_call_kwargs: )
46
+ completion = generate_completion.call(calling_gen: self)
47
+
48
+ @results << completion
49
+ completion
50
+ end
51
+
52
+
53
+ def to_s
54
+ if @result.nil?
55
+ "<Instruct::Gen>"
56
+ else
57
+ "<Instruct::Gen call_count=#{result.length}>"
58
+ end
59
+ end
60
+
61
+ private
62
+
63
+ def select_first_model_from(*args, gen_and_call_kwargs:)
64
+ model = args.compact.first
65
+ model = Instruct::Model.from_string(model, **gen_and_call_kwargs) if model.is_a?(String)
66
+ model
67
+ end
68
+
69
+ end
70
+ end
@@ -0,0 +1,61 @@
1
+ module Instruct
2
+ class GenerateCompletion
3
+ def initialize(prompt:, model:, streaming_block:nil, capture_key:, capture_list_key:, gen_and_call_kwargs:)
4
+ @prompt = prompt
5
+ @model = model
6
+ @streaming_block = streaming_block
7
+ @capture_key = capture_key
8
+ @capture_list_key = capture_list_key
9
+ @gen_and_call_kwargs = gen_and_call_kwargs
10
+ @run = false
11
+ end
12
+
13
+ def call(calling_gen:)
14
+ raise RuntimeError, "Cannot call a completed Gen" if @run
15
+ @run = true
16
+
17
+ @original_prompt = @prompt.dup
18
+ completion = Prompt::Completion.new
19
+ prompt = prompt_with_gen_attachment_removed(calling_gen)
20
+ @request = Gen::CompletionRequest.new(prompt: prompt, completion: completion, env: build_request_env)
21
+ if @streaming_block
22
+ @request.add_stream_handler do |response|
23
+ response = prepare_completion_for_return(response)
24
+ @streaming_block.call(response)
25
+ end
26
+ end
27
+ middleware = build_model_middleware_chain(@request)
28
+ response = middleware.execute(@request)
29
+ completion = response.attributed_string
30
+ prepare_completion_for_return(completion)
31
+ end
32
+
33
+ private
34
+
35
+ def prepare_completion_for_return(completion)
36
+ completion._prepare_for_return(prompt: @original_prompt, captured_key: @capture_key, captured_list_key: @capture_list_key, updated_prompt: @request.prompt)
37
+ completion
38
+ end
39
+
40
+ def build_request_env
41
+ @model.default_request_env.merge(@gen_and_call_kwargs)
42
+ end
43
+
44
+ def build_model_middleware_chain(request)
45
+ if @model.respond_to?(:middleware_chain)
46
+ @model.middleware_chain(request)
47
+ else
48
+ @model
49
+ end
50
+ end
51
+
52
+ def prompt_with_gen_attachment_removed(calling_gen)
53
+ if calling_gen && @prompt.attachment_at(@prompt.length - 1) == calling_gen
54
+ @prompt[...-1]
55
+ else
56
+ @prompt.dup
57
+ end
58
+ end
59
+
60
+ end
61
+ end
@@ -0,0 +1,29 @@
1
+ module Instruct
2
+ module Helpers
3
+ P_HELPER_ERROR_MESSAGE = "the p(rompt) helpers should be called using a block p{<arg>} not p(<arg>)"
4
+ module ERBHelper
5
+ def p(*args, &block)
6
+ raise ArgumentError, P_HELPER_ERROR_MESSAGE if args.length > 0
7
+ if block_given?
8
+ Instruct::CompileERB.new(template: yield, _binding: block.binding).prompt
9
+ else
10
+ P.new
11
+ end
12
+ end
13
+ end
14
+ class P
15
+ def system(*args, &block)
16
+ raise ArgumentError, P_HELPER_ERROR_MESSAGE if args.length > 0
17
+ return Prompt.new("\nsystem: ", safe: true)+ Instruct::CompileERB.new(template: yield, _binding: block.binding).prompt
18
+ end
19
+ def user(*args, &block)
20
+ raise ArgumentError, P_HELPER_ERROR_MESSAGE if args.length > 0
21
+ return Prompt.new("\nuser: ", safe: true) + Instruct::CompileERB.new(template: yield, _binding: block.binding).prompt
22
+ end
23
+ def assistant(*args, &block)
24
+ raise ArgumentError, P_HELPER_ERROR_MESSAGE if args.length > 0
25
+ return Prompt.new("\nassistant: ", safe: true) + Instruct::CompileERB.new(template: yield, _binding: block.binding).prompt
26
+ end
27
+ end
28
+ end
29
+ end
@@ -0,0 +1,22 @@
1
+ module Instruct::Helpers
2
+ module GenHelper
3
+ # This helper is used to create a new Instruct::Gen object. It can be used in
4
+ # two ways: with a prompt or without. If a prompt is provided, the
5
+ # method will immediately return the generated completion. If no prompt
6
+ # is provided, the method will return a deferred completion that can be
7
+ # appended to a prompt.
8
+ # @param prompt [Instruct::Prompt, String, nil] The prompt to generate a completion for.
9
+ # @param model [Instruct::Model, String, nil] The model to use for generation.
10
+ # @param client_opts [Hash] Optional keyword argument that contains an option hash to pass to the API client when initializing a client model with a string.
11
+ def gen(prompt = nil, model: nil, **kwargs)
12
+
13
+ prompt = Instruct::Prompt.new(prompt) if prompt.class == String
14
+ model ||= self.respond_to?(:instruct_default_model) ? self.instruct_default_model : nil
15
+ gen = Instruct::Gen.new(prompt: , model: , **kwargs)
16
+
17
+ return gen.call if prompt
18
+
19
+ Instruct::Prompt.new.add_attachment(gen)
20
+ end
21
+ end
22
+ end
@@ -0,0 +1,9 @@
1
+ module Instruct
2
+ module Helpers
3
+ include Instruct::Helpers::GenHelper
4
+ include Instruct::Helpers::ERBHelper
5
+ include Instruct::Helpers::ModelHelper
6
+
7
+
8
+ end
9
+ end
@@ -0,0 +1,13 @@
1
+ module Instruct::Helpers
2
+ module ModelHelper
3
+
4
+ def instruct_default_model
5
+ @_instruct_default_model ||= Instruct.default_model
6
+ end
7
+
8
+ def instruct_default_model=(string_or_model)
9
+ @_instruct_default_model = Instruct::Model.from_string_or_model(string_or_model)
10
+ end
11
+
12
+ end
13
+ end
@@ -0,0 +1,54 @@
1
+ module Instruct
2
+ module Refinements
3
+ refine String do
4
+ # alias_method :old_double_arrow, :<<
5
+ # private :old_double_arrow
6
+ def <<(other)
7
+ if other.is_a?(Prompt) || other.is_a?(Prompt::Completion)
8
+ raise Instruct::Error, <<~ERR.chomp
9
+ Consider using become gem here to make string become a prompt, if you see this error you should
10
+ convert your string to an Instruct::Prompt either using Instruct::Prompt.new or "safe string".prompt_safe
11
+ ERR
12
+ else
13
+ super
14
+ end
15
+ end
16
+ # alias_method :instruct_old_plus, :+
17
+ # private :instruct_old_plus
18
+
19
+ def +(other)
20
+ if other.is_a?(Prompt) || other.is_a?(Prompt::Completion)
21
+ Prompt.new(self) + other
22
+ else
23
+ super
24
+ end
25
+ end
26
+
27
+ def prompt_safe
28
+ string = self.is_a?(AttributedString) ? self : Prompt.new(self)
29
+ string.add_attrs(safe: true)
30
+ end
31
+ end
32
+ # alias_method :instruct_old_plus, :+
33
+ # private :instruct_old_plus
34
+
35
+ # def +(other)
36
+ # if other.is_a?(Instruct::Expression::Expression)
37
+ # wrapped = Instruct::Expression::PlainText.new(self)
38
+ # Instruct::Expression::Concat.new(wrapped, other)
39
+ # else
40
+ # instruct_old_plus(other)
41
+ # end
42
+ # end
43
+
44
+ # end
45
+ # refine Object do
46
+ # def erb(safe: nil, &block)
47
+ # Instruct::Expression::ERBFuture.new(template: block.call, binding: block.binding, safe:)
48
+ # end
49
+ # def gen(**kwargs)
50
+ # Instruct::Expression::LLMFuture.new(**kwargs)
51
+ # end
52
+ # end
53
+ end
54
+ end
@@ -0,0 +1,107 @@
1
+ module Instruct
2
+ class Anthropic
3
+ include Instruct::Serializable
4
+ set_instruct_class_id 200
5
+
6
+ # params client_or_model_name [Anthropic::Client, String] Client instance or model name string
7
+ # params model [String] Required model name to use for completion if client is provided as first arg
8
+ attr_reader :default_request_env
9
+ def initialize(client_or_model_name = "claude-3-5-sonnet-latest", middlewares: [], **kwargs)
10
+ @middlewares = middlewares
11
+ @default_request_env = kwargs
12
+ @cached_clients = {}
13
+
14
+ if client_or_model_name.is_a? ::Anthropic::Client
15
+ @client = client_or_model_name
16
+ @model_name = kwargs.delete(:model) if kwargs[:model]
17
+ raise ArgumentError, "model: keyword argument must be a model name string when initializing with a client (see https://docs.anthropic.com/claude/docs/models-overview)" if @model_name.nil? || @model_name.empty?
18
+ elsif client_or_model_name.is_a? String
19
+ @model_name = client_or_model_name if client_or_model_name.is_a? String
20
+ raise ArgumentError, "Model name must not be blank (see https://docs.anthropic.com/claude/docs/models-overview)" if @model_name.empty?
21
+ else
22
+ raise ArgumentError, "arg must be a model name string (see https://docs.anthropic.com/claude/docs/models-overview) or an instance of Anthropic::Client"
23
+ end
24
+
25
+ append_default_middleware_if_not_added
26
+ set_access_token_from_env_if_needed
27
+ end
28
+
29
+ def middleware_chain(req)
30
+ @middleware_chain ||= Instruct::MiddlewareChain.new(middlewares: (@middlewares || []) << self)
31
+ end
32
+
33
+ def call(req, _next:)
34
+ client = build_client(req.env[:anthropic_client_opts])
35
+ messages_params = req.env[:anthropic_messages_opts]||{}
36
+ messages_params[:model] = @model_name if messages_params[:model].nil?
37
+ warn_about_latest_model_if_needed(messages_params[:model])
38
+ messages_params[:max_tokens] ||= max_tokens_if_not_set(messages_params[:model])
39
+ messages_params.merge!(req.prompt_object)
40
+
41
+ response = Instruct::Anthropic::MessagesCompletionResponse.new(**req.response_kwargs)
42
+ messages_params[:stream] = Proc.new { |chunk| response.call(chunk) }
43
+
44
+ begin
45
+ Instruct.logger.info("Sending Anthropic Messages Completion Request: (#{request_params}) Client:(#{client.inspect})") if Instruct.logger.sev_threshold <= Logger::INFO
46
+ _client_response = client.messages(parameters: messages_params)
47
+ rescue Faraday::Error => e
48
+ if e.respond_to?(:response_body)
49
+ Instruct.err_logger.error("#{e.response_body}")
50
+ else
51
+ Instruct.err_logger.error("#{e.inspect}")
52
+ end
53
+ raise e
54
+ end
55
+
56
+ response
57
+ end
58
+
59
+ def max_tokens_if_not_set(model_name)
60
+ if model_name.include?("claude-3-5-sonnet")
61
+ 8192
62
+ else
63
+ 4096
64
+ end
65
+ end
66
+
67
+ @@warned_about_latest_model = false
68
+ def self.warned_about_latest_model?
69
+ @@warned_about_latest_model
70
+ end
71
+
72
+
73
+ protected
74
+
75
+ def append_default_middleware_if_not_added
76
+ [Instruct::ChompMiddleware, Instruct::ChatCompletionMiddleware, Instruct::Anthropic::Middleware].each do |middleware|
77
+ if !@middlewares.any? { |m| m.is_a?(middleware) }
78
+ @middlewares << middleware.new
79
+ end
80
+ end
81
+ end
82
+
83
+ private
84
+
85
+ def build_client(req_client_opts = {})
86
+ if @client
87
+ raise ArgumentError, "Client options must not be set when initializing with a client" if req_client_opts.any?
88
+ return @client
89
+ end
90
+
91
+ @cached_clients[req_client_opts.hash] ||= ::Anthropic::Client.new(req_client_opts)
92
+ end
93
+
94
+
95
+ def set_access_token_from_env_if_needed
96
+ access_key = ENV["ANTHROPIC_ACCESS_TOKEN"] || ENV["ANTHROPIC_API_KEY"]
97
+ @default_request_env[:access_token] = access_key if access_key && @default_request_env[:access_token].nil?
98
+ end
99
+
100
+ def warn_about_latest_model_if_needed(model_name)
101
+ return if Instruct.suppress_warnings
102
+ if model_name.end_with?("latest") && !@@warned_about_latest_model
103
+ puts "Warning: You are using an anthropic model with the 'latest' suffix. This is alright for development, but not recommended for production. See https://docs.anthropic.com/en/docs/about-claude/models for more information."
104
+ end
105
+ end
106
+ end
107
+ end
@@ -0,0 +1,35 @@
1
+ # frozen_string_literal: true
2
+
3
+ class Instruct::Anthropic
4
+ class MessagesCompletionResponse < Instruct::Gen::CompletionResponse
5
+ @delta_finish_reason = nil
6
+
7
+ def call(chunk)
8
+ chunk = Instruct::SymbolizeKeys.recursive(chunk)
9
+ case chunk
10
+ in { type: "message_start" }
11
+ # do nothing
12
+ in { type: "content_block_start", index: _index, content_block: { type: "text", text: chunk }}
13
+ append_text_chunk(chunk)
14
+ in { type: "content_block_delta", index: _index, delta: { type: "text_delta", text: chunk }}
15
+ append_text_chunk(chunk)
16
+ in { type: "content_block_stop", index: _index }
17
+ # do nothing
18
+ in { type: "message_delta", delta: { stop_reason: } }
19
+ # this occurs just before the message_stop and lets us collect the stop reason (and other info like output tokens)
20
+ @delta_finish_reason = stop_reason
21
+ in { type: "message_stop" }
22
+ done(@delta_finish_reason)
23
+ in { type: "ping" }
24
+ # do nothing
25
+ in { error: { message: , type: } }
26
+ raise RuntimeError, "Anthropic Client Error: (type: #{type}, message: #{message})"
27
+ else
28
+ raise RuntimeError, "Unexpected Chunk: #{chunk}"
29
+ end
30
+ chunk_processed
31
+ end
32
+
33
+
34
+ end
35
+ end
@@ -0,0 +1,91 @@
1
+ class Instruct::Anthropic
2
+ class Middleware
3
+ include Instruct::Serializable
4
+ set_instruct_class_id 201
5
+
6
+ CLIENT_PARAMS = %i[access_token anthropic_version api_version extra_headers request_timeout uri_base beta].freeze
7
+ # TODO: make request params settable at the model level, its silly to not set temperature in one place
8
+ REQUEST_PARAMS = %i[metadata max_tokens temperature tools tool_choice top_k top_p stop_sequences system].freeze
9
+
10
+ def call(req, _next:)
11
+ raise Instruct::Todo, "Tools are not supported yet, consider opening a pull request" if req.env[:tools] || req.env[:tool_choice]
12
+
13
+ # pull out the client options that were in the call and put them in anthropic_client_opts
14
+ client_options = filter_env_keys(req, CLIENT_PARAMS)
15
+ transform_beta_argument_into_extra_headers(client_options, req.env[:beta])
16
+
17
+ req.env[:anthropic_client_opts] = client_options
18
+
19
+ # pull out the message request params and put them in anthropic_messages_opts
20
+ request_options = filter_env_keys(req, REQUEST_PARAMS)
21
+ if request_options[:system].nil?
22
+ # TODO: this will probably go back into the chat completion middleware and can be removed
23
+ request_options[:system] = req.env[:system_from_prompt].to_s
24
+ end
25
+ normalize_stop_sequence_arguments(req, request_options)
26
+
27
+ req.env[:anthropic_messages_opts] = request_options
28
+
29
+ req.add_prompt_transform do | prompt_obj |
30
+ transform(prompt_obj)
31
+ end
32
+
33
+ _next.call(req)
34
+ end
35
+
36
+ private
37
+
38
+ def filter_env_keys(req, keys)
39
+ req.env.select { |k, _| keys.include?(k) }
40
+ end
41
+
42
+ def normalize_stop_sequence_arguments(req, request_options)
43
+ # Make the stop_chars and stop options consistent with openai
44
+ if req.env[:stop_chars].is_a?(String)
45
+ request_options[:stop_sequences] = req.env[:stop_chars].split('')
46
+ end
47
+ if req.env[:stop].is_a?(String)
48
+ request_options[:stop_sequences] = [req.env[:stop]]
49
+ elsif req.env[:stop].is_a?(Array)
50
+ request_options[:stop_sequences] = req.env[:stop]
51
+ end
52
+ end
53
+
54
+ def transform(prompt_obj)
55
+ raise RuntimeError, "Expected hash with messages, probably missing chat completion middleware" unless prompt_obj.is_a?(Hash) && prompt_obj[:messages].is_a?(Array)
56
+ remove_system_message(prompt_obj)
57
+ convert_messages_to_anthropic_format(prompt_obj)
58
+ prompt_obj
59
+ end
60
+
61
+ def remove_system_message(prompt_obj)
62
+ prompt_obj[:messages].reject! { |message| message.keys.first == :system }
63
+ end
64
+
65
+ def convert_messages_to_anthropic_format(prompt_obj)
66
+ prompt_obj[:messages].map! do |message|
67
+ { role: message.keys.first, content: message.values.first.to_s }
68
+ end
69
+ end
70
+
71
+ # This method takes the beta argument and transforms it into a header for the client
72
+ def transform_beta_argument_into_extra_headers(client_options, beta)
73
+ return unless beta
74
+ client_options.delete(:beta)
75
+ client_options[:extra_headers] ||= {}
76
+
77
+ if client_options[:extra_headers]['anthropic-beta']
78
+ raise ArgumentError, "Cannot set anthropic-beta header to #{beta} when it is already set to #{client_options[:extra_headers]['anthropic-beta']}."
79
+ end
80
+
81
+ if beta.is_a?(Array)
82
+ client_options[:extra_headers]['anthropic-beta'] = beta.join(',')
83
+ elsif beta.is_a?(String)
84
+ client_options[:extra_headers] ||= {}
85
+ client_options[:extra_headers]['anthropic-beta'] = beta.to_s
86
+ else
87
+ raise ArgumentError, "beta must be a string or an array of strings"
88
+ end
89
+ end
90
+ end
91
+ end
@@ -0,0 +1,21 @@
1
+ class Instruct::OpenAI
2
+ class ChatCompletionResponse < Instruct::Gen::CompletionResponse
3
+
4
+ def call(chunk)
5
+ case Instruct::SymbolizeKeys.recursive(chunk)
6
+ # TODO: check if this will break if the content is not text
7
+ in { choices: [ { delta: {}, finish_reason: } ] }
8
+ done(finish_reason) unless finish_reason.nil?
9
+ in { choices: [ { delta: { content: new_content }, finish_reason: } ] }
10
+ append_text_chunk(new_content)
11
+ done(finish_reason) unless finish_reason.nil?
12
+ in { error: { message: } }
13
+ raise RuntimeError, "OpenAI Client Error: #{message}"
14
+ else
15
+ raise RuntimeError, "Unexpected Chunk: #{chunk}"
16
+ end
17
+ chunk_processed
18
+ end
19
+
20
+ end
21
+ end
@@ -0,0 +1,129 @@
1
+ module Instruct
2
+ class OpenAI
3
+ include Instruct::Serializable
4
+ set_instruct_class_id 100
5
+
6
+ attr_reader :default_request_env
7
+
8
+ def middleware_chain(req)
9
+ middlewares = @middlewares || []
10
+ append_default_middleware_if_not_added(req, middlewares)
11
+ @middleware_chain ||= Instruct::MiddlewareChain.new(middlewares: (middlewares << self))
12
+ end
13
+
14
+ def initialize(client_or_model_name = 'gpt-3.5-turbo-instruct', middlewares: [], **kwargs)
15
+ @middlewares = middlewares
16
+ @default_request_env = kwargs
17
+ @cached_clients = {}
18
+
19
+ if client_or_model_name.is_a? ::OpenAI::Client
20
+ @client = client_or_model_name
21
+ @model_name = kwargs.delete(:model) if kwargs[:model]
22
+ raise ArgumentError, "model: keyword argument must be a model name string when initializing with a client" if @model_name.nil? || @model_name.empty?
23
+ elsif client_or_model_name.is_a? String
24
+ @model_name = client_or_model_name
25
+ raise ArgumentError, "Model name must not be blank" if @model_name.empty?
26
+ else
27
+ raise ArgumentError, "arg must be a model name string or an instance of OpenAI::Client"
28
+ end
29
+
30
+ set_access_token_from_env_if_needed
31
+ end
32
+
33
+ def call(req, _next:)
34
+ client = build_client(req.env[:openai_client_opts] || {})
35
+
36
+ request_params = req.env[:openai_args] || {}
37
+ request_params[:model] = @model_name if request_params[:model].nil?
38
+
39
+ if is_chat_model?(req)
40
+ response = request_params[:stream] = Instruct::OpenAI::ChatCompletionResponse.new(**req.response_kwargs)
41
+ request_params.merge!(req.prompt_object)
42
+ warn_about_deprecated_args(req.env[:openai_deprecated_args]) if req.env[:openai_deprecated_args]
43
+ begin
44
+ Instruct.logger.info("Sending OpenAI Chat Completion Request: (#{request_params}) Client:(#{client.inspect})") if Instruct.logger.sev_threshold <= Logger::INFO
45
+ _client_response = client.chat(parameters: request_params)
46
+ rescue Faraday::Error => e
47
+ if e.respond_to?(:response_body)
48
+ Instruct.err_logger.error("#{e.response_body}")
49
+ else
50
+ Instruct.err_logger.error("#{e.inspect}")
51
+ end
52
+ raise e
53
+ end
54
+ else
55
+ warn_about_completions_endpoint if !@warned
56
+ response = request_params[:stream] = Instruct::OpenAI::CompletionResponse.new(**req.response_kwargs)
57
+ request_params.merge!({prompt: req.prompt_object})
58
+ begin
59
+ Instruct.logger.info("Sending OpenAI Completion Request: (#{request_params}) Client:(#{client.inspect})") if Instruct.logger.sev_threshold <= Logger::INFO
60
+ _client_response = client.completions(parameters: request_params)
61
+ rescue Faraday::Error => e
62
+ if e.respond_to?(:response_body)
63
+ Instruct.err_logger.error("#{e.response_body}")
64
+ else
65
+ Instruct.err_logger.error("#{e.inspect}")
66
+ end
67
+ raise e
68
+ end
69
+ end
70
+ response
71
+ end
72
+
73
+ protected
74
+
75
+ def append_default_middleware_if_not_added(req, middlewares)
76
+ openai_middlewares = [Instruct::OpenAI::Middleware.new]
77
+ if is_chat_model?(req)
78
+ openai_middlewares = [Instruct::ChompMiddleware.new, Instruct::ChatCompletionMiddleware.new] + openai_middlewares
79
+ end
80
+ openai_middlewares.each do |middleware|
81
+ if !middlewares.any? { |m| m.is_a?(middleware.class) }
82
+ middlewares << middleware
83
+ end
84
+ end
85
+ end
86
+
87
+ private
88
+
89
+ def is_chat_model?(req)
90
+ !(req.env[:use_completion_endpoint] || ((req.env[:model] || @model_name) == 'gpt-3.5-turbo-instruct'))
91
+ end
92
+
93
+ def build_client(req_client_opts = {})
94
+ if @client
95
+ raise ArgumentError, "Client options must not be set when initializing with a client" if req_client_opts.any?
96
+ return @client
97
+ end
98
+
99
+ client_opts = @default_request_env.select { |k, _| Instruct::OpenAI::Middleware::CLIENT_PARAMS.include?(k) }
100
+ client_opts.merge!(req_client_opts)
101
+
102
+ @cached_clients[client_opts.hash] ||= ::OpenAI::Client.new(
103
+ access_token: client_opts[:access_token] || ENV['OPENAI_API_KEY'] || ENV['OPENAI_ACCESS_TOKEN'],
104
+ uri_base: client_opts[:uri_base],
105
+ request_timeout: client_opts[:request_timeout],
106
+ extra_headers: client_opts[:extra_headers]
107
+ )
108
+ end
109
+
110
+ def set_access_token_from_env_if_needed
111
+ access_key = ENV['OPENAI_API_KEY'] || ENV['OPENAI_ACCESS_TOKEN']
112
+ @default_request_env[:access_token] = access_key if access_key && @default_request_env[:access_token].nil?
113
+ end
114
+
115
+ def warn_about_deprecated_args(deprecated_args)
116
+ return if Instruct.suppress_warnings || @deprecated_arg_warned
117
+ if deprecated_args && !deprecated_args.empty?
118
+ puts "Warning: the follow args are deprecated by OpenAI and will be removed in the future: #{deprecated_args.keys.join(', ')}"
119
+ @deprecated_arg_warned = true
120
+ end
121
+ end
122
+
123
+ def warn_about_completions_endpoint
124
+ return if Instruct.suppress_warnings || @warned
125
+ puts "Warning: the completions endpoint is being shutdown by OpenAI in Jan 2025."
126
+ @warned = true
127
+ end
128
+ end
129
+ end
@@ -0,0 +1,20 @@
1
+ class Instruct::OpenAI
2
+ # The completion API has been deprecated from OpenAI but some alternative service providers
3
+ # may still be using it. Leaving it in for now.
4
+ class CompletionResponse < Instruct::Gen::CompletionResponse
5
+
6
+ def call(chunk)
7
+ case Instruct::SymbolizeKeys.recursive(chunk)
8
+ in { choices: [ { text: new_content, finish_reason: } ] }
9
+ append_text_chunk(new_content)
10
+ done(finish_reason) unless finish_reason.nil?
11
+ in { error: { message: } }
12
+ raise RuntimeError, "OpenAI Client Error: #{message}"
13
+ else
14
+ raise RuntimeError, "Unexpected Chunk: #{chunk}"
15
+ end
16
+ chunk_processed
17
+ end
18
+
19
+ end
20
+ end