cohere-ruby 0.9.10 → 1.0.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/.env.example +1 -0
- data/CHANGELOG.md +8 -0
- data/Gemfile +1 -0
- data/Gemfile.lock +3 -1
- data/README.md +65 -39
- data/lib/cohere/client.rb +100 -67
- data/lib/cohere/version.rb +1 -1
- metadata +4 -3
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 0c65ac65e6fa623709455907a71697f721311d2933e96a10a5c49f907f356f02
|
4
|
+
data.tar.gz: 553fd85c32a2a99b2a8d1a75f1a2aa0a0c9abec290f305ed26caa855515f717b
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: bb7dd10864d789f91225f9370011e622dc6353a55a7039b8ef509610980c1503c3916f210189361cc1dcc20cb1db880d7fd4d80f967867fdcecd6803e85e8bdb
|
7
|
+
data.tar.gz: 89eb351c192abfda2c905810210b21be5d6a5b18aa6e3d5a4ac28efaba9fc23e1a0525510b51a4399c5e7086fab12a854035388806c4401122c3463277941990
|
data/.env.example
ADDED
@@ -0,0 +1 @@
|
|
1
|
+
COHERE_API_KEY=
|
data/CHANGELOG.md
CHANGED
data/Gemfile
CHANGED
data/Gemfile.lock
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
PATH
|
2
2
|
remote: .
|
3
3
|
specs:
|
4
|
-
cohere-ruby (0.
|
4
|
+
cohere-ruby (1.0.1)
|
5
5
|
faraday (>= 2.0.1, < 3.0)
|
6
6
|
|
7
7
|
GEM
|
@@ -9,6 +9,7 @@ GEM
|
|
9
9
|
specs:
|
10
10
|
ast (2.4.2)
|
11
11
|
diff-lcs (1.5.0)
|
12
|
+
dotenv (2.8.1)
|
12
13
|
faraday (2.7.10)
|
13
14
|
faraday-net_http (>= 2.0, < 3.1)
|
14
15
|
ruby2_keywords (>= 0.0.4)
|
@@ -74,6 +75,7 @@ PLATFORMS
|
|
74
75
|
|
75
76
|
DEPENDENCIES
|
76
77
|
cohere-ruby!
|
78
|
+
dotenv (~> 2.8.1)
|
77
79
|
rake (~> 13.0)
|
78
80
|
rspec (~> 3.0)
|
79
81
|
standard (~> 1.28.0)
|
data/README.md
CHANGED
@@ -1,19 +1,19 @@
|
|
1
1
|
# Cohere
|
2
2
|
|
3
3
|
<p>
|
4
|
-
<img alt='
|
4
|
+
<img alt='Cohere logo' src='https://static.wikia.nocookie.net/logopedia/images/d/d4/Cohere_2023.svg/revision/latest?cb=20230419182227' height='50' />
|
5
5
|
+
|
6
6
|
<img alt='Ruby logo' src='https://user-images.githubusercontent.com/541665/230231593-43861278-4550-421d-a543-fd3553aac4f6.png' height='40' />
|
7
7
|
</p>
|
8
8
|
|
9
9
|
Cohere API client for Ruby.
|
10
10
|
|
11
|
-
Part of the [Langchain.rb](https://github.com/
|
11
|
+
Part of the [Langchain.rb](https://github.com/patterns-ai-core/langchainrb) stack.
|
12
12
|
|
13
|
-
![Tests status](https://github.com/
|
13
|
+
![Tests status](https://github.com/patterns-ai-core/cohere-ruby/actions/workflows/ci.yml/badge.svg)
|
14
14
|
[![Gem Version](https://badge.fury.io/rb/cohere-ruby.svg)](https://badge.fury.io/rb/cohere-ruby)
|
15
15
|
[![Docs](http://img.shields.io/badge/yard-docs-blue.svg)](http://rubydoc.info/gems/cohere-ruby)
|
16
|
-
[![License](https://img.shields.io/badge/license-MIT-green.svg)](https://github.com/
|
16
|
+
[![License](https://img.shields.io/badge/license-MIT-green.svg)](https://github.com/patterns-ai-core/cohere-ruby/blob/main/LICENSE.txt)
|
17
17
|
[![](https://dcbadge.vercel.app/api/server/WDARp7J2n8?compact=true&style=flat)](https://discord.gg/WDARp7J2n8)
|
18
18
|
|
19
19
|
## Installation
|
@@ -42,7 +42,7 @@ client = Cohere::Client.new(
|
|
42
42
|
|
43
43
|
```ruby
|
44
44
|
client.generate(
|
45
|
-
|
45
|
+
prompt: "Once upon a time in a magical land called"
|
46
46
|
)
|
47
47
|
```
|
48
48
|
|
@@ -50,14 +50,18 @@ client.generate(
|
|
50
50
|
|
51
51
|
```ruby
|
52
52
|
client.chat(
|
53
|
-
|
53
|
+
model: "command-r-plus-08-2024",
|
54
|
+
messages: [{role:"user", content: "Hey! How are you?"}]
|
54
55
|
)
|
55
56
|
```
|
56
57
|
|
57
58
|
`chat` supports a streaming option. You can pass a block to the `chat` method and it will yield a new chunk as soon as it is received.
|
58
59
|
|
59
60
|
```ruby
|
60
|
-
client.chat(
|
61
|
+
client.chat(
|
62
|
+
model: "command-r-plus-08-2024",
|
63
|
+
messages: [{role:"user", content: "Hey! How are you?"}]
|
64
|
+
) do |chunk, overall_received_bytes|
|
61
65
|
puts "Received #{overall_received_bytes} bytes: #{chunk.force_encoding(Encoding::UTF_8)}"
|
62
66
|
end
|
63
67
|
```
|
@@ -68,35 +72,54 @@ end
|
|
68
72
|
|
69
73
|
```ruby
|
70
74
|
tools = [
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
75
|
+
{
|
76
|
+
name: "query_daily_sales_report",
|
77
|
+
description: "Connects to a database to retrieve overall sales volumes and sales information for a given day.",
|
78
|
+
parameter_definitions: {
|
79
|
+
day: {
|
80
|
+
description: "Retrieves sales data for this day, formatted as YYYY-MM-DD.",
|
81
|
+
type: "str",
|
82
|
+
required: true
|
83
|
+
}
|
84
|
+
}
|
85
|
+
}
|
82
86
|
]
|
83
87
|
|
84
88
|
message = "Can you provide a sales summary for 29th September 2023, and also give me some details about the products in the 'Electronics' category, for example their prices and stock levels?"
|
85
89
|
|
86
90
|
client.chat(
|
87
91
|
model: model,
|
88
|
-
|
89
|
-
tools: tools
|
92
|
+
messages: [{ role:"user", content: message }],
|
93
|
+
tools: tools
|
90
94
|
)
|
91
95
|
```
|
92
96
|
|
93
|
-
|
94
|
-
|
95
97
|
### Embed
|
96
98
|
|
97
99
|
```ruby
|
98
100
|
client.embed(
|
99
|
-
|
101
|
+
model: "embed-english-v3.0",
|
102
|
+
texts: ["hello", "goodbye"],
|
103
|
+
input_type: "classification",
|
104
|
+
embedding_types: ["float"]
|
105
|
+
)
|
106
|
+
```
|
107
|
+
|
108
|
+
### Rerank
|
109
|
+
|
110
|
+
```ruby
|
111
|
+
docs = [
|
112
|
+
"Carson City is the capital city of the American state of Nevada.",
|
113
|
+
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
|
114
|
+
"Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.",
|
115
|
+
"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.",
|
116
|
+
"Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.",
|
117
|
+
]
|
118
|
+
|
119
|
+
client.rerank(
|
120
|
+
model: "rerank-english-v3.0",
|
121
|
+
query: "What is the capital of the United States?",
|
122
|
+
documents: docs
|
100
123
|
)
|
101
124
|
```
|
102
125
|
|
@@ -104,16 +127,16 @@ client.embed(
|
|
104
127
|
|
105
128
|
```ruby
|
106
129
|
examples = [
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
130
|
+
{ text: "Dermatologists don't like her!", label: "Spam" },
|
131
|
+
{ text: "Hello, open to this?", label: "Spam" },
|
132
|
+
{ text: "I need help please wire me $1000 right now", label: "Spam" },
|
133
|
+
{ text: "Nice to know you ;)", label: "Spam" },
|
134
|
+
{ text: "Please help me?", label: "Spam" },
|
135
|
+
{ text: "Your parcel will be delivered today", label: "Not spam" },
|
136
|
+
{ text: "Review changes to our Terms and Conditions", label: "Not spam" },
|
137
|
+
{ text: "Weekly sync notes", label: "Not spam" },
|
138
|
+
{ text: "Re: Follow up from today's meeting", label: "Not spam" },
|
139
|
+
{ text: "Pre-read for tomorrow", label: "Not spam" }
|
117
140
|
]
|
118
141
|
|
119
142
|
inputs = [
|
@@ -122,8 +145,9 @@ inputs = [
|
|
122
145
|
]
|
123
146
|
|
124
147
|
client.classify(
|
125
|
-
|
126
|
-
|
148
|
+
model: "embed-multilingual-v2.0",
|
149
|
+
inputs: inputs,
|
150
|
+
examples: examples
|
127
151
|
)
|
128
152
|
```
|
129
153
|
|
@@ -131,7 +155,8 @@ client.classify(
|
|
131
155
|
|
132
156
|
```ruby
|
133
157
|
client.tokenize(
|
134
|
-
|
158
|
+
model: "command-r-plus-08-2024",
|
159
|
+
text: "Hello, world!"
|
135
160
|
)
|
136
161
|
```
|
137
162
|
|
@@ -139,7 +164,8 @@ client.tokenize(
|
|
139
164
|
|
140
165
|
```ruby
|
141
166
|
client.detokenize(
|
142
|
-
|
167
|
+
model: "command-r-plus-08-2024",
|
168
|
+
tokens: [33555, 1114, 34]
|
143
169
|
)
|
144
170
|
```
|
145
171
|
|
@@ -147,7 +173,7 @@ client.detokenize(
|
|
147
173
|
|
148
174
|
```ruby
|
149
175
|
client.detect_language(
|
150
|
-
|
176
|
+
texts: ["Здравствуй, Мир"]
|
151
177
|
)
|
152
178
|
```
|
153
179
|
|
@@ -155,7 +181,7 @@ client.detect_language(
|
|
155
181
|
|
156
182
|
```ruby
|
157
183
|
client.summarize(
|
158
|
-
|
184
|
+
text: "..."
|
159
185
|
)
|
160
186
|
```
|
161
187
|
|
data/lib/cohere/client.rb
CHANGED
@@ -6,62 +6,56 @@ module Cohere
|
|
6
6
|
class Client
|
7
7
|
attr_reader :api_key, :connection
|
8
8
|
|
9
|
-
ENDPOINT_URL = "https://api.cohere.ai/v1"
|
10
|
-
|
11
9
|
def initialize(api_key:, timeout: nil)
|
12
10
|
@api_key = api_key
|
13
11
|
@timeout = timeout
|
14
12
|
end
|
15
13
|
|
14
|
+
# Generates a text response to a user message and streams it down, token by token
|
16
15
|
def chat(
|
17
|
-
|
18
|
-
|
16
|
+
model:,
|
17
|
+
messages:,
|
19
18
|
stream: false,
|
20
|
-
|
21
|
-
preamble_override: nil,
|
22
|
-
chat_history: [],
|
23
|
-
conversation_id: nil,
|
24
|
-
prompt_truncation: nil,
|
25
|
-
connectors: [],
|
26
|
-
search_queries_only: false,
|
19
|
+
tools: [],
|
27
20
|
documents: [],
|
28
|
-
|
29
|
-
|
21
|
+
citation_options: nil,
|
22
|
+
response_format: nil,
|
23
|
+
safety_mode: nil,
|
30
24
|
max_tokens: nil,
|
31
|
-
|
32
|
-
|
25
|
+
stop_sequences: nil,
|
26
|
+
temperature: nil,
|
33
27
|
seed: nil,
|
34
28
|
frequency_penalty: nil,
|
35
29
|
presence_penalty: nil,
|
36
|
-
|
30
|
+
k: nil,
|
31
|
+
p: nil,
|
32
|
+
logprops: nil,
|
37
33
|
&block
|
38
34
|
)
|
39
|
-
response =
|
35
|
+
response = v2_connection.post("chat") do |req|
|
40
36
|
req.body = {}
|
41
37
|
|
42
|
-
req.body[:
|
43
|
-
req.body[:
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
req.body[:
|
49
|
-
req.body[:preamble_override] = preamble_override if preamble_override
|
50
|
-
req.body[:chat_history] = chat_history if chat_history
|
51
|
-
req.body[:conversation_id] = conversation_id if conversation_id
|
52
|
-
req.body[:prompt_truncation] = prompt_truncation if prompt_truncation
|
53
|
-
req.body[:connectors] = connectors if connectors
|
54
|
-
req.body[:search_queries_only] = search_queries_only if search_queries_only
|
55
|
-
req.body[:documents] = documents if documents
|
56
|
-
req.body[:citation_quality] = citation_quality if citation_quality
|
57
|
-
req.body[:temperature] = temperature if temperature
|
38
|
+
req.body[:model] = model
|
39
|
+
req.body[:messages] = messages if messages
|
40
|
+
req.body[:tools] = tools if tools.any?
|
41
|
+
req.body[:documents] = documents if documents.any?
|
42
|
+
req.body[:citation_options] = citation_options if citation_options
|
43
|
+
req.body[:response_format] = response_format if response_format
|
44
|
+
req.body[:safety_mode] = safety_mode if safety_mode
|
58
45
|
req.body[:max_tokens] = max_tokens if max_tokens
|
59
|
-
req.body[:
|
60
|
-
req.body[:
|
46
|
+
req.body[:stop_sequences] = stop_sequences if stop_sequences
|
47
|
+
req.body[:temperature] = temperature if temperature
|
61
48
|
req.body[:seed] = seed if seed
|
62
49
|
req.body[:frequency_penalty] = frequency_penalty if frequency_penalty
|
63
50
|
req.body[:presence_penalty] = presence_penalty if presence_penalty
|
64
|
-
req.body[:
|
51
|
+
req.body[:k] = k if k
|
52
|
+
req.body[:p] = p if p
|
53
|
+
req.body[:logprops] = logprops if logprops
|
54
|
+
|
55
|
+
if stream || block
|
56
|
+
req.body[:stream] = true
|
57
|
+
req.options.on_data = block if block
|
58
|
+
end
|
65
59
|
end
|
66
60
|
response.body
|
67
61
|
end
|
@@ -84,7 +78,7 @@ module Cohere
|
|
84
78
|
logit_bias: nil,
|
85
79
|
truncate: nil
|
86
80
|
)
|
87
|
-
response =
|
81
|
+
response = v1_connection.post("generate") do |req|
|
88
82
|
req.body = {prompt: prompt}
|
89
83
|
req.body[:model] = model if model
|
90
84
|
req.body[:num_generations] = num_generations if num_generations
|
@@ -104,56 +98,90 @@ module Cohere
|
|
104
98
|
response.body
|
105
99
|
end
|
106
100
|
|
101
|
+
# This endpoint returns text embeddings. An embedding is a list of floating point numbers that captures semantic information about the text that it represents.
|
107
102
|
def embed(
|
108
|
-
|
109
|
-
|
110
|
-
|
103
|
+
model:,
|
104
|
+
input_type:,
|
105
|
+
embedding_types:,
|
106
|
+
texts: nil,
|
107
|
+
images: nil,
|
111
108
|
truncate: nil
|
112
109
|
)
|
113
|
-
response =
|
114
|
-
req.body = {
|
115
|
-
|
116
|
-
|
110
|
+
response = v2_connection.post("embed") do |req|
|
111
|
+
req.body = {
|
112
|
+
model: model,
|
113
|
+
input_type: input_type,
|
114
|
+
embedding_types: embedding_types
|
115
|
+
}
|
116
|
+
req.body[:texts] = texts if texts
|
117
|
+
req.body[:images] = images if images
|
117
118
|
req.body[:truncate] = truncate if truncate
|
118
119
|
end
|
119
120
|
response.body
|
120
121
|
end
|
121
122
|
|
123
|
+
# This endpoint takes in a query and a list of texts and produces an ordered array with each text assigned a relevance score.
|
124
|
+
def rerank(
|
125
|
+
model:,
|
126
|
+
query:,
|
127
|
+
documents:,
|
128
|
+
top_n: nil,
|
129
|
+
rank_fields: nil,
|
130
|
+
return_documents: nil,
|
131
|
+
max_chunks_per_doc: nil
|
132
|
+
)
|
133
|
+
response = v2_connection.post("rerank") do |req|
|
134
|
+
req.body = {
|
135
|
+
model: model,
|
136
|
+
query: query,
|
137
|
+
documents: documents
|
138
|
+
}
|
139
|
+
req.body[:top_n] = top_n if top_n
|
140
|
+
req.body[:rank_fields] = rank_fields if rank_fields
|
141
|
+
req.body[:return_documents] = return_documents if return_documents
|
142
|
+
req.body[:max_chunks_per_doc] = max_chunks_per_doc if max_chunks_per_doc
|
143
|
+
end
|
144
|
+
response.body
|
145
|
+
end
|
146
|
+
|
147
|
+
# This endpoint makes a prediction about which label fits the specified text inputs best.
|
122
148
|
def classify(
|
149
|
+
model:,
|
123
150
|
inputs:,
|
124
|
-
examples
|
125
|
-
|
126
|
-
present: nil,
|
151
|
+
examples: nil,
|
152
|
+
preset: nil,
|
127
153
|
truncate: nil
|
128
154
|
)
|
129
|
-
response =
|
155
|
+
response = v1_connection.post("classify") do |req|
|
130
156
|
req.body = {
|
131
|
-
|
132
|
-
|
157
|
+
model: model,
|
158
|
+
inputs: inputs
|
133
159
|
}
|
134
|
-
req.body[:
|
135
|
-
req.body[:
|
160
|
+
req.body[:examples] = examples if examples
|
161
|
+
req.body[:preset] = preset if preset
|
136
162
|
req.body[:truncate] = truncate if truncate
|
137
163
|
end
|
138
164
|
response.body
|
139
165
|
end
|
140
166
|
|
141
|
-
|
142
|
-
|
143
|
-
|
167
|
+
# This endpoint splits input text into smaller units called tokens using byte-pair encoding (BPE).
|
168
|
+
def tokenize(text:, model:)
|
169
|
+
response = v1_connection.post("tokenize") do |req|
|
170
|
+
req.body = {text: text, model: model}
|
144
171
|
end
|
145
172
|
response.body
|
146
173
|
end
|
147
174
|
|
148
|
-
|
149
|
-
|
150
|
-
|
175
|
+
# This endpoint takes tokens using byte-pair encoding and returns their text representation.
|
176
|
+
def detokenize(tokens:, model:)
|
177
|
+
response = v1_connection.post("detokenize") do |req|
|
178
|
+
req.body = {tokens: tokens, model: model}
|
151
179
|
end
|
152
180
|
response.body
|
153
181
|
end
|
154
182
|
|
155
183
|
def detect_language(texts:)
|
156
|
-
response =
|
184
|
+
response = v1_connection.post("detect-language") do |req|
|
157
185
|
req.body = {texts: texts}
|
158
186
|
end
|
159
187
|
response.body
|
@@ -168,7 +196,7 @@ module Cohere
|
|
168
196
|
temperature: nil,
|
169
197
|
additional_command: nil
|
170
198
|
)
|
171
|
-
response =
|
199
|
+
response = v1_connection.post("summarize") do |req|
|
172
200
|
req.body = {text: text}
|
173
201
|
req.body[:length] = length if length
|
174
202
|
req.body[:format] = format if format
|
@@ -182,17 +210,22 @@ module Cohere
|
|
182
210
|
|
183
211
|
private
|
184
212
|
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
213
|
+
def v1_connection
|
214
|
+
@connection ||= Faraday.new(url: "https://api.cohere.ai/v1", request: {timeout: @timeout}) do |faraday|
|
215
|
+
faraday.request :authorization, :Bearer, api_key
|
216
|
+
faraday.request :json
|
217
|
+
faraday.response :json, content_type: /\bjson$/
|
218
|
+
faraday.adapter Faraday.default_adapter
|
219
|
+
end
|
220
|
+
end
|
221
|
+
|
222
|
+
def v2_connection
|
223
|
+
@connection ||= Faraday.new(url: "https://api.cohere.com/v2", request: {timeout: @timeout}) do |faraday|
|
224
|
+
faraday.request :authorization, :Bearer, api_key
|
191
225
|
faraday.request :json
|
192
226
|
faraday.response :json, content_type: /\bjson$/
|
193
227
|
faraday.adapter Faraday.default_adapter
|
194
228
|
end
|
195
229
|
end
|
196
|
-
# standard:enable Lint/DuplicateMethods
|
197
230
|
end
|
198
231
|
end
|
data/lib/cohere/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: cohere-ruby
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 1.0.1
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrei Bondarev
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2024-
|
11
|
+
date: 2024-11-23 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: faraday
|
@@ -37,6 +37,7 @@ executables: []
|
|
37
37
|
extensions: []
|
38
38
|
extra_rdoc_files: []
|
39
39
|
files:
|
40
|
+
- ".env.example"
|
40
41
|
- ".rspec"
|
41
42
|
- CHANGELOG.md
|
42
43
|
- Gemfile
|
@@ -70,7 +71,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
70
71
|
- !ruby/object:Gem::Version
|
71
72
|
version: '0'
|
72
73
|
requirements: []
|
73
|
-
rubygems_version: 3.
|
74
|
+
rubygems_version: 3.5.11
|
74
75
|
signing_key:
|
75
76
|
specification_version: 4
|
76
77
|
summary: Cohere API client for Ruby.
|