cloudflare-ai 0.4.1 → 0.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 4de1d8c20767aab8d40b96fa01a447e1ab83a1586cad6c4a3d7597331cabc5bd
4
- data.tar.gz: 7b912c4bf7bb23ec4f2befca92c966656ac4ab0527b03857adaf14b8e8b87a34
3
+ metadata.gz: 8f1d5e2776ba81382b1cb02355ac23348bee4a661d141c7c144136f1b9547eed
4
+ data.tar.gz: b651a6faed4d18a9b00cfed3f2217bc3cdc47ff9a2f9371cf50ab9c1f6bfab6f
5
5
  SHA512:
6
- metadata.gz: c87c773d129790a865a524a56c1ab545a51f3d8a54825b4e02d092282b733f5dbee9a59d1536fd8df096aac47c7d94ae4466e6d6f94aa90115ec4cebdd4d4b0e
7
- data.tar.gz: 45ead04646ae3756d05d0d4805bfbfcd509d421b08bfec63a53ff3736abd16887786d17741df542e4596500779881d35e9d8af6ffdb07c2bd2ffdfc632e17883
6
+ metadata.gz: 2fe5ec654b60e662113f79d2e66adf0483c17c0b2ffd2ea483c8f3b45c0d2ae747605b831e5cf6eed9ea8cd7ae614e6992199cf3fcad41870cb8ca8e269cb048
7
+ data.tar.gz: 3f4eac6cea6d8cdd00a00f71c3eca95ccfb61b2de8b2f45a4dfa46a7d81b45978eb5c92883569c8fd4398874005d0bad456096baa01cb7c82e22587b9700dc2a
data/README.md CHANGED
@@ -21,9 +21,9 @@ It's still early days, and here are my immediate priorities:
21
21
  * [ ] Support for more AI model categories
22
22
  * [x] [Text Generation](https://developers.cloudflare.com/workers-ai/models/text-generation/)
23
23
  * [x] [Text Embeddings](https://developers.cloudflare.com/workers-ai/models/text-embeddings/)
24
- * [ ] [Text Classification](https://developers.cloudflare.com/workers-ai/models/text-classification/)
25
- * [ ] [Image Classification](https://developers.cloudflare.com/workers-ai/models/image-classification/)
26
- * [ ] [Translation](https://developers.cloudflare.com/workers-ai/models/translation/)
24
+ * [x] [Text Classification](https://developers.cloudflare.com/workers-ai/models/text-classification/)
25
+ * [x] [Translation](https://developers.cloudflare.com/workers-ai/models/translation/)
26
+ * [x] [Image Classification](https://developers.cloudflare.com/workers-ai/models/image-classification/)
27
27
  * [ ] [Text-to-Image](https://developers.cloudflare.com/workers-ai/models/text-to-image/)
28
28
  * [ ] [Automatic Speech Recognition](https://developers.cloudflare.com/workers-ai/models/speech-recognition/)
29
29
 
@@ -75,7 +75,8 @@ The full list of supported models is available here: [models.rb](lib/cloudflare/
75
75
  More information is available [in the cloudflare documentation](https://developers.cloudflare.com/workers-ai/models/).
76
76
  The default model used is the first enumerated model in the applicable set in [models.rb](lib/cloudflare/ai/models.rb).
77
77
 
78
- ### Text generation (chat / scoped prompt)
78
+ ### Text generation
79
+ #### (chat / scoped prompt)
79
80
  ```ruby
80
81
  messages = [
81
82
  Cloudflare::AI::Message.new(role: "system", content: "You are a big fan of Cloudflare and Ruby."),
@@ -86,6 +87,11 @@ messages = [
86
87
  result = client.chat(messages: messages)
87
88
  puts result.response # => "Yes, I love Cloudflare!"
88
89
  ```
90
+ #### (string prompt)
91
+ ```ruby
92
+ result = client.complete(prompt: "What is your name?", max_tokens: 512)
93
+ puts result.response # => "My name is Jonas."
94
+ ```
89
95
 
90
96
  #### Streaming responses
91
97
  Responses will be streamed back to the client using Server Side Events (SSE) if a block is passed to the `chat` or `complete` method.
@@ -99,6 +105,9 @@ result = client.complete(prompt: "Hi!") { |data| puts data}
99
105
  # [DONE]
100
106
 
101
107
  ```
108
+ #### Token limits
109
+ Invocations of the `prompt` and `chat` can take an optional `max_tokens` argument that defaults to 256.
110
+
102
111
  #### Result object
103
112
  All invocations of the `prompt` and `chat` methods return a `Cloudflare::AI::Results::TextGeneration` object. This object's serializable JSON output is
104
113
  based on the raw response from the Cloudflare API.
@@ -119,7 +128,6 @@ puts result.failure? # => true
119
128
  puts result.to_json # => {"result":null,"success":false,"errors":[{"code":7009,"message":"Upstream service unavailable"}],"messages":[]}
120
129
  ```
121
130
 
122
-
123
131
  ### Text embedding
124
132
  ```ruby
125
133
  result = client.embed(text: "Hello")
@@ -144,6 +152,29 @@ p result.result # => [{"label"=>"NEGATIVE", "score"=>0.6647962927818298}, {"labe
144
152
  #### Result object
145
153
  All invocations of the `classify` methods return a `Cloudflare::AI::Results::TextClassification`.
146
154
 
155
+ ### Image classification
156
+ The image classification endpoint accepts either a path to a file or a file stream.
157
+
158
+ ```ruby
159
+ result = client.classify(image: "/path/to/cat.jpg")
160
+ p result.result # => {"result":[{"label":"TABBY","score":0.6159140467643738},{"label":"TIGER CAT","score":0.12016300112009048},{"label":"EGYPTIAN CAT","score":0.07523812353610992},{"label":"DOORMAT","score":0.018854796886444092},{"label":"ASHCAN","score":0.01314085815101862}],"success":true,"errors":[],"messages":[]}
161
+
162
+ result = client.classify(image: File.open("/path/to/cat.jpg"))
163
+ p result.result # => {"result":[{"label":"TABBY","score":0.6159140467643738},{"label":"TIGER CAT","score":0.12016300112009048},{"label":"EGYPTIAN CAT","score":0.07523812353610992},{"label":"DOORMAT","score":0.018854796886444092},{"label":"ASHCAN","score":0.01314085815101862}],"success":true,"errors":[],"messages":[]}
164
+ ```
165
+
166
+ #### Result object
167
+ All invocations of the `classify` methods return a `Cloudflare::AI::Results::TextClassification`.
168
+
169
+ ### Translation
170
+ ```ruby
171
+ result = client.translate(text: "Hello Jello", source_lang: "en", target_lang: "fr")
172
+ p result.translated_text # => Hola Jello
173
+ ```
174
+
175
+ #### Result object
176
+ All invocations of the `translate` methods return a `Cloudflare::AI::Results::Translate`.
177
+
147
178
  # Logging
148
179
 
149
180
  This gem uses standard logging mechanisms and defaults to `:warn` level. Most messages are at info level, but we will add debug or warn statements as needed.
@@ -2,6 +2,7 @@ require "event_stream_parser"
2
2
  require "faraday"
3
3
 
4
4
  class Cloudflare::AI::Client
5
+ include Cloudflare::AI::Clients::ImageHelpers
5
6
  include Cloudflare::AI::Clients::TextGenerationHelpers
6
7
 
7
8
  attr_reader :url, :account_id, :api_token
@@ -18,11 +19,19 @@ class Cloudflare::AI::Client
18
19
  post_streamable_request(url, payload, &block)
19
20
  end
20
21
 
21
- def classify(text:, model_name: Cloudflare::AI::Models.text_classification.first)
22
+ def classify(text: nil, image: nil, model_name: nil)
23
+ raise ArgumentError, "Must provide either text or image (and not both)" if [text, image].compact.size != 1
24
+
25
+ model_name ||= text ? Cloudflare::AI::Models.text_classification.first : Cloudflare::AI::Models.image_classification.first
22
26
  url = service_url_for(account_id: account_id, model_name: model_name)
23
- payload = {text: text}.to_json
24
27
 
25
- Cloudflare::AI::Results::TextClassification.new(connection.post(url, payload).body)
28
+ if text
29
+ payload = {text: text}.to_json
30
+ Cloudflare::AI::Results::TextClassification.new(connection.post(url, payload).body)
31
+ else
32
+ image = File.open(image) if image.is_a?(String)
33
+ Cloudflare::AI::Results::ImageClassification.new(post_request_with_binary_file(url, image).body)
34
+ end
26
35
  end
27
36
 
28
37
  def complete(prompt:, model_name: default_text_generation_model_name, max_tokens: default_max_tokens, &block)
@@ -39,6 +48,12 @@ class Cloudflare::AI::Client
39
48
  Cloudflare::AI::Results::TextEmbedding.new(connection.post(url, payload).body)
40
49
  end
41
50
 
51
+ def translate(text:, target_lang:, source_lang: "en", model_name: Cloudflare::AI::Models.translation.first)
52
+ url = service_url_for(account_id: account_id, model_name: model_name)
53
+ payload = {text: text, target_lang: target_lang, source_lang: source_lang}.to_json
54
+ Cloudflare::AI::Results::Translation.new(connection.post(url, payload).body)
55
+ end
56
+
42
57
  private
43
58
 
44
59
  def connection
@@ -0,0 +1,20 @@
1
+ require "faraday/multipart"
2
+
3
+ module Cloudflare
4
+ module AI
5
+ module Clients
6
+ module ImageHelpers
7
+ private
8
+
9
+ def post_request_with_binary_file(url, file)
10
+ connection.post do |req|
11
+ req.url url
12
+ req.headers["Transfer-Encoding"] = "chunked"
13
+ req.headers["Content-Type"] = "multipart/form-data"
14
+ req.body = ::Faraday::UploadIO.new(file, "octet/stream")
15
+ end
16
+ end
17
+ end
18
+ end
19
+ end
20
+ end
@@ -17,7 +17,7 @@ class Cloudflare::AI::Models
17
17
  end
18
18
 
19
19
  def image_classification
20
- %w[@cf/huggingface/distilbert-sst-2-int8]
20
+ %w[@cf/microsoft/resnet-50]
21
21
  end
22
22
 
23
23
  def text_to_image
@@ -0,0 +1,3 @@
1
+ class Cloudflare::AI::Results::ImageClassification < Cloudflare::AI::Result
2
+ # Empty seam kept for consistency with other result objects that have more complexity.
3
+ end
@@ -0,0 +1,5 @@
1
+ class Cloudflare::AI::Results::Translation < Cloudflare::AI::Result
2
+ def translated_text
3
+ result&.dig(:translated_text) # nil if no shape
4
+ end
5
+ end
@@ -2,6 +2,6 @@
2
2
 
3
3
  module Cloudflare
4
4
  module AI
5
- VERSION = "0.4.1"
5
+ VERSION = "0.6.0"
6
6
  end
7
7
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: cloudflare-ai
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.4.1
4
+ version: 0.6.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Ajay Krishnan
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2024-01-22 00:00:00.000000000 Z
11
+ date: 2024-01-24 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: activemodel
@@ -66,6 +66,20 @@ dependencies:
66
66
  - - "~>"
67
67
  - !ruby/object:Gem::Version
68
68
  version: '2.0'
69
+ - !ruby/object:Gem::Dependency
70
+ name: faraday-multipart
71
+ requirement: !ruby/object:Gem::Requirement
72
+ requirements:
73
+ - - "~>"
74
+ - !ruby/object:Gem::Version
75
+ version: '1.0'
76
+ type: :runtime
77
+ prerelease: false
78
+ version_requirements: !ruby/object:Gem::Requirement
79
+ requirements:
80
+ - - "~>"
81
+ - !ruby/object:Gem::Version
82
+ version: '1.0'
69
83
  - !ruby/object:Gem::Dependency
70
84
  name: zeitwerk
71
85
  requirement: !ruby/object:Gem::Requirement
@@ -93,14 +107,17 @@ files:
93
107
  - README.md
94
108
  - lib/cloudflare/ai.rb
95
109
  - lib/cloudflare/ai/client.rb
110
+ - lib/cloudflare/ai/clients/image_helpers.rb
96
111
  - lib/cloudflare/ai/clients/text_generation_helpers.rb
97
112
  - lib/cloudflare/ai/contextual_logger.rb
98
113
  - lib/cloudflare/ai/message.rb
99
114
  - lib/cloudflare/ai/models.rb
100
115
  - lib/cloudflare/ai/result.rb
116
+ - lib/cloudflare/ai/results/image_classification.rb
101
117
  - lib/cloudflare/ai/results/text_classification.rb
102
118
  - lib/cloudflare/ai/results/text_embedding.rb
103
119
  - lib/cloudflare/ai/results/text_generation.rb
120
+ - lib/cloudflare/ai/results/translation.rb
104
121
  - lib/cloudflare/ai/version.rb
105
122
  homepage: https://rubygems.org/gems/cloudflare-ai
106
123
  licenses: