langchainrb 0.12.1 → 0.13.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +3 -2
- data/lib/langchain/assistants/assistant.rb +75 -20
- data/lib/langchain/assistants/messages/base.rb +16 -0
- data/lib/langchain/assistants/messages/google_gemini_message.rb +90 -0
- data/lib/langchain/assistants/messages/openai_message.rb +74 -0
- data/lib/langchain/assistants/thread.rb +5 -5
- data/lib/langchain/llm/base.rb +2 -1
- data/lib/langchain/llm/google_gemini.rb +67 -0
- data/lib/langchain/llm/google_vertex_ai.rb +68 -112
- data/lib/langchain/llm/response/google_gemini_response.rb +45 -0
- data/lib/langchain/llm/response/openai_response.rb +5 -1
- data/lib/langchain/tool/base.rb +11 -1
- data/lib/langchain/tool/calculator/calculator.json +1 -1
- data/lib/langchain/tool/database/database.json +3 -3
- data/lib/langchain/tool/file_system/file_system.json +3 -3
- data/lib/langchain/tool/news_retriever/news_retriever.json +121 -0
- data/lib/langchain/tool/news_retriever/news_retriever.rb +132 -0
- data/lib/langchain/tool/ruby_code_interpreter/ruby_code_interpreter.json +1 -1
- data/lib/langchain/tool/vectorsearch/vectorsearch.json +1 -1
- data/lib/langchain/tool/weather/weather.json +1 -1
- data/lib/langchain/tool/wikipedia/wikipedia.json +1 -1
- data/lib/langchain/tool/wikipedia/wikipedia.rb +2 -2
- data/lib/langchain/version.rb +1 -1
- data/lib/langchain.rb +3 -0
- metadata +14 -9
- data/lib/langchain/assistants/message.rb +0 -58
- data/lib/langchain/llm/response/google_vertex_ai_response.rb +0 -33
@@ -2,150 +2,106 @@
|
|
2
2
|
|
3
3
|
module Langchain::LLM
|
4
4
|
#
|
5
|
-
# Wrapper around the Google Vertex AI APIs: https://cloud.google.com/vertex-ai
|
5
|
+
# Wrapper around the Google Vertex AI APIs: https://cloud.google.com/vertex-ai
|
6
6
|
#
|
7
7
|
# Gem requirements:
|
8
|
-
# gem "
|
8
|
+
# gem "googleauth"
|
9
9
|
#
|
10
10
|
# Usage:
|
11
|
-
#
|
11
|
+
# llm = Langchain::LLM::GoogleVertexAI.new(project_id: ENV["GOOGLE_VERTEX_AI_PROJECT_ID"], region: "us-central1")
|
12
12
|
#
|
13
|
-
class
|
13
|
+
class GoogleVertexAI < Base
|
14
14
|
DEFAULTS = {
|
15
|
-
temperature: 0.1,
|
15
|
+
temperature: 0.1,
|
16
16
|
max_output_tokens: 1000,
|
17
17
|
top_p: 0.8,
|
18
18
|
top_k: 40,
|
19
19
|
dimensions: 768,
|
20
|
-
|
21
|
-
|
20
|
+
embeddings_model_name: "textembedding-gecko",
|
21
|
+
chat_completion_model_name: "gemini-1.0-pro"
|
22
22
|
}.freeze
|
23
23
|
|
24
|
-
# TODO: Implement token length validation
|
25
|
-
# LENGTH_VALIDATOR = Langchain::Utils::TokenLength::...
|
26
|
-
|
27
24
|
# Google Cloud has a project id and a specific region of deployment.
|
28
25
|
# For GenAI-related things, a safe choice is us-central1.
|
29
|
-
attr_reader :
|
30
|
-
|
31
|
-
def initialize(project_id:, default_options: {})
|
32
|
-
depends_on "google-apis-aiplatform_v1"
|
26
|
+
attr_reader :defaults, :url, :authorizer
|
33
27
|
|
34
|
-
|
35
|
-
|
28
|
+
def initialize(project_id:, region:, default_options: {})
|
29
|
+
depends_on "googleauth"
|
36
30
|
|
37
|
-
@
|
38
|
-
|
39
|
-
|
40
|
-
# For the moment only us-central1 available so no big deal.
|
41
|
-
@client.root_url = "https://#{@region}-aiplatform.googleapis.com/"
|
42
|
-
@client.authorization = Google::Auth.get_application_default
|
31
|
+
@authorizer = ::Google::Auth.get_application_default
|
32
|
+
proj_id = project_id || @authorizer.project_id || @authorizer.quota_project_id
|
33
|
+
@url = "https://#{region}-aiplatform.googleapis.com/v1/projects/#{proj_id}/locations/#{region}/publishers/google/models/"
|
43
34
|
|
44
35
|
@defaults = DEFAULTS.merge(default_options)
|
36
|
+
|
37
|
+
chat_parameters.update(
|
38
|
+
model: {default: @defaults[:chat_completion_model_name]},
|
39
|
+
temperature: {default: @defaults[:temperature]}
|
40
|
+
)
|
41
|
+
chat_parameters.remap(
|
42
|
+
messages: :contents,
|
43
|
+
system: :system_instruction,
|
44
|
+
tool_choice: :tool_config
|
45
|
+
)
|
45
46
|
end
|
46
47
|
|
47
48
|
#
|
48
49
|
# Generate an embedding for a given text
|
49
50
|
#
|
50
51
|
# @param text [String] The text to generate an embedding for
|
51
|
-
# @
|
52
|
+
# @param model [String] ID of the model to use
|
53
|
+
# @return [Langchain::LLM::GoogleGeminiResponse] Response object
|
52
54
|
#
|
53
|
-
def embed(
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
55
|
+
def embed(
|
56
|
+
text:,
|
57
|
+
model: @defaults[:embeddings_model_name]
|
58
|
+
)
|
59
|
+
params = {instances: [{content: text}]}
|
60
|
+
|
61
|
+
response = HTTParty.post(
|
62
|
+
"#{url}#{model}:predict",
|
63
|
+
body: params.to_json,
|
64
|
+
headers: {
|
65
|
+
"Content-Type" => "application/json",
|
66
|
+
"Authorization" => "Bearer #{@authorizer.fetch_access_token!["access_token"]}"
|
67
|
+
}
|
68
|
+
)
|
62
69
|
|
63
|
-
Langchain::LLM::
|
70
|
+
Langchain::LLM::GoogleGeminiResponse.new(response, model: model)
|
64
71
|
end
|
65
72
|
|
73
|
+
# Generate a chat completion for given messages
|
66
74
|
#
|
67
|
-
#
|
68
|
-
#
|
69
|
-
# @param
|
70
|
-
# @param
|
71
|
-
# @
|
72
|
-
#
|
73
|
-
def
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
}
|
82
|
-
|
83
|
-
|
84
|
-
|
75
|
+
# @param messages [Array<Hash>] Input messages
|
76
|
+
# @param model [String] The model that will complete your prompt
|
77
|
+
# @param tools [Array<Hash>] The tools to use
|
78
|
+
# @param tool_choice [String] The tool choice to use
|
79
|
+
# @param system [String] The system instruction to use
|
80
|
+
# @return [Langchain::LLM::GoogleGeminiResponse] Response object
|
81
|
+
def chat(params = {})
|
82
|
+
params[:system] = {parts: [{text: params[:system]}]} if params[:system]
|
83
|
+
params[:tools] = {function_declarations: params[:tools]} if params[:tools]
|
84
|
+
params[:tool_choice] = {function_calling_config: {mode: params[:tool_choice].upcase}} if params[:tool_choice]
|
85
|
+
|
86
|
+
raise ArgumentError.new("messages argument is required") if Array(params[:messages]).empty?
|
87
|
+
|
88
|
+
parameters = chat_parameters.to_params(params)
|
89
|
+
parameters[:generation_config] = {temperature: parameters.delete(:temperature)} if parameters[:temperature]
|
90
|
+
|
91
|
+
uri = URI("#{url}#{parameters[:model]}:generateContent")
|
92
|
+
|
93
|
+
request = Net::HTTP::Post.new(uri)
|
94
|
+
request.content_type = "application/json"
|
95
|
+
request["Authorization"] = "Bearer #{@authorizer.fetch_access_token!["access_token"]}"
|
96
|
+
request.body = parameters.to_json
|
97
|
+
|
98
|
+
response = Net::HTTP.start(uri.hostname, uri.port, use_ssl: uri.scheme == "https") do |http|
|
99
|
+
http.request(request)
|
85
100
|
end
|
86
101
|
|
87
|
-
|
88
|
-
default_params[:max_output_tokens] = params.delete(:max_output_tokens)
|
89
|
-
end
|
90
|
-
|
91
|
-
# to be tested
|
92
|
-
temperature = params.delete(:temperature) || @defaults[:temperature]
|
93
|
-
max_output_tokens = default_params.fetch(:max_output_tokens, @defaults[:max_output_tokens])
|
94
|
-
|
95
|
-
default_params.merge!(params)
|
96
|
-
|
97
|
-
# response = client.generate_text(**default_params)
|
98
|
-
request = Google::Apis::AiplatformV1::GoogleCloudAiplatformV1PredictRequest.new \
|
99
|
-
instances: [{
|
100
|
-
prompt: prompt # key used to be :content, changed to :prompt
|
101
|
-
}],
|
102
|
-
parameters: {
|
103
|
-
temperature: temperature,
|
104
|
-
maxOutputTokens: max_output_tokens,
|
105
|
-
topP: 0.8,
|
106
|
-
topK: 40
|
107
|
-
}
|
108
|
-
|
109
|
-
response = client.predict_project_location_publisher_model \
|
110
|
-
"projects/#{project_id}/locations/us-central1/publishers/google/models/#{@defaults[:completion_model_name]}",
|
111
|
-
request
|
112
|
-
|
113
|
-
Langchain::LLM::GoogleVertexAiResponse.new(response, model: default_params[:model])
|
114
|
-
end
|
102
|
+
parsed_response = JSON.parse(response.body)
|
115
103
|
|
116
|
-
|
117
|
-
# Generate a summarization for a given text
|
118
|
-
#
|
119
|
-
# @param text [String] The text to generate a summarization for
|
120
|
-
# @return [String] The summarization
|
121
|
-
#
|
122
|
-
# TODO(ricc): add params for Temp, topP, topK, MaxTokens and have it default to these 4 values.
|
123
|
-
def summarize(text:)
|
124
|
-
prompt_template = Langchain::Prompt.load_from_path(
|
125
|
-
file_path: Langchain.root.join("langchain/llm/prompts/summarize_template.yaml")
|
126
|
-
)
|
127
|
-
prompt = prompt_template.format(text: text)
|
128
|
-
|
129
|
-
complete(
|
130
|
-
prompt: prompt,
|
131
|
-
# For best temperature, topP, topK, MaxTokens for summarization: see
|
132
|
-
# https://cloud.google.com/vertex-ai/docs/samples/aiplatform-sdk-summarization
|
133
|
-
temperature: 0.2,
|
134
|
-
top_p: 0.95,
|
135
|
-
top_k: 40,
|
136
|
-
# Most models have a context length of 2048 tokens (except for the newest models, which support 4096).
|
137
|
-
max_output_tokens: 256
|
138
|
-
)
|
104
|
+
Langchain::LLM::GoogleGeminiResponse.new(parsed_response, model: parameters[:model])
|
139
105
|
end
|
140
|
-
|
141
|
-
# def chat(...)
|
142
|
-
# https://cloud.google.com/vertex-ai/docs/samples/aiplatform-sdk-chathat
|
143
|
-
# Chat params: https://cloud.google.com/vertex-ai/docs/samples/aiplatform-sdk-chat
|
144
|
-
# \"temperature\": 0.3,\n"
|
145
|
-
# + " \"maxDecodeSteps\": 200,\n"
|
146
|
-
# + " \"topP\": 0.8,\n"
|
147
|
-
# + " \"topK\": 40\n"
|
148
|
-
# + "}";
|
149
|
-
# end
|
150
106
|
end
|
151
107
|
end
|
@@ -0,0 +1,45 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Langchain::LLM
|
4
|
+
class GoogleGeminiResponse < BaseResponse
|
5
|
+
def initialize(raw_response, model: nil)
|
6
|
+
super(raw_response, model: model)
|
7
|
+
end
|
8
|
+
|
9
|
+
def chat_completion
|
10
|
+
raw_response.dig("candidates", 0, "content", "parts", 0, "text")
|
11
|
+
end
|
12
|
+
|
13
|
+
def role
|
14
|
+
raw_response.dig("candidates", 0, "content", "role")
|
15
|
+
end
|
16
|
+
|
17
|
+
def tool_calls
|
18
|
+
if raw_response.dig("candidates", 0, "content") && raw_response.dig("candidates", 0, "content", "parts", 0).has_key?("functionCall")
|
19
|
+
raw_response.dig("candidates", 0, "content", "parts")
|
20
|
+
else
|
21
|
+
[]
|
22
|
+
end
|
23
|
+
end
|
24
|
+
|
25
|
+
def embedding
|
26
|
+
embeddings.first
|
27
|
+
end
|
28
|
+
|
29
|
+
def embeddings
|
30
|
+
[raw_response.dig("predictions", 0, "embeddings", "values")]
|
31
|
+
end
|
32
|
+
|
33
|
+
def prompt_tokens
|
34
|
+
raw_response.dig("usageMetadata", "promptTokenCount")
|
35
|
+
end
|
36
|
+
|
37
|
+
def completion_tokens
|
38
|
+
raw_response.dig("usageMetadata", "candidatesTokenCount")
|
39
|
+
end
|
40
|
+
|
41
|
+
def total_tokens
|
42
|
+
raw_response.dig("usageMetadata", "totalTokenCount")
|
43
|
+
end
|
44
|
+
end
|
45
|
+
end
|
@@ -25,7 +25,11 @@ module Langchain::LLM
|
|
25
25
|
end
|
26
26
|
|
27
27
|
def tool_calls
|
28
|
-
chat_completions
|
28
|
+
if chat_completions.dig(0, "message").has_key?("tool_calls")
|
29
|
+
chat_completions.dig(0, "message", "tool_calls")
|
30
|
+
else
|
31
|
+
[]
|
32
|
+
end
|
29
33
|
end
|
30
34
|
|
31
35
|
def embedding
|
data/lib/langchain/tool/base.rb
CHANGED
@@ -66,11 +66,21 @@ module Langchain::Tool
|
|
66
66
|
|
67
67
|
# Returns the tool as a list of OpenAI formatted functions
|
68
68
|
#
|
69
|
-
# @return [Hash] tool as
|
69
|
+
# @return [Array<Hash>] List of hashes representing the tool as OpenAI formatted functions
|
70
70
|
def to_openai_tools
|
71
71
|
method_annotations
|
72
72
|
end
|
73
73
|
|
74
|
+
# Returns the tool as a list of Google Gemini formatted functions
|
75
|
+
#
|
76
|
+
# @return [Array<Hash>] List of hashes representing the tool as Google Gemini formatted functions
|
77
|
+
def to_google_gemini_tools
|
78
|
+
method_annotations.map do |annotation|
|
79
|
+
# Slice out only the content of the "function" key
|
80
|
+
annotation["function"]
|
81
|
+
end
|
82
|
+
end
|
83
|
+
|
74
84
|
# Return tool's method annotations as JSON
|
75
85
|
#
|
76
86
|
# @return [Hash] Tool's method annotations
|
@@ -2,7 +2,7 @@
|
|
2
2
|
{
|
3
3
|
"type": "function",
|
4
4
|
"function": {
|
5
|
-
"name": "
|
5
|
+
"name": "calculator__execute",
|
6
6
|
"description": "Evaluates a pure math expression or if equation contains non-math characters (e.g.: \"12F in Celsius\") then it uses the google search calculator to evaluate the expression",
|
7
7
|
"parameters": {
|
8
8
|
"type": "object",
|
@@ -2,7 +2,7 @@
|
|
2
2
|
{
|
3
3
|
"type": "function",
|
4
4
|
"function": {
|
5
|
-
"name": "
|
5
|
+
"name": "database__describe_tables",
|
6
6
|
"description": "Database Tool: Returns the schema for a list of tables",
|
7
7
|
"parameters": {
|
8
8
|
"type": "object",
|
@@ -18,7 +18,7 @@
|
|
18
18
|
}, {
|
19
19
|
"type": "function",
|
20
20
|
"function": {
|
21
|
-
"name": "
|
21
|
+
"name": "database__list_tables",
|
22
22
|
"description": "Database Tool: Returns a list of tables in the database",
|
23
23
|
"parameters": {
|
24
24
|
"type": "object",
|
@@ -29,7 +29,7 @@
|
|
29
29
|
}, {
|
30
30
|
"type": "function",
|
31
31
|
"function": {
|
32
|
-
"name": "
|
32
|
+
"name": "database__execute",
|
33
33
|
"description": "Database Tool: Executes a SQL query and returns the results",
|
34
34
|
"parameters": {
|
35
35
|
"type": "object",
|
@@ -2,7 +2,7 @@
|
|
2
2
|
{
|
3
3
|
"type": "function",
|
4
4
|
"function": {
|
5
|
-
"name": "
|
5
|
+
"name": "file_system__list_directory",
|
6
6
|
"description": "File System Tool: Lists out the content of a specified directory",
|
7
7
|
"parameters": {
|
8
8
|
"type": "object",
|
@@ -19,7 +19,7 @@
|
|
19
19
|
{
|
20
20
|
"type": "function",
|
21
21
|
"function": {
|
22
|
-
"name": "
|
22
|
+
"name": "file_system__read_file",
|
23
23
|
"description": "File System Tool: Reads the contents of a file",
|
24
24
|
"parameters": {
|
25
25
|
"type": "object",
|
@@ -36,7 +36,7 @@
|
|
36
36
|
{
|
37
37
|
"type": "function",
|
38
38
|
"function": {
|
39
|
-
"name": "
|
39
|
+
"name": "file_system__write_to_file",
|
40
40
|
"description": "File System Tool: Write content to a file",
|
41
41
|
"parameters": {
|
42
42
|
"type": "object",
|
@@ -0,0 +1,121 @@
|
|
1
|
+
[
|
2
|
+
{
|
3
|
+
"type": "function",
|
4
|
+
"function": {
|
5
|
+
"name": "news_retriever__get_everything",
|
6
|
+
"description": "News Retriever: Search through millions of articles from over 150,000 large and small news sources and blogs.",
|
7
|
+
"parameters": {
|
8
|
+
"type": "object",
|
9
|
+
"properties": {
|
10
|
+
"q": {
|
11
|
+
"type": "string",
|
12
|
+
"description": "Keywords or phrases to search for in the article title and body. Surround phrases with quotes (\") for exact match. Alternatively you can use the AND / OR / NOT keywords, and optionally group these with parenthesis. Must be URL-encoded."
|
13
|
+
},
|
14
|
+
"search_in": {
|
15
|
+
"type": "string",
|
16
|
+
"description": "The fields to restrict your q search to.",
|
17
|
+
"enum": ["title", "description", "content"]
|
18
|
+
},
|
19
|
+
"sources": {
|
20
|
+
"type": "string",
|
21
|
+
"description": "A comma-seperated string of identifiers (maximum 20) for the news sources or blogs you want headlines from. Use the /sources endpoint to locate these programmatically or look at the sources index."
|
22
|
+
},
|
23
|
+
"domains": {
|
24
|
+
"type": "string",
|
25
|
+
"description": "A comma-seperated string of domains (eg bbc.co.uk, techcrunch.com, engadget.com) to restrict the search to."
|
26
|
+
},
|
27
|
+
"exclude_domains": {
|
28
|
+
"type": "string",
|
29
|
+
"description": "A comma-seperated string of domains (eg bbc.co.uk, techcrunch.com, engadget.com) to remove from the results."
|
30
|
+
},
|
31
|
+
"from": {
|
32
|
+
"type": "string",
|
33
|
+
"description": "A date and optional time for the oldest article allowed. This should be in ISO 8601 format."
|
34
|
+
},
|
35
|
+
"to": {
|
36
|
+
"type": "string",
|
37
|
+
"description": "A date and optional time for the newest article allowed. This should be in ISO 8601 format."
|
38
|
+
},
|
39
|
+
"language": {
|
40
|
+
"type": "string",
|
41
|
+
"description": "The 2-letter ISO-639-1 code of the language you want to get headlines for.",
|
42
|
+
"enum": ["ar", "de", "en", "es", "fr", "he", "it", "nl", "no", "pt", "ru", "sv", "ud", "zh"]
|
43
|
+
},
|
44
|
+
"sort_by": {
|
45
|
+
"type": "string",
|
46
|
+
"description": "The order to sort the articles in.",
|
47
|
+
"enum": ["relevancy", "popularity", "publishedAt"]
|
48
|
+
},
|
49
|
+
"page_size": {
|
50
|
+
"type": "integer",
|
51
|
+
"description": "The number of results to return per page (request). 5 is the default, 100 is the maximum."
|
52
|
+
},
|
53
|
+
"page": {
|
54
|
+
"type": "integer",
|
55
|
+
"description": "Use this to page through the results if the total results found is greater than the page size."
|
56
|
+
}
|
57
|
+
}
|
58
|
+
}
|
59
|
+
}
|
60
|
+
},
|
61
|
+
{
|
62
|
+
"type": "function",
|
63
|
+
"function": {
|
64
|
+
"name": "news_retriever__get_top_headlines",
|
65
|
+
"description": "News Retriever: Provides live top and breaking headlines for a country, specific category in a country, single source, or multiple sources. You can also search with keywords. Articles are sorted by the earliest date published first.",
|
66
|
+
"parameters": {
|
67
|
+
"type": "object",
|
68
|
+
"properties": {
|
69
|
+
"country": {
|
70
|
+
"type": "string",
|
71
|
+
"description": "The 2-letter ISO 3166-1 code of the country you want to get headlines for."
|
72
|
+
},
|
73
|
+
"category": {
|
74
|
+
"type": "string",
|
75
|
+
"description": "The category you want to get headlines for.",
|
76
|
+
"enum": ["business", "entertainment", "general", "health", "science", "sports", "technology"]
|
77
|
+
},
|
78
|
+
"q": {
|
79
|
+
"type": "string",
|
80
|
+
"description": "Keywords or a phrase to search for."
|
81
|
+
},
|
82
|
+
"page_size": {
|
83
|
+
"type": "integer",
|
84
|
+
"description": "The number of results to return per page (request). 5 is the default, 100 is the maximum."
|
85
|
+
},
|
86
|
+
"page": {
|
87
|
+
"type": "integer",
|
88
|
+
"description": "Use this to page through the results if the total results found is greater than the page size."
|
89
|
+
}
|
90
|
+
}
|
91
|
+
}
|
92
|
+
}
|
93
|
+
},
|
94
|
+
{
|
95
|
+
"type": "function",
|
96
|
+
"function": {
|
97
|
+
"name": "news_retriever__get_sources",
|
98
|
+
"description": "News Retriever: This endpoint returns the subset of news publishers that top headlines (/v2/top-headlines) are available from. It's mainly a convenience endpoint that you can use to keep track of the publishers available on the API, and you can pipe it straight through to your users.",
|
99
|
+
"parameters": {
|
100
|
+
"type": "object",
|
101
|
+
"properties": {
|
102
|
+
"country": {
|
103
|
+
"type": "string",
|
104
|
+
"description": "The 2-letter ISO 3166-1 code of the country you want to get headlines for. Default: all countries.",
|
105
|
+
"enum": ["ae", "ar", "at", "au", "be", "bg", "br", "ca", "ch", "cn", "co", "cu", "cz", "de", "eg", "fr", "gb", "gr", "hk", "hu", "id", "ie", "il", "in", "it", "jp", "kr", "lt", "lv", "ma", "mx", "my", "ng", "nl", "no", "nz", "ph", "pl", "pt", "ro", "rs", "ru", "sa", "se", "sg", "si", "sk", "th", "tr", "tw", "ua", "us", "ve", "za"]
|
106
|
+
},
|
107
|
+
"category": {
|
108
|
+
"type": "string",
|
109
|
+
"description": "The category you want to get headlines for. Default: all categories.",
|
110
|
+
"enum": ["business", "entertainment", "general", "health", "science", "sports", "technology"]
|
111
|
+
},
|
112
|
+
"language": {
|
113
|
+
"type": "string",
|
114
|
+
"description": "The 2-letter ISO-639-1 code of the language you want to get headlines for.",
|
115
|
+
"enum": ["ar", "de", "en", "es", "fr", "he", "it", "nl", "no", "pt", "ru", "sv", "ud", "zh"]
|
116
|
+
}
|
117
|
+
}
|
118
|
+
}
|
119
|
+
}
|
120
|
+
}
|
121
|
+
]
|
@@ -0,0 +1,132 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Langchain::Tool
|
4
|
+
class NewsRetriever < Base
|
5
|
+
#
|
6
|
+
# A tool that retrieves latest news from various sources via https://newsapi.org/.
|
7
|
+
# An API key needs to be obtained from https://newsapi.org/ to use this tool.
|
8
|
+
#
|
9
|
+
# Usage:
|
10
|
+
# news_retriever = Langchain::Tool::NewsRetriever.new(api_key: ENV["NEWS_API_KEY"])
|
11
|
+
#
|
12
|
+
NAME = "news_retriever"
|
13
|
+
ANNOTATIONS_PATH = Langchain.root.join("./langchain/tool/#{NAME}/#{NAME}.json").to_path
|
14
|
+
|
15
|
+
def initialize(api_key: ENV["NEWS_API_KEY"])
|
16
|
+
@api_key = api_key
|
17
|
+
end
|
18
|
+
|
19
|
+
# Retrieve all news
|
20
|
+
#
|
21
|
+
# @param q [String] Keywords or phrases to search for in the article title and body.
|
22
|
+
# @param search_in [String] The fields to restrict your q search to. The possible options are: title, description, content.
|
23
|
+
# @param sources [String] A comma-seperated string of identifiers (maximum 20) for the news sources or blogs you want headlines from. Use the /sources endpoint to locate these programmatically or look at the sources index.
|
24
|
+
# @param domains [String] A comma-seperated string of domains (eg bbc.co.uk, techcrunch.com, engadget.com) to restrict the search to.
|
25
|
+
# @param exclude_domains [String] A comma-seperated string of domains (eg bbc.co.uk, techcrunch.com, engadget.com) to remove from the results.
|
26
|
+
# @param from [String] A date and optional time for the oldest article allowed. This should be in ISO 8601 format.
|
27
|
+
# @param to [String] A date and optional time for the newest article allowed. This should be in ISO 8601 format.
|
28
|
+
# @param language [String] The 2-letter ISO-639-1 code of the language you want to get headlines for. Possible options: ar, de, en, es, fr, he, it, nl, no, pt, ru, se, ud, zh.
|
29
|
+
# @param sort_by [String] The order to sort the articles in. Possible options: relevancy, popularity, publishedAt.
|
30
|
+
# @param page_size [Integer] The number of results to return per page. 20 is the API's default, 100 is the maximum. Our default is 5.
|
31
|
+
# @param page [Integer] Use this to page through the results.
|
32
|
+
#
|
33
|
+
# @return [String] JSON response
|
34
|
+
def get_everything(
|
35
|
+
q: nil,
|
36
|
+
search_in: nil,
|
37
|
+
sources: nil,
|
38
|
+
domains: nil,
|
39
|
+
exclude_domains: nil,
|
40
|
+
from: nil,
|
41
|
+
to: nil,
|
42
|
+
language: nil,
|
43
|
+
sort_by: nil,
|
44
|
+
page_size: 5, # The API default is 20 but that's too many.
|
45
|
+
page: nil
|
46
|
+
)
|
47
|
+
Langchain.logger.info("Retrieving all news", for: self.class)
|
48
|
+
|
49
|
+
params = {apiKey: @api_key}
|
50
|
+
params[:q] = q if q
|
51
|
+
params[:searchIn] = search_in if search_in
|
52
|
+
params[:sources] = sources if sources
|
53
|
+
params[:domains] = domains if domains
|
54
|
+
params[:excludeDomains] = exclude_domains if exclude_domains
|
55
|
+
params[:from] = from if from
|
56
|
+
params[:to] = to if to
|
57
|
+
params[:language] = language if language
|
58
|
+
params[:sortBy] = sort_by if sort_by
|
59
|
+
params[:pageSize] = page_size if page_size
|
60
|
+
params[:page] = page if page
|
61
|
+
|
62
|
+
send_request(path: "everything", params: params)
|
63
|
+
end
|
64
|
+
|
65
|
+
# Retrieve top headlines
|
66
|
+
#
|
67
|
+
# @param country [String] The 2-letter ISO 3166-1 code of the country you want to get headlines for. Possible options: ae, ar, at, au, be, bg, br, ca, ch, cn, co, cu, cz, de, eg, fr, gb, gr, hk, hu, id, ie, il, in, it, jp, kr, lt, lv, ma, mx, my, ng, nl, no, nz, ph, pl, pt, ro, rs, ru, sa, se, sg, si, sk, th, tr, tw, ua, us, ve, za.
|
68
|
+
# @param category [String] The category you want to get headlines for. Possible options: business, entertainment, general, health, science, sports, technology.
|
69
|
+
# @param sources [String] A comma-seperated string of identifiers for the news sources or blogs you want headlines from. Use the /sources endpoint to locate these programmatically.
|
70
|
+
# @param q [String] Keywords or a phrase to search for.
|
71
|
+
# @param page_size [Integer] The number of results to return per page. 20 is the API's default, 100 is the maximum. Our default is 5.
|
72
|
+
# @param page [Integer] Use this to page through the results.
|
73
|
+
#
|
74
|
+
# @return [String] JSON response
|
75
|
+
def get_top_headlines(
|
76
|
+
country: nil,
|
77
|
+
category: nil,
|
78
|
+
sources: nil,
|
79
|
+
q: nil,
|
80
|
+
page_size: 5,
|
81
|
+
page: nil
|
82
|
+
)
|
83
|
+
Langchain.logger.info("Retrieving top news headlines", for: self.class)
|
84
|
+
|
85
|
+
params = {apiKey: @api_key}
|
86
|
+
params[:country] = country if country
|
87
|
+
params[:category] = category if category
|
88
|
+
params[:sources] = sources if sources
|
89
|
+
params[:q] = q if q
|
90
|
+
params[:pageSize] = page_size if page_size
|
91
|
+
params[:page] = page if page
|
92
|
+
|
93
|
+
send_request(path: "top-headlines", params: params)
|
94
|
+
end
|
95
|
+
|
96
|
+
# Retrieve news sources
|
97
|
+
#
|
98
|
+
# @param category [String] The category you want to get headlines for. Possible options: business, entertainment, general, health, science, sports, technology.
|
99
|
+
# @param language [String] The 2-letter ISO-639-1 code of the language you want to get headlines for. Possible options: ar, de, en, es, fr, he, it, nl, no, pt, ru, se, ud, zh.
|
100
|
+
# @param country [String] The 2-letter ISO 3166-1 code of the country you want to get headlines for. Possible options: ae, ar, at, au, be, bg, br, ca, ch, cn, co, cu, cz, de, eg, fr, gb, gr, hk, hu, id, ie, il, in, it, jp, kr, lt, lv, ma, mx, my, ng, nl, no, nz, ph, pl, pt, ro, rs, ru, sa, se, sg, si, sk, th, tr, tw, ua, us, ve, za.
|
101
|
+
#
|
102
|
+
# @return [String] JSON response
|
103
|
+
def get_sources(
|
104
|
+
category: nil,
|
105
|
+
language: nil,
|
106
|
+
country: nil
|
107
|
+
)
|
108
|
+
Langchain.logger.info("Retrieving news sources", for: self.class)
|
109
|
+
|
110
|
+
params = {apiKey: @api_key}
|
111
|
+
params[:country] = country if country
|
112
|
+
params[:category] = category if category
|
113
|
+
params[:language] = language if language
|
114
|
+
|
115
|
+
send_request(path: "top-headlines/sources", params: params)
|
116
|
+
end
|
117
|
+
|
118
|
+
private
|
119
|
+
|
120
|
+
def send_request(path:, params:)
|
121
|
+
uri = URI.parse("https://newsapi.org/v2/#{path}?#{URI.encode_www_form(params)}")
|
122
|
+
http = Net::HTTP.new(uri.host, uri.port)
|
123
|
+
http.use_ssl = true
|
124
|
+
|
125
|
+
request = Net::HTTP::Get.new(uri.request_uri)
|
126
|
+
request["Content-Type"] = "application/json"
|
127
|
+
|
128
|
+
response = http.request(request)
|
129
|
+
response.body
|
130
|
+
end
|
131
|
+
end
|
132
|
+
end
|