langchainrb 0.12.1 → 0.13.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +3 -2
- data/lib/langchain/assistants/assistant.rb +75 -20
- data/lib/langchain/assistants/messages/base.rb +16 -0
- data/lib/langchain/assistants/messages/google_gemini_message.rb +90 -0
- data/lib/langchain/assistants/messages/openai_message.rb +74 -0
- data/lib/langchain/assistants/thread.rb +5 -5
- data/lib/langchain/llm/base.rb +2 -1
- data/lib/langchain/llm/google_gemini.rb +67 -0
- data/lib/langchain/llm/google_vertex_ai.rb +68 -112
- data/lib/langchain/llm/response/google_gemini_response.rb +45 -0
- data/lib/langchain/llm/response/openai_response.rb +5 -1
- data/lib/langchain/tool/base.rb +11 -1
- data/lib/langchain/tool/calculator/calculator.json +1 -1
- data/lib/langchain/tool/database/database.json +3 -3
- data/lib/langchain/tool/file_system/file_system.json +3 -3
- data/lib/langchain/tool/news_retriever/news_retriever.json +121 -0
- data/lib/langchain/tool/news_retriever/news_retriever.rb +132 -0
- data/lib/langchain/tool/ruby_code_interpreter/ruby_code_interpreter.json +1 -1
- data/lib/langchain/tool/vectorsearch/vectorsearch.json +1 -1
- data/lib/langchain/tool/weather/weather.json +1 -1
- data/lib/langchain/tool/wikipedia/wikipedia.json +1 -1
- data/lib/langchain/tool/wikipedia/wikipedia.rb +2 -2
- data/lib/langchain/version.rb +1 -1
- data/lib/langchain.rb +3 -0
- metadata +14 -9
- data/lib/langchain/assistants/message.rb +0 -58
- data/lib/langchain/llm/response/google_vertex_ai_response.rb +0 -33
@@ -2,150 +2,106 @@
|
|
2
2
|
|
3
3
|
module Langchain::LLM
|
4
4
|
#
|
5
|
-
# Wrapper around the Google Vertex AI APIs: https://cloud.google.com/vertex-ai
|
5
|
+
# Wrapper around the Google Vertex AI APIs: https://cloud.google.com/vertex-ai
|
6
6
|
#
|
7
7
|
# Gem requirements:
|
8
|
-
# gem "
|
8
|
+
# gem "googleauth"
|
9
9
|
#
|
10
10
|
# Usage:
|
11
|
-
#
|
11
|
+
# llm = Langchain::LLM::GoogleVertexAI.new(project_id: ENV["GOOGLE_VERTEX_AI_PROJECT_ID"], region: "us-central1")
|
12
12
|
#
|
13
|
-
class
|
13
|
+
class GoogleVertexAI < Base
|
14
14
|
DEFAULTS = {
|
15
|
-
temperature: 0.1,
|
15
|
+
temperature: 0.1,
|
16
16
|
max_output_tokens: 1000,
|
17
17
|
top_p: 0.8,
|
18
18
|
top_k: 40,
|
19
19
|
dimensions: 768,
|
20
|
-
|
21
|
-
|
20
|
+
embeddings_model_name: "textembedding-gecko",
|
21
|
+
chat_completion_model_name: "gemini-1.0-pro"
|
22
22
|
}.freeze
|
23
23
|
|
24
|
-
# TODO: Implement token length validation
|
25
|
-
# LENGTH_VALIDATOR = Langchain::Utils::TokenLength::...
|
26
|
-
|
27
24
|
# Google Cloud has a project id and a specific region of deployment.
|
28
25
|
# For GenAI-related things, a safe choice is us-central1.
|
29
|
-
attr_reader :
|
30
|
-
|
31
|
-
def initialize(project_id:, default_options: {})
|
32
|
-
depends_on "google-apis-aiplatform_v1"
|
26
|
+
attr_reader :defaults, :url, :authorizer
|
33
27
|
|
34
|
-
|
35
|
-
|
28
|
+
def initialize(project_id:, region:, default_options: {})
|
29
|
+
depends_on "googleauth"
|
36
30
|
|
37
|
-
@
|
38
|
-
|
39
|
-
|
40
|
-
# For the moment only us-central1 available so no big deal.
|
41
|
-
@client.root_url = "https://#{@region}-aiplatform.googleapis.com/"
|
42
|
-
@client.authorization = Google::Auth.get_application_default
|
31
|
+
@authorizer = ::Google::Auth.get_application_default
|
32
|
+
proj_id = project_id || @authorizer.project_id || @authorizer.quota_project_id
|
33
|
+
@url = "https://#{region}-aiplatform.googleapis.com/v1/projects/#{proj_id}/locations/#{region}/publishers/google/models/"
|
43
34
|
|
44
35
|
@defaults = DEFAULTS.merge(default_options)
|
36
|
+
|
37
|
+
chat_parameters.update(
|
38
|
+
model: {default: @defaults[:chat_completion_model_name]},
|
39
|
+
temperature: {default: @defaults[:temperature]}
|
40
|
+
)
|
41
|
+
chat_parameters.remap(
|
42
|
+
messages: :contents,
|
43
|
+
system: :system_instruction,
|
44
|
+
tool_choice: :tool_config
|
45
|
+
)
|
45
46
|
end
|
46
47
|
|
47
48
|
#
|
48
49
|
# Generate an embedding for a given text
|
49
50
|
#
|
50
51
|
# @param text [String] The text to generate an embedding for
|
51
|
-
# @
|
52
|
+
# @param model [String] ID of the model to use
|
53
|
+
# @return [Langchain::LLM::GoogleGeminiResponse] Response object
|
52
54
|
#
|
53
|
-
def embed(
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
55
|
+
def embed(
|
56
|
+
text:,
|
57
|
+
model: @defaults[:embeddings_model_name]
|
58
|
+
)
|
59
|
+
params = {instances: [{content: text}]}
|
60
|
+
|
61
|
+
response = HTTParty.post(
|
62
|
+
"#{url}#{model}:predict",
|
63
|
+
body: params.to_json,
|
64
|
+
headers: {
|
65
|
+
"Content-Type" => "application/json",
|
66
|
+
"Authorization" => "Bearer #{@authorizer.fetch_access_token!["access_token"]}"
|
67
|
+
}
|
68
|
+
)
|
62
69
|
|
63
|
-
Langchain::LLM::
|
70
|
+
Langchain::LLM::GoogleGeminiResponse.new(response, model: model)
|
64
71
|
end
|
65
72
|
|
73
|
+
# Generate a chat completion for given messages
|
66
74
|
#
|
67
|
-
#
|
68
|
-
#
|
69
|
-
# @param
|
70
|
-
# @param
|
71
|
-
# @
|
72
|
-
#
|
73
|
-
def
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
}
|
82
|
-
|
83
|
-
|
84
|
-
|
75
|
+
# @param messages [Array<Hash>] Input messages
|
76
|
+
# @param model [String] The model that will complete your prompt
|
77
|
+
# @param tools [Array<Hash>] The tools to use
|
78
|
+
# @param tool_choice [String] The tool choice to use
|
79
|
+
# @param system [String] The system instruction to use
|
80
|
+
# @return [Langchain::LLM::GoogleGeminiResponse] Response object
|
81
|
+
def chat(params = {})
|
82
|
+
params[:system] = {parts: [{text: params[:system]}]} if params[:system]
|
83
|
+
params[:tools] = {function_declarations: params[:tools]} if params[:tools]
|
84
|
+
params[:tool_choice] = {function_calling_config: {mode: params[:tool_choice].upcase}} if params[:tool_choice]
|
85
|
+
|
86
|
+
raise ArgumentError.new("messages argument is required") if Array(params[:messages]).empty?
|
87
|
+
|
88
|
+
parameters = chat_parameters.to_params(params)
|
89
|
+
parameters[:generation_config] = {temperature: parameters.delete(:temperature)} if parameters[:temperature]
|
90
|
+
|
91
|
+
uri = URI("#{url}#{parameters[:model]}:generateContent")
|
92
|
+
|
93
|
+
request = Net::HTTP::Post.new(uri)
|
94
|
+
request.content_type = "application/json"
|
95
|
+
request["Authorization"] = "Bearer #{@authorizer.fetch_access_token!["access_token"]}"
|
96
|
+
request.body = parameters.to_json
|
97
|
+
|
98
|
+
response = Net::HTTP.start(uri.hostname, uri.port, use_ssl: uri.scheme == "https") do |http|
|
99
|
+
http.request(request)
|
85
100
|
end
|
86
101
|
|
87
|
-
|
88
|
-
default_params[:max_output_tokens] = params.delete(:max_output_tokens)
|
89
|
-
end
|
90
|
-
|
91
|
-
# to be tested
|
92
|
-
temperature = params.delete(:temperature) || @defaults[:temperature]
|
93
|
-
max_output_tokens = default_params.fetch(:max_output_tokens, @defaults[:max_output_tokens])
|
94
|
-
|
95
|
-
default_params.merge!(params)
|
96
|
-
|
97
|
-
# response = client.generate_text(**default_params)
|
98
|
-
request = Google::Apis::AiplatformV1::GoogleCloudAiplatformV1PredictRequest.new \
|
99
|
-
instances: [{
|
100
|
-
prompt: prompt # key used to be :content, changed to :prompt
|
101
|
-
}],
|
102
|
-
parameters: {
|
103
|
-
temperature: temperature,
|
104
|
-
maxOutputTokens: max_output_tokens,
|
105
|
-
topP: 0.8,
|
106
|
-
topK: 40
|
107
|
-
}
|
108
|
-
|
109
|
-
response = client.predict_project_location_publisher_model \
|
110
|
-
"projects/#{project_id}/locations/us-central1/publishers/google/models/#{@defaults[:completion_model_name]}",
|
111
|
-
request
|
112
|
-
|
113
|
-
Langchain::LLM::GoogleVertexAiResponse.new(response, model: default_params[:model])
|
114
|
-
end
|
102
|
+
parsed_response = JSON.parse(response.body)
|
115
103
|
|
116
|
-
|
117
|
-
# Generate a summarization for a given text
|
118
|
-
#
|
119
|
-
# @param text [String] The text to generate a summarization for
|
120
|
-
# @return [String] The summarization
|
121
|
-
#
|
122
|
-
# TODO(ricc): add params for Temp, topP, topK, MaxTokens and have it default to these 4 values.
|
123
|
-
def summarize(text:)
|
124
|
-
prompt_template = Langchain::Prompt.load_from_path(
|
125
|
-
file_path: Langchain.root.join("langchain/llm/prompts/summarize_template.yaml")
|
126
|
-
)
|
127
|
-
prompt = prompt_template.format(text: text)
|
128
|
-
|
129
|
-
complete(
|
130
|
-
prompt: prompt,
|
131
|
-
# For best temperature, topP, topK, MaxTokens for summarization: see
|
132
|
-
# https://cloud.google.com/vertex-ai/docs/samples/aiplatform-sdk-summarization
|
133
|
-
temperature: 0.2,
|
134
|
-
top_p: 0.95,
|
135
|
-
top_k: 40,
|
136
|
-
# Most models have a context length of 2048 tokens (except for the newest models, which support 4096).
|
137
|
-
max_output_tokens: 256
|
138
|
-
)
|
104
|
+
Langchain::LLM::GoogleGeminiResponse.new(parsed_response, model: parameters[:model])
|
139
105
|
end
|
140
|
-
|
141
|
-
# def chat(...)
|
142
|
-
# https://cloud.google.com/vertex-ai/docs/samples/aiplatform-sdk-chathat
|
143
|
-
# Chat params: https://cloud.google.com/vertex-ai/docs/samples/aiplatform-sdk-chat
|
144
|
-
# \"temperature\": 0.3,\n"
|
145
|
-
# + " \"maxDecodeSteps\": 200,\n"
|
146
|
-
# + " \"topP\": 0.8,\n"
|
147
|
-
# + " \"topK\": 40\n"
|
148
|
-
# + "}";
|
149
|
-
# end
|
150
106
|
end
|
151
107
|
end
|
@@ -0,0 +1,45 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Langchain::LLM
|
4
|
+
class GoogleGeminiResponse < BaseResponse
|
5
|
+
def initialize(raw_response, model: nil)
|
6
|
+
super(raw_response, model: model)
|
7
|
+
end
|
8
|
+
|
9
|
+
def chat_completion
|
10
|
+
raw_response.dig("candidates", 0, "content", "parts", 0, "text")
|
11
|
+
end
|
12
|
+
|
13
|
+
def role
|
14
|
+
raw_response.dig("candidates", 0, "content", "role")
|
15
|
+
end
|
16
|
+
|
17
|
+
def tool_calls
|
18
|
+
if raw_response.dig("candidates", 0, "content") && raw_response.dig("candidates", 0, "content", "parts", 0).has_key?("functionCall")
|
19
|
+
raw_response.dig("candidates", 0, "content", "parts")
|
20
|
+
else
|
21
|
+
[]
|
22
|
+
end
|
23
|
+
end
|
24
|
+
|
25
|
+
def embedding
|
26
|
+
embeddings.first
|
27
|
+
end
|
28
|
+
|
29
|
+
def embeddings
|
30
|
+
[raw_response.dig("predictions", 0, "embeddings", "values")]
|
31
|
+
end
|
32
|
+
|
33
|
+
def prompt_tokens
|
34
|
+
raw_response.dig("usageMetadata", "promptTokenCount")
|
35
|
+
end
|
36
|
+
|
37
|
+
def completion_tokens
|
38
|
+
raw_response.dig("usageMetadata", "candidatesTokenCount")
|
39
|
+
end
|
40
|
+
|
41
|
+
def total_tokens
|
42
|
+
raw_response.dig("usageMetadata", "totalTokenCount")
|
43
|
+
end
|
44
|
+
end
|
45
|
+
end
|
@@ -25,7 +25,11 @@ module Langchain::LLM
|
|
25
25
|
end
|
26
26
|
|
27
27
|
def tool_calls
|
28
|
-
chat_completions
|
28
|
+
if chat_completions.dig(0, "message").has_key?("tool_calls")
|
29
|
+
chat_completions.dig(0, "message", "tool_calls")
|
30
|
+
else
|
31
|
+
[]
|
32
|
+
end
|
29
33
|
end
|
30
34
|
|
31
35
|
def embedding
|
data/lib/langchain/tool/base.rb
CHANGED
@@ -66,11 +66,21 @@ module Langchain::Tool
|
|
66
66
|
|
67
67
|
# Returns the tool as a list of OpenAI formatted functions
|
68
68
|
#
|
69
|
-
# @return [Hash] tool as
|
69
|
+
# @return [Array<Hash>] List of hashes representing the tool as OpenAI formatted functions
|
70
70
|
def to_openai_tools
|
71
71
|
method_annotations
|
72
72
|
end
|
73
73
|
|
74
|
+
# Returns the tool as a list of Google Gemini formatted functions
|
75
|
+
#
|
76
|
+
# @return [Array<Hash>] List of hashes representing the tool as Google Gemini formatted functions
|
77
|
+
def to_google_gemini_tools
|
78
|
+
method_annotations.map do |annotation|
|
79
|
+
# Slice out only the content of the "function" key
|
80
|
+
annotation["function"]
|
81
|
+
end
|
82
|
+
end
|
83
|
+
|
74
84
|
# Return tool's method annotations as JSON
|
75
85
|
#
|
76
86
|
# @return [Hash] Tool's method annotations
|
@@ -2,7 +2,7 @@
|
|
2
2
|
{
|
3
3
|
"type": "function",
|
4
4
|
"function": {
|
5
|
-
"name": "
|
5
|
+
"name": "calculator__execute",
|
6
6
|
"description": "Evaluates a pure math expression or if equation contains non-math characters (e.g.: \"12F in Celsius\") then it uses the google search calculator to evaluate the expression",
|
7
7
|
"parameters": {
|
8
8
|
"type": "object",
|
@@ -2,7 +2,7 @@
|
|
2
2
|
{
|
3
3
|
"type": "function",
|
4
4
|
"function": {
|
5
|
-
"name": "
|
5
|
+
"name": "database__describe_tables",
|
6
6
|
"description": "Database Tool: Returns the schema for a list of tables",
|
7
7
|
"parameters": {
|
8
8
|
"type": "object",
|
@@ -18,7 +18,7 @@
|
|
18
18
|
}, {
|
19
19
|
"type": "function",
|
20
20
|
"function": {
|
21
|
-
"name": "
|
21
|
+
"name": "database__list_tables",
|
22
22
|
"description": "Database Tool: Returns a list of tables in the database",
|
23
23
|
"parameters": {
|
24
24
|
"type": "object",
|
@@ -29,7 +29,7 @@
|
|
29
29
|
}, {
|
30
30
|
"type": "function",
|
31
31
|
"function": {
|
32
|
-
"name": "
|
32
|
+
"name": "database__execute",
|
33
33
|
"description": "Database Tool: Executes a SQL query and returns the results",
|
34
34
|
"parameters": {
|
35
35
|
"type": "object",
|
@@ -2,7 +2,7 @@
|
|
2
2
|
{
|
3
3
|
"type": "function",
|
4
4
|
"function": {
|
5
|
-
"name": "
|
5
|
+
"name": "file_system__list_directory",
|
6
6
|
"description": "File System Tool: Lists out the content of a specified directory",
|
7
7
|
"parameters": {
|
8
8
|
"type": "object",
|
@@ -19,7 +19,7 @@
|
|
19
19
|
{
|
20
20
|
"type": "function",
|
21
21
|
"function": {
|
22
|
-
"name": "
|
22
|
+
"name": "file_system__read_file",
|
23
23
|
"description": "File System Tool: Reads the contents of a file",
|
24
24
|
"parameters": {
|
25
25
|
"type": "object",
|
@@ -36,7 +36,7 @@
|
|
36
36
|
{
|
37
37
|
"type": "function",
|
38
38
|
"function": {
|
39
|
-
"name": "
|
39
|
+
"name": "file_system__write_to_file",
|
40
40
|
"description": "File System Tool: Write content to a file",
|
41
41
|
"parameters": {
|
42
42
|
"type": "object",
|
@@ -0,0 +1,121 @@
|
|
1
|
+
[
|
2
|
+
{
|
3
|
+
"type": "function",
|
4
|
+
"function": {
|
5
|
+
"name": "news_retriever__get_everything",
|
6
|
+
"description": "News Retriever: Search through millions of articles from over 150,000 large and small news sources and blogs.",
|
7
|
+
"parameters": {
|
8
|
+
"type": "object",
|
9
|
+
"properties": {
|
10
|
+
"q": {
|
11
|
+
"type": "string",
|
12
|
+
"description": "Keywords or phrases to search for in the article title and body. Surround phrases with quotes (\") for exact match. Alternatively you can use the AND / OR / NOT keywords, and optionally group these with parenthesis. Must be URL-encoded."
|
13
|
+
},
|
14
|
+
"search_in": {
|
15
|
+
"type": "string",
|
16
|
+
"description": "The fields to restrict your q search to.",
|
17
|
+
"enum": ["title", "description", "content"]
|
18
|
+
},
|
19
|
+
"sources": {
|
20
|
+
"type": "string",
|
21
|
+
"description": "A comma-seperated string of identifiers (maximum 20) for the news sources or blogs you want headlines from. Use the /sources endpoint to locate these programmatically or look at the sources index."
|
22
|
+
},
|
23
|
+
"domains": {
|
24
|
+
"type": "string",
|
25
|
+
"description": "A comma-seperated string of domains (eg bbc.co.uk, techcrunch.com, engadget.com) to restrict the search to."
|
26
|
+
},
|
27
|
+
"exclude_domains": {
|
28
|
+
"type": "string",
|
29
|
+
"description": "A comma-seperated string of domains (eg bbc.co.uk, techcrunch.com, engadget.com) to remove from the results."
|
30
|
+
},
|
31
|
+
"from": {
|
32
|
+
"type": "string",
|
33
|
+
"description": "A date and optional time for the oldest article allowed. This should be in ISO 8601 format."
|
34
|
+
},
|
35
|
+
"to": {
|
36
|
+
"type": "string",
|
37
|
+
"description": "A date and optional time for the newest article allowed. This should be in ISO 8601 format."
|
38
|
+
},
|
39
|
+
"language": {
|
40
|
+
"type": "string",
|
41
|
+
"description": "The 2-letter ISO-639-1 code of the language you want to get headlines for.",
|
42
|
+
"enum": ["ar", "de", "en", "es", "fr", "he", "it", "nl", "no", "pt", "ru", "sv", "ud", "zh"]
|
43
|
+
},
|
44
|
+
"sort_by": {
|
45
|
+
"type": "string",
|
46
|
+
"description": "The order to sort the articles in.",
|
47
|
+
"enum": ["relevancy", "popularity", "publishedAt"]
|
48
|
+
},
|
49
|
+
"page_size": {
|
50
|
+
"type": "integer",
|
51
|
+
"description": "The number of results to return per page (request). 5 is the default, 100 is the maximum."
|
52
|
+
},
|
53
|
+
"page": {
|
54
|
+
"type": "integer",
|
55
|
+
"description": "Use this to page through the results if the total results found is greater than the page size."
|
56
|
+
}
|
57
|
+
}
|
58
|
+
}
|
59
|
+
}
|
60
|
+
},
|
61
|
+
{
|
62
|
+
"type": "function",
|
63
|
+
"function": {
|
64
|
+
"name": "news_retriever__get_top_headlines",
|
65
|
+
"description": "News Retriever: Provides live top and breaking headlines for a country, specific category in a country, single source, or multiple sources. You can also search with keywords. Articles are sorted by the earliest date published first.",
|
66
|
+
"parameters": {
|
67
|
+
"type": "object",
|
68
|
+
"properties": {
|
69
|
+
"country": {
|
70
|
+
"type": "string",
|
71
|
+
"description": "The 2-letter ISO 3166-1 code of the country you want to get headlines for."
|
72
|
+
},
|
73
|
+
"category": {
|
74
|
+
"type": "string",
|
75
|
+
"description": "The category you want to get headlines for.",
|
76
|
+
"enum": ["business", "entertainment", "general", "health", "science", "sports", "technology"]
|
77
|
+
},
|
78
|
+
"q": {
|
79
|
+
"type": "string",
|
80
|
+
"description": "Keywords or a phrase to search for."
|
81
|
+
},
|
82
|
+
"page_size": {
|
83
|
+
"type": "integer",
|
84
|
+
"description": "The number of results to return per page (request). 5 is the default, 100 is the maximum."
|
85
|
+
},
|
86
|
+
"page": {
|
87
|
+
"type": "integer",
|
88
|
+
"description": "Use this to page through the results if the total results found is greater than the page size."
|
89
|
+
}
|
90
|
+
}
|
91
|
+
}
|
92
|
+
}
|
93
|
+
},
|
94
|
+
{
|
95
|
+
"type": "function",
|
96
|
+
"function": {
|
97
|
+
"name": "news_retriever__get_sources",
|
98
|
+
"description": "News Retriever: This endpoint returns the subset of news publishers that top headlines (/v2/top-headlines) are available from. It's mainly a convenience endpoint that you can use to keep track of the publishers available on the API, and you can pipe it straight through to your users.",
|
99
|
+
"parameters": {
|
100
|
+
"type": "object",
|
101
|
+
"properties": {
|
102
|
+
"country": {
|
103
|
+
"type": "string",
|
104
|
+
"description": "The 2-letter ISO 3166-1 code of the country you want to get headlines for. Default: all countries.",
|
105
|
+
"enum": ["ae", "ar", "at", "au", "be", "bg", "br", "ca", "ch", "cn", "co", "cu", "cz", "de", "eg", "fr", "gb", "gr", "hk", "hu", "id", "ie", "il", "in", "it", "jp", "kr", "lt", "lv", "ma", "mx", "my", "ng", "nl", "no", "nz", "ph", "pl", "pt", "ro", "rs", "ru", "sa", "se", "sg", "si", "sk", "th", "tr", "tw", "ua", "us", "ve", "za"]
|
106
|
+
},
|
107
|
+
"category": {
|
108
|
+
"type": "string",
|
109
|
+
"description": "The category you want to get headlines for. Default: all categories.",
|
110
|
+
"enum": ["business", "entertainment", "general", "health", "science", "sports", "technology"]
|
111
|
+
},
|
112
|
+
"language": {
|
113
|
+
"type": "string",
|
114
|
+
"description": "The 2-letter ISO-639-1 code of the language you want to get headlines for.",
|
115
|
+
"enum": ["ar", "de", "en", "es", "fr", "he", "it", "nl", "no", "pt", "ru", "sv", "ud", "zh"]
|
116
|
+
}
|
117
|
+
}
|
118
|
+
}
|
119
|
+
}
|
120
|
+
}
|
121
|
+
]
|
@@ -0,0 +1,132 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Langchain::Tool
|
4
|
+
class NewsRetriever < Base
|
5
|
+
#
|
6
|
+
# A tool that retrieves latest news from various sources via https://newsapi.org/.
|
7
|
+
# An API key needs to be obtained from https://newsapi.org/ to use this tool.
|
8
|
+
#
|
9
|
+
# Usage:
|
10
|
+
# news_retriever = Langchain::Tool::NewsRetriever.new(api_key: ENV["NEWS_API_KEY"])
|
11
|
+
#
|
12
|
+
NAME = "news_retriever"
|
13
|
+
ANNOTATIONS_PATH = Langchain.root.join("./langchain/tool/#{NAME}/#{NAME}.json").to_path
|
14
|
+
|
15
|
+
def initialize(api_key: ENV["NEWS_API_KEY"])
|
16
|
+
@api_key = api_key
|
17
|
+
end
|
18
|
+
|
19
|
+
# Retrieve all news
|
20
|
+
#
|
21
|
+
# @param q [String] Keywords or phrases to search for in the article title and body.
|
22
|
+
# @param search_in [String] The fields to restrict your q search to. The possible options are: title, description, content.
|
23
|
+
# @param sources [String] A comma-seperated string of identifiers (maximum 20) for the news sources or blogs you want headlines from. Use the /sources endpoint to locate these programmatically or look at the sources index.
|
24
|
+
# @param domains [String] A comma-seperated string of domains (eg bbc.co.uk, techcrunch.com, engadget.com) to restrict the search to.
|
25
|
+
# @param exclude_domains [String] A comma-seperated string of domains (eg bbc.co.uk, techcrunch.com, engadget.com) to remove from the results.
|
26
|
+
# @param from [String] A date and optional time for the oldest article allowed. This should be in ISO 8601 format.
|
27
|
+
# @param to [String] A date and optional time for the newest article allowed. This should be in ISO 8601 format.
|
28
|
+
# @param language [String] The 2-letter ISO-639-1 code of the language you want to get headlines for. Possible options: ar, de, en, es, fr, he, it, nl, no, pt, ru, se, ud, zh.
|
29
|
+
# @param sort_by [String] The order to sort the articles in. Possible options: relevancy, popularity, publishedAt.
|
30
|
+
# @param page_size [Integer] The number of results to return per page. 20 is the API's default, 100 is the maximum. Our default is 5.
|
31
|
+
# @param page [Integer] Use this to page through the results.
|
32
|
+
#
|
33
|
+
# @return [String] JSON response
|
34
|
+
def get_everything(
|
35
|
+
q: nil,
|
36
|
+
search_in: nil,
|
37
|
+
sources: nil,
|
38
|
+
domains: nil,
|
39
|
+
exclude_domains: nil,
|
40
|
+
from: nil,
|
41
|
+
to: nil,
|
42
|
+
language: nil,
|
43
|
+
sort_by: nil,
|
44
|
+
page_size: 5, # The API default is 20 but that's too many.
|
45
|
+
page: nil
|
46
|
+
)
|
47
|
+
Langchain.logger.info("Retrieving all news", for: self.class)
|
48
|
+
|
49
|
+
params = {apiKey: @api_key}
|
50
|
+
params[:q] = q if q
|
51
|
+
params[:searchIn] = search_in if search_in
|
52
|
+
params[:sources] = sources if sources
|
53
|
+
params[:domains] = domains if domains
|
54
|
+
params[:excludeDomains] = exclude_domains if exclude_domains
|
55
|
+
params[:from] = from if from
|
56
|
+
params[:to] = to if to
|
57
|
+
params[:language] = language if language
|
58
|
+
params[:sortBy] = sort_by if sort_by
|
59
|
+
params[:pageSize] = page_size if page_size
|
60
|
+
params[:page] = page if page
|
61
|
+
|
62
|
+
send_request(path: "everything", params: params)
|
63
|
+
end
|
64
|
+
|
65
|
+
# Retrieve top headlines
|
66
|
+
#
|
67
|
+
# @param country [String] The 2-letter ISO 3166-1 code of the country you want to get headlines for. Possible options: ae, ar, at, au, be, bg, br, ca, ch, cn, co, cu, cz, de, eg, fr, gb, gr, hk, hu, id, ie, il, in, it, jp, kr, lt, lv, ma, mx, my, ng, nl, no, nz, ph, pl, pt, ro, rs, ru, sa, se, sg, si, sk, th, tr, tw, ua, us, ve, za.
|
68
|
+
# @param category [String] The category you want to get headlines for. Possible options: business, entertainment, general, health, science, sports, technology.
|
69
|
+
# @param sources [String] A comma-seperated string of identifiers for the news sources or blogs you want headlines from. Use the /sources endpoint to locate these programmatically.
|
70
|
+
# @param q [String] Keywords or a phrase to search for.
|
71
|
+
# @param page_size [Integer] The number of results to return per page. 20 is the API's default, 100 is the maximum. Our default is 5.
|
72
|
+
# @param page [Integer] Use this to page through the results.
|
73
|
+
#
|
74
|
+
# @return [String] JSON response
|
75
|
+
def get_top_headlines(
|
76
|
+
country: nil,
|
77
|
+
category: nil,
|
78
|
+
sources: nil,
|
79
|
+
q: nil,
|
80
|
+
page_size: 5,
|
81
|
+
page: nil
|
82
|
+
)
|
83
|
+
Langchain.logger.info("Retrieving top news headlines", for: self.class)
|
84
|
+
|
85
|
+
params = {apiKey: @api_key}
|
86
|
+
params[:country] = country if country
|
87
|
+
params[:category] = category if category
|
88
|
+
params[:sources] = sources if sources
|
89
|
+
params[:q] = q if q
|
90
|
+
params[:pageSize] = page_size if page_size
|
91
|
+
params[:page] = page if page
|
92
|
+
|
93
|
+
send_request(path: "top-headlines", params: params)
|
94
|
+
end
|
95
|
+
|
96
|
+
# Retrieve news sources
|
97
|
+
#
|
98
|
+
# @param category [String] The category you want to get headlines for. Possible options: business, entertainment, general, health, science, sports, technology.
|
99
|
+
# @param language [String] The 2-letter ISO-639-1 code of the language you want to get headlines for. Possible options: ar, de, en, es, fr, he, it, nl, no, pt, ru, se, ud, zh.
|
100
|
+
# @param country [String] The 2-letter ISO 3166-1 code of the country you want to get headlines for. Possible options: ae, ar, at, au, be, bg, br, ca, ch, cn, co, cu, cz, de, eg, fr, gb, gr, hk, hu, id, ie, il, in, it, jp, kr, lt, lv, ma, mx, my, ng, nl, no, nz, ph, pl, pt, ro, rs, ru, sa, se, sg, si, sk, th, tr, tw, ua, us, ve, za.
|
101
|
+
#
|
102
|
+
# @return [String] JSON response
|
103
|
+
def get_sources(
|
104
|
+
category: nil,
|
105
|
+
language: nil,
|
106
|
+
country: nil
|
107
|
+
)
|
108
|
+
Langchain.logger.info("Retrieving news sources", for: self.class)
|
109
|
+
|
110
|
+
params = {apiKey: @api_key}
|
111
|
+
params[:country] = country if country
|
112
|
+
params[:category] = category if category
|
113
|
+
params[:language] = language if language
|
114
|
+
|
115
|
+
send_request(path: "top-headlines/sources", params: params)
|
116
|
+
end
|
117
|
+
|
118
|
+
private
|
119
|
+
|
120
|
+
def send_request(path:, params:)
|
121
|
+
uri = URI.parse("https://newsapi.org/v2/#{path}?#{URI.encode_www_form(params)}")
|
122
|
+
http = Net::HTTP.new(uri.host, uri.port)
|
123
|
+
http.use_ssl = true
|
124
|
+
|
125
|
+
request = Net::HTTP::Get.new(uri.request_uri)
|
126
|
+
request["Content-Type"] = "application/json"
|
127
|
+
|
128
|
+
response = http.request(request)
|
129
|
+
response.body
|
130
|
+
end
|
131
|
+
end
|
132
|
+
end
|