stability_sdk 0.1.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/.gitignore +8 -0
- data/.travis.yml +6 -0
- data/Gemfile +7 -0
- data/Gemfile.lock +34 -0
- data/README.md +89 -0
- data/Rakefile +10 -0
- data/bin/console +14 -0
- data/bin/setup +8 -0
- data/exe/stability-client +43 -0
- data/lib/generation_pb.rb +182 -0
- data/lib/generation_services_pb.rb +22 -0
- data/lib/stability_sdk/cli.rb +34 -0
- data/lib/stability_sdk/client.rb +75 -0
- data/lib/stability_sdk/version.rb +3 -0
- data/lib/stability_sdk.rb +8 -0
- data/proto/generation.proto +179 -0
- data/stability_sdk.gemspec +30 -0
- metadata +106 -0
checksums.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
1
|
+
---
|
2
|
+
SHA256:
|
3
|
+
metadata.gz: dac029de10ccef05e11d7fee6a581f14fec50252e1cf7790ce2e849e5061489c
|
4
|
+
data.tar.gz: 66f8803ff804dae6971cdf3b8473225012255e87d67f00484583d65ecd8afb6a
|
5
|
+
SHA512:
|
6
|
+
metadata.gz: ad988347184eafe784ab654b0022f1742cf56e51c9e2d4d88b3507f97162c6aaa7a165ad8ee5a6d3c34c1a9427fc860a54679192235c154bfea3008765dd8f89
|
7
|
+
data.tar.gz: d4522e6b2644f0ef309dbfb862a0667ed80c77e93c30bc575764997aa511d73789b3f270641d44c83cc2fdc017b7a0d5562e7de79ed7bebf86ff894019ad2642
|
data/.gitignore
ADDED
data/.travis.yml
ADDED
data/Gemfile
ADDED
data/Gemfile.lock
ADDED
@@ -0,0 +1,34 @@
|
|
1
|
+
PATH
|
2
|
+
remote: .
|
3
|
+
specs:
|
4
|
+
stability_sdk (0.1.0)
|
5
|
+
grpc
|
6
|
+
mime-types
|
7
|
+
|
8
|
+
GEM
|
9
|
+
remote: https://rubygems.org/
|
10
|
+
specs:
|
11
|
+
google-protobuf (3.21.5)
|
12
|
+
googleapis-common-protos-types (1.4.0)
|
13
|
+
google-protobuf (~> 3.14)
|
14
|
+
grpc (1.48.0)
|
15
|
+
google-protobuf (~> 3.19)
|
16
|
+
googleapis-common-protos-types (~> 1.0)
|
17
|
+
grpc-tools (1.48.0)
|
18
|
+
mime-types (3.4.1)
|
19
|
+
mime-types-data (~> 3.2015)
|
20
|
+
mime-types-data (3.2022.0105)
|
21
|
+
minitest (5.16.3)
|
22
|
+
rake (12.3.3)
|
23
|
+
|
24
|
+
PLATFORMS
|
25
|
+
arm64-darwin-20
|
26
|
+
|
27
|
+
DEPENDENCIES
|
28
|
+
grpc-tools
|
29
|
+
minitest (~> 5.0)
|
30
|
+
rake (~> 12.0)
|
31
|
+
stability_sdk!
|
32
|
+
|
33
|
+
BUNDLED WITH
|
34
|
+
2.2.2
|
data/README.md
ADDED
@@ -0,0 +1,89 @@
|
|
1
|
+
# StabilitySDK - Ruby client for stability.ai APIs, such as Stable Diffusion
|
2
|
+
|
3
|
+
A ruby client for [stability.ai](https://stability.ai/) APIs, e.g., [Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release). Referring to https://github.com/Stability-AI/stability-sdk.
|
4
|
+
|
5
|
+
## Installation
|
6
|
+
|
7
|
+
Add this line to your application's Gemfile:
|
8
|
+
|
9
|
+
```ruby
|
10
|
+
gem 'stability_sdk'
|
11
|
+
```
|
12
|
+
|
13
|
+
And then execute:
|
14
|
+
|
15
|
+
$ bundle install
|
16
|
+
|
17
|
+
Or install it yourself as:
|
18
|
+
|
19
|
+
$ gem install stability_sdk
|
20
|
+
|
21
|
+
## Usage
|
22
|
+
|
23
|
+
First you need to create a [DreamStudio](https://beta.dreamstudio.ai/home)'s account and get an API Key of it.
|
24
|
+
|
25
|
+
- Access [DreamStudio](https://beta.dreamstudio.ai/dream) and create an account if you have not had it
|
26
|
+
- Go to the [membership page](https://beta.dreamstudio.ai/membership)
|
27
|
+
- You can get the API Key in an `API Key` tab
|
28
|
+
|
29
|
+
### Command line usage
|
30
|
+
|
31
|
+
```sh
|
32
|
+
STABILITY_SDK_API_KEY=YOUR_API_KEY stability-client 'A night in winter, oil-on-canvas landscape painting, by Vincent van Gogh'
|
33
|
+
```
|
34
|
+
|
35
|
+
This command saves an image like this:
|
36
|
+
|
37
|
+
![3749380973_A_night_in_winter__oil_on_canvas_landscape_painting__by_Vincent_van_Gogh](https://user-images.githubusercontent.com/25668/188884116-0b03494b-0b34-49de-bbbc-89fbc2f6029d.png)
|
38
|
+
|
39
|
+
|
40
|
+
```sh
|
41
|
+
Usage: stability-client [options] YOUR_PROMPT_TEXT
|
42
|
+
|
43
|
+
Options:
|
44
|
+
--api_key=VAL api key of DreamStudio account. You can also specify by a STABILITY_SDK_API_KEY environment variable
|
45
|
+
-H, --height=VAL height of image in pixel. default 512
|
46
|
+
-W, --width=VAL width of image in pixel. default 512
|
47
|
+
-C, --cfg_scale=VAL CFG scale factor. default 7.0
|
48
|
+
-A, --sampler=VAL ddim, plms, k_euler, k_euler_ancestral, k_heun, k_dpm_2, k_dpm_2_ancestral, k_lms. default k_lms
|
49
|
+
-s, --steps=VAL number of steps. default 50
|
50
|
+
-S, --seed=VAL random seed to use in integer
|
51
|
+
-p, --prefix=VAL output prefixes for artifacts. default `generation`
|
52
|
+
--no-store do not write out artifacts
|
53
|
+
-n, --num_samples=VAL number of samples to generate
|
54
|
+
-e, --engine=VAL engine to use for inference. default `stable-diffusion-v1`
|
55
|
+
-v, --verbose
|
56
|
+
```
|
57
|
+
|
58
|
+
### SDK usage
|
59
|
+
|
60
|
+
This sample code saves a generated image as `result.png`.
|
61
|
+
|
62
|
+
```ruby
|
63
|
+
require "stability_sdk"
|
64
|
+
|
65
|
+
client = StabilitySDK::Client.new(api_key: "YOUR_API_KEY")
|
66
|
+
|
67
|
+
prompt = "your prompot here"
|
68
|
+
options = {}
|
69
|
+
|
70
|
+
client.generate(prompt, options) do |answer|
|
71
|
+
answer.artifacts.each do |artifact|
|
72
|
+
if artifact.type == :ARTIFACT_IMAGE
|
73
|
+
File.open("result.png", "wb") do |f|
|
74
|
+
f.write(artifact.binary)
|
75
|
+
end
|
76
|
+
end
|
77
|
+
end
|
78
|
+
end
|
79
|
+
```
|
80
|
+
|
81
|
+
## Development
|
82
|
+
|
83
|
+
After checking out the repo, run `bin/setup` to install dependencies. Then, run `rake test` to run the tests. You can also run `bin/console` for an interactive prompt that will allow you to experiment.
|
84
|
+
|
85
|
+
To install this gem onto your local machine, run `bundle exec rake install`. To release a new version, update the version number in `version.rb`, and then run `bundle exec rake release`, which will create a git tag for the version, push git commits and tags, and push the `.gem` file to [rubygems.org](https://rubygems.org).
|
86
|
+
|
87
|
+
## Contributing
|
88
|
+
|
89
|
+
Bug reports and pull requests are welcome on GitHub at https://github.com/cou929/stability-sdk-ruby.
|
data/Rakefile
ADDED
data/bin/console
ADDED
@@ -0,0 +1,14 @@
|
|
1
|
+
#!/usr/bin/env ruby
|
2
|
+
|
3
|
+
require "bundler/setup"
|
4
|
+
require "stability_sdk"
|
5
|
+
|
6
|
+
# You can add fixtures and/or initialization code here to make experimenting
|
7
|
+
# with your gem easier. You can also use a different console, if you like.
|
8
|
+
|
9
|
+
# (If you use this, don't forget to add pry to your Gemfile!)
|
10
|
+
# require "pry"
|
11
|
+
# Pry.start
|
12
|
+
|
13
|
+
require "irb"
|
14
|
+
IRB.start(__FILE__)
|
data/bin/setup
ADDED
@@ -0,0 +1,43 @@
|
|
1
|
+
#!/usr/bin/env ruby
|
2
|
+
|
3
|
+
require "optparse"
|
4
|
+
require "mime/types"
|
5
|
+
require "logger"
|
6
|
+
require "stability_sdk"
|
7
|
+
|
8
|
+
logger = Logger.new(STDOUT)
|
9
|
+
logger.level = Logger::WARN
|
10
|
+
|
11
|
+
opt = OptionParser.new
|
12
|
+
Version = StabilitySDK::VERSION
|
13
|
+
|
14
|
+
options = {}
|
15
|
+
|
16
|
+
opt.banner = "Usage: stability-client [options] YOUR_PROMPT_TEXT"
|
17
|
+
opt.separator ""
|
18
|
+
opt.separator "Options:"
|
19
|
+
opt.on("--api_key=VAL", "api key of DreamStudio account. You can also specify by a STABILITY_SDK_API_KEY environment variable") {|v| options[:api_key] = v }
|
20
|
+
opt.on("-H", "--height=VAL", "height of image in pixel. default 512") {|v| options[:height] = v }
|
21
|
+
opt.on("-W", "--width=VAL", "width of image in pixel. default 512") {|v| options[:width] = v }
|
22
|
+
opt.on("-C", "--cfg_scale=VAL", "CFG scale factor. default 7.0") {|v| options[:cfg_scale] = v }
|
23
|
+
opt.on("-A", "--sampler=VAL", "ddim, plms, k_euler, k_euler_ancestral, k_heun, k_dpm_2, k_dpm_2_ancestral, k_lms. default k_lms") {|v| options[:sampler] = v }
|
24
|
+
opt.on("-s", "--steps=VAL", "number of steps. default 50") {|v| options[:steps] = v }
|
25
|
+
opt.on("-S", "--seed=VAL", "random seed to use in integer") {|v| options[:seed] = v }
|
26
|
+
opt.on("-p", "--prefix=VAL", "output prefixes for artifacts. default `generation`") {|v| options[:prefix] = v }
|
27
|
+
opt.on("--no-store", "do not write out artifacts") {|v| options[:no_store] = v }
|
28
|
+
opt.on("-n", "--num_samples=VAL", "number of samples to generate") {|v| options[:num_samples] = v }
|
29
|
+
opt.on("-e", "--engine=VAL", "engine to use for inference. default `stable-diffusion-v1`") {|v| options[:engine_id] = v }
|
30
|
+
opt.on("-v", "--verbose") { logger.level = Logger::DEBUG }
|
31
|
+
opt.parse!(ARGV)
|
32
|
+
|
33
|
+
prompt = ARGV.join(" ")
|
34
|
+
raise StabilitySDK::InsufficientParameter, "prompt is required" if prompt.nil? || prompt == ""
|
35
|
+
|
36
|
+
options[:api_key] = ENV["STABILITY_SDK_API_KEY"] if ENV["STABILITY_SDK_API_KEY"]
|
37
|
+
raise StabilitySDK::InsufficientParameter, "api key is required" if !options.has_key?(:api_key) || options[:api_key] == ""
|
38
|
+
|
39
|
+
client = StabilitySDK::Client.new(api_key: options[:api_key])
|
40
|
+
|
41
|
+
client.generate(prompt, options) do |answer|
|
42
|
+
StabilitySDK::CLI.save_answer(answer, options, logger)
|
43
|
+
end
|
@@ -0,0 +1,182 @@
|
|
1
|
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
2
|
+
# source: generation.proto
|
3
|
+
|
4
|
+
require 'google/protobuf'
|
5
|
+
|
6
|
+
Google::Protobuf::DescriptorPool.generated_pool.build do
|
7
|
+
add_file("generation.proto", :syntax => :proto3) do
|
8
|
+
add_message "gooseai.Token" do
|
9
|
+
proto3_optional :text, :string, 1
|
10
|
+
optional :id, :uint32, 2
|
11
|
+
end
|
12
|
+
add_message "gooseai.Tokens" do
|
13
|
+
repeated :tokens, :message, 1, "gooseai.Token"
|
14
|
+
proto3_optional :tokenizer_id, :string, 2
|
15
|
+
end
|
16
|
+
add_message "gooseai.Artifact" do
|
17
|
+
optional :id, :uint64, 1
|
18
|
+
optional :type, :enum, 2, "gooseai.ArtifactType"
|
19
|
+
optional :mime, :string, 3
|
20
|
+
proto3_optional :magic, :string, 4
|
21
|
+
optional :index, :uint32, 8
|
22
|
+
optional :finish_reason, :enum, 9, "gooseai.FinishReason"
|
23
|
+
optional :seed, :uint32, 10
|
24
|
+
oneof :data do
|
25
|
+
optional :binary, :bytes, 5
|
26
|
+
optional :text, :string, 6
|
27
|
+
optional :tokens, :message, 7, "gooseai.Tokens"
|
28
|
+
optional :classifier, :message, 11, "gooseai.ClassifierParameters"
|
29
|
+
end
|
30
|
+
end
|
31
|
+
add_message "gooseai.PromptParameters" do
|
32
|
+
proto3_optional :init, :bool, 1
|
33
|
+
proto3_optional :weight, :float, 2
|
34
|
+
end
|
35
|
+
add_message "gooseai.Prompt" do
|
36
|
+
proto3_optional :parameters, :message, 1, "gooseai.PromptParameters"
|
37
|
+
oneof :prompt do
|
38
|
+
optional :text, :string, 2
|
39
|
+
optional :tokens, :message, 3, "gooseai.Tokens"
|
40
|
+
optional :artifact, :message, 4, "gooseai.Artifact"
|
41
|
+
end
|
42
|
+
end
|
43
|
+
add_message "gooseai.AnswerMeta" do
|
44
|
+
proto3_optional :gpu_id, :string, 1
|
45
|
+
proto3_optional :cpu_id, :string, 2
|
46
|
+
proto3_optional :node_id, :string, 3
|
47
|
+
proto3_optional :engine_id, :string, 4
|
48
|
+
end
|
49
|
+
add_message "gooseai.Answer" do
|
50
|
+
optional :answer_id, :string, 1
|
51
|
+
optional :request_id, :string, 2
|
52
|
+
optional :received, :uint64, 3
|
53
|
+
optional :created, :uint64, 4
|
54
|
+
proto3_optional :meta, :message, 6, "gooseai.AnswerMeta"
|
55
|
+
repeated :artifacts, :message, 7, "gooseai.Artifact"
|
56
|
+
end
|
57
|
+
add_message "gooseai.SamplerParameters" do
|
58
|
+
proto3_optional :eta, :float, 1
|
59
|
+
proto3_optional :sampling_steps, :uint64, 2
|
60
|
+
proto3_optional :latent_channels, :uint64, 3
|
61
|
+
proto3_optional :downsampling_factor, :uint64, 4
|
62
|
+
proto3_optional :cfg_scale, :float, 5
|
63
|
+
end
|
64
|
+
add_message "gooseai.ConditionerParameters" do
|
65
|
+
proto3_optional :vector_adjust_prior, :string, 1
|
66
|
+
end
|
67
|
+
add_message "gooseai.StepParameter" do
|
68
|
+
optional :scaled_step, :float, 1
|
69
|
+
proto3_optional :sampler, :message, 2, "gooseai.SamplerParameters"
|
70
|
+
end
|
71
|
+
add_message "gooseai.TransformType" do
|
72
|
+
oneof :type do
|
73
|
+
optional :diffusion, :enum, 1, "gooseai.DiffusionSampler"
|
74
|
+
optional :upscaler, :enum, 2, "gooseai.Upscaler"
|
75
|
+
end
|
76
|
+
end
|
77
|
+
add_message "gooseai.ImageParameters" do
|
78
|
+
proto3_optional :height, :uint64, 1
|
79
|
+
proto3_optional :width, :uint64, 2
|
80
|
+
repeated :seed, :uint32, 3
|
81
|
+
proto3_optional :samples, :uint64, 4
|
82
|
+
proto3_optional :steps, :uint64, 5
|
83
|
+
proto3_optional :transform, :message, 6, "gooseai.TransformType"
|
84
|
+
repeated :parameters, :message, 7, "gooseai.StepParameter"
|
85
|
+
end
|
86
|
+
add_message "gooseai.ClassifierConcept" do
|
87
|
+
optional :concept, :string, 1
|
88
|
+
proto3_optional :threshold, :float, 2
|
89
|
+
end
|
90
|
+
add_message "gooseai.ClassifierCategory" do
|
91
|
+
optional :name, :string, 1
|
92
|
+
repeated :concepts, :message, 2, "gooseai.ClassifierConcept"
|
93
|
+
proto3_optional :adjustment, :float, 3
|
94
|
+
proto3_optional :action, :enum, 4, "gooseai.Action"
|
95
|
+
proto3_optional :classifier_mode, :enum, 5, "gooseai.ClassifierMode"
|
96
|
+
end
|
97
|
+
add_message "gooseai.ClassifierParameters" do
|
98
|
+
repeated :categories, :message, 1, "gooseai.ClassifierCategory"
|
99
|
+
repeated :exceeds, :message, 2, "gooseai.ClassifierCategory"
|
100
|
+
proto3_optional :realized_action, :enum, 3, "gooseai.Action"
|
101
|
+
end
|
102
|
+
add_message "gooseai.Request" do
|
103
|
+
optional :engine_id, :string, 1
|
104
|
+
optional :request_id, :string, 2
|
105
|
+
optional :requested_type, :enum, 3, "gooseai.ArtifactType"
|
106
|
+
repeated :prompt, :message, 4, "gooseai.Prompt"
|
107
|
+
proto3_optional :conditioner, :message, 6, "gooseai.ConditionerParameters"
|
108
|
+
proto3_optional :classifier, :message, 7, "gooseai.ClassifierParameters"
|
109
|
+
oneof :params do
|
110
|
+
optional :image, :message, 5, "gooseai.ImageParameters"
|
111
|
+
end
|
112
|
+
end
|
113
|
+
add_enum "gooseai.FinishReason" do
|
114
|
+
value :NULL, 0
|
115
|
+
value :LENGTH, 1
|
116
|
+
value :STOP, 2
|
117
|
+
value :ERROR, 3
|
118
|
+
value :FILTER, 4
|
119
|
+
end
|
120
|
+
add_enum "gooseai.ArtifactType" do
|
121
|
+
value :ARTIFACT_NONE, 0
|
122
|
+
value :ARTIFACT_IMAGE, 1
|
123
|
+
value :ARTIFACT_VIDEO, 2
|
124
|
+
value :ARTIFACT_TEXT, 3
|
125
|
+
value :ARTIFACT_TOKENS, 4
|
126
|
+
value :ARTIFACT_EMBEDDING, 5
|
127
|
+
value :ARTIFACT_CLASSIFICATIONS, 6
|
128
|
+
end
|
129
|
+
add_enum "gooseai.DiffusionSampler" do
|
130
|
+
value :SAMPLER_DDIM, 0
|
131
|
+
value :SAMPLER_DDPM, 1
|
132
|
+
value :SAMPLER_K_EULER, 2
|
133
|
+
value :SAMPLER_K_EULER_ANCESTRAL, 3
|
134
|
+
value :SAMPLER_K_HEUN, 4
|
135
|
+
value :SAMPLER_K_DPM_2, 5
|
136
|
+
value :SAMPLER_K_DPM_2_ANCESTRAL, 6
|
137
|
+
value :SAMPLER_K_LMS, 7
|
138
|
+
end
|
139
|
+
add_enum "gooseai.Upscaler" do
|
140
|
+
value :UPSCALER_RGB, 0
|
141
|
+
value :UPSCALER_GFPGAN, 1
|
142
|
+
value :UPSCALER_ESRGAN, 2
|
143
|
+
end
|
144
|
+
add_enum "gooseai.Action" do
|
145
|
+
value :ACTION_PASSTHROUGH, 0
|
146
|
+
value :ACTION_REGENERATE_DUPLICATE, 1
|
147
|
+
value :ACTION_REGENERATE, 2
|
148
|
+
value :ACTION_OBFUSCATE_DUPLICATE, 3
|
149
|
+
value :ACTION_OBFUSCATE, 4
|
150
|
+
value :ACTION_DISCARD, 5
|
151
|
+
end
|
152
|
+
add_enum "gooseai.ClassifierMode" do
|
153
|
+
value :CLSFR_MODE_ZEROSHOT, 0
|
154
|
+
value :CLSFR_MODE_MULTICLASS, 1
|
155
|
+
end
|
156
|
+
end
|
157
|
+
end
|
158
|
+
|
159
|
+
module Gooseai
|
160
|
+
Token = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.Token").msgclass
|
161
|
+
Tokens = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.Tokens").msgclass
|
162
|
+
Artifact = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.Artifact").msgclass
|
163
|
+
PromptParameters = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.PromptParameters").msgclass
|
164
|
+
Prompt = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.Prompt").msgclass
|
165
|
+
AnswerMeta = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.AnswerMeta").msgclass
|
166
|
+
Answer = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.Answer").msgclass
|
167
|
+
SamplerParameters = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.SamplerParameters").msgclass
|
168
|
+
ConditionerParameters = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.ConditionerParameters").msgclass
|
169
|
+
StepParameter = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.StepParameter").msgclass
|
170
|
+
TransformType = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.TransformType").msgclass
|
171
|
+
ImageParameters = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.ImageParameters").msgclass
|
172
|
+
ClassifierConcept = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.ClassifierConcept").msgclass
|
173
|
+
ClassifierCategory = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.ClassifierCategory").msgclass
|
174
|
+
ClassifierParameters = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.ClassifierParameters").msgclass
|
175
|
+
Request = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.Request").msgclass
|
176
|
+
FinishReason = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.FinishReason").enummodule
|
177
|
+
ArtifactType = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.ArtifactType").enummodule
|
178
|
+
DiffusionSampler = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.DiffusionSampler").enummodule
|
179
|
+
Upscaler = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.Upscaler").enummodule
|
180
|
+
Action = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.Action").enummodule
|
181
|
+
ClassifierMode = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.ClassifierMode").enummodule
|
182
|
+
end
|
@@ -0,0 +1,22 @@
|
|
1
|
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
2
|
+
# Source: generation.proto for package 'gooseai'
|
3
|
+
|
4
|
+
require 'grpc'
|
5
|
+
require 'generation_pb'
|
6
|
+
|
7
|
+
module Gooseai
|
8
|
+
module GenerationService
|
9
|
+
class Service
|
10
|
+
|
11
|
+
include ::GRPC::GenericService
|
12
|
+
|
13
|
+
self.marshal_class_method = :encode
|
14
|
+
self.unmarshal_class_method = :decode
|
15
|
+
self.service_name = 'gooseai.GenerationService'
|
16
|
+
|
17
|
+
rpc :Generate, ::Gooseai::Request, stream(::Gooseai::Answer)
|
18
|
+
end
|
19
|
+
|
20
|
+
Stub = Service.rpc_stub_class
|
21
|
+
end
|
22
|
+
end
|
@@ -0,0 +1,34 @@
|
|
1
|
+
module StabilitySDK
|
2
|
+
class CLI
|
3
|
+
def self.save_answer(answer, options, logger)
|
4
|
+
logger.debug "received answer: #{answer}"
|
5
|
+
return if options[:no_store]
|
6
|
+
|
7
|
+
answer.artifacts.each_with_index do |artifact, idx|
|
8
|
+
filename_base = "#{options[:prefix] || "generation"}-#{answer.request_id}-#{answer.answer_id}-#{idx}"
|
9
|
+
|
10
|
+
filename = ""
|
11
|
+
contents = ""
|
12
|
+
|
13
|
+
case artifact.type
|
14
|
+
when :ARTIFACT_IMAGE
|
15
|
+
ext = MIME::Types[artifact.mime].first.preferred_extension
|
16
|
+
filename = "#{filename_base}.#{ext}"
|
17
|
+
contents = artifact.binary
|
18
|
+
when :ARTIFACT_CLASSIFICATIONS
|
19
|
+
ext = "pb.json"
|
20
|
+
filename = "#{filename_base}.#{ext}"
|
21
|
+
contents = artifact.classifier.to_json
|
22
|
+
else
|
23
|
+
logger.warn "not implemented for ArtifactType #{artifact.type}"
|
24
|
+
end
|
25
|
+
|
26
|
+
next if filename == "" || contents == ""
|
27
|
+
|
28
|
+
File.open(filename, "wb") do |f|
|
29
|
+
f.write(contents)
|
30
|
+
end
|
31
|
+
end
|
32
|
+
end
|
33
|
+
end
|
34
|
+
end
|
@@ -0,0 +1,75 @@
|
|
1
|
+
require "grpc"
|
2
|
+
require "generation_services_pb"
|
3
|
+
|
4
|
+
module StabilitySDK
|
5
|
+
class Client
|
6
|
+
DEFAULT_API_HOST = "grpc.stability.ai:443"
|
7
|
+
DEFAULT_IMAGE_WIDTH = 512
|
8
|
+
DEFAULT_IMAGE_HEIGHT = 512
|
9
|
+
DEFAULT_SAMPLE_SIZE = 1
|
10
|
+
DEFAULT_STEPS = 50
|
11
|
+
DEFAULT_ENGINE_ID = "stable-diffusion-v1"
|
12
|
+
DEFAULT_CFG_SCALE = 7.0
|
13
|
+
DEFAULT_SAMPLER_ALGORITHM = Gooseai::DiffusionSampler::SAMPLER_K_LMS
|
14
|
+
|
15
|
+
sampler_algorithms = {
|
16
|
+
"ddim": Gooseai::DiffusionSampler::SAMPLER_DDIM,
|
17
|
+
"plms": Gooseai::DiffusionSampler::SAMPLER_DDPM,
|
18
|
+
"k_euler": Gooseai::DiffusionSampler::SAMPLER_K_EULER,
|
19
|
+
"k_euler_ancestral": Gooseai::DiffusionSampler::SAMPLER_K_EULER_ANCESTRAL,
|
20
|
+
"k_heun": Gooseai::DiffusionSampler::SAMPLER_K_HEUN,
|
21
|
+
"k_dpm_2": Gooseai::DiffusionSampler::SAMPLER_K_DPM_2,
|
22
|
+
"k_dpm_2_ancestral": Gooseai::DiffusionSampler::SAMPLER_K_DPM_2_ANCESTRAL,
|
23
|
+
"k_lms": Gooseai::DiffusionSampler::SAMPLER_K_LMS,
|
24
|
+
}
|
25
|
+
|
26
|
+
def initialize(options={})
|
27
|
+
host = options[:api_host] || DEFAULT_API_HOST
|
28
|
+
channel_creds = options.has_key?(:ca_cert) ? GRPC::Core::ChannelCredentials.new(options[:ca_cert]) : GRPC::Core::ChannelCredentials.new
|
29
|
+
call_creds = GRPC::Core::CallCredentials.new(proc { { "authorization" => "Bearer #{options[:api_key]}" } })
|
30
|
+
creds = channel_creds.compose(call_creds)
|
31
|
+
|
32
|
+
@stub = Gooseai::GenerationService::Stub.new(host, creds)
|
33
|
+
end
|
34
|
+
|
35
|
+
def generate(prompt, options, &block)
|
36
|
+
image_param = image_param(options)
|
37
|
+
req = Gooseai::Request.new(
|
38
|
+
prompt: [Gooseai::Prompt.new(text: prompt)],
|
39
|
+
engine_id: options[:engine_id] || DEFAULT_ENGINE_ID,
|
40
|
+
image: image_param
|
41
|
+
)
|
42
|
+
|
43
|
+
@stub.generate(req).each do |answer|
|
44
|
+
block.call(answer)
|
45
|
+
end
|
46
|
+
end
|
47
|
+
|
48
|
+
def image_param(options={})
|
49
|
+
width = options.has_key?(:width) ? options[:width].to_i : DEFAULT_IMAGE_WIDTH
|
50
|
+
height = options.has_key?(:height) ? options[:height] : DEFAULT_IMAGE_HEIGHT
|
51
|
+
samples = options.has_key?(:num_samples) ? [:num_samples].to_i : DEFAULT_SAMPLE_SIZE
|
52
|
+
steps = options.has_key?(:steps) ? options[:steps].to_i : DEFAULT_STEPS
|
53
|
+
seed = options.has_key?(:seed) ? [options[:seed]] : [rand(4294967295)]
|
54
|
+
transform = Gooseai::TransformType.new(
|
55
|
+
diffusion: options.has_key?(:sampler) ? sampler_algorithms[options[:sampler]] : DEFAULT_SAMPLER_ALGORITHM,
|
56
|
+
)
|
57
|
+
parameters = [Gooseai::StepParameter.new(
|
58
|
+
scaled_step: 0,
|
59
|
+
sampler: Gooseai::SamplerParameters.new(
|
60
|
+
cfg_scale: options.has_key?(:cfg_scale) ? options[:cfg_scale].to_f : DEFAULT_CFG_SCALE,
|
61
|
+
),
|
62
|
+
)]
|
63
|
+
|
64
|
+
return Gooseai::ImageParameters.new(
|
65
|
+
width: width,
|
66
|
+
height: height,
|
67
|
+
samples: samples,
|
68
|
+
steps: steps,
|
69
|
+
seed: seed,
|
70
|
+
transform: transform,
|
71
|
+
parameters: parameters,
|
72
|
+
)
|
73
|
+
end
|
74
|
+
end
|
75
|
+
end
|
@@ -0,0 +1,179 @@
|
|
1
|
+
syntax = 'proto3';
|
2
|
+
package gooseai;
|
3
|
+
option go_package = "./;generation";
|
4
|
+
|
5
|
+
enum FinishReason {
|
6
|
+
NULL = 0;
|
7
|
+
LENGTH = 1;
|
8
|
+
STOP = 2;
|
9
|
+
ERROR = 3;
|
10
|
+
FILTER = 4;
|
11
|
+
}
|
12
|
+
|
13
|
+
|
14
|
+
enum ArtifactType {
|
15
|
+
ARTIFACT_NONE = 0;
|
16
|
+
ARTIFACT_IMAGE = 1;
|
17
|
+
ARTIFACT_VIDEO = 2;
|
18
|
+
ARTIFACT_TEXT = 3;
|
19
|
+
ARTIFACT_TOKENS = 4;
|
20
|
+
ARTIFACT_EMBEDDING = 5;
|
21
|
+
ARTIFACT_CLASSIFICATIONS = 6;
|
22
|
+
}
|
23
|
+
|
24
|
+
message Token {
|
25
|
+
optional string text = 1;
|
26
|
+
uint32 id = 2;
|
27
|
+
}
|
28
|
+
|
29
|
+
message Tokens {
|
30
|
+
repeated Token tokens = 1;
|
31
|
+
optional string tokenizer_id = 2;
|
32
|
+
}
|
33
|
+
|
34
|
+
message Artifact {
|
35
|
+
uint64 id = 1;
|
36
|
+
ArtifactType type = 2;
|
37
|
+
string mime = 3;
|
38
|
+
optional string magic = 4;
|
39
|
+
oneof data {
|
40
|
+
bytes binary = 5;
|
41
|
+
string text = 6;
|
42
|
+
Tokens tokens = 7;
|
43
|
+
ClassifierParameters classifier = 11;
|
44
|
+
}
|
45
|
+
uint32 index = 8;
|
46
|
+
FinishReason finish_reason = 9;
|
47
|
+
uint32 seed = 10;
|
48
|
+
}
|
49
|
+
|
50
|
+
message PromptParameters {
|
51
|
+
optional bool init = 1;
|
52
|
+
optional float weight = 2;
|
53
|
+
}
|
54
|
+
|
55
|
+
message Prompt {
|
56
|
+
optional PromptParameters parameters = 1;
|
57
|
+
oneof prompt {
|
58
|
+
string text = 2;
|
59
|
+
Tokens tokens = 3;
|
60
|
+
Artifact artifact = 4;
|
61
|
+
}
|
62
|
+
}
|
63
|
+
|
64
|
+
message AnswerMeta {
|
65
|
+
optional string gpu_id = 1;
|
66
|
+
optional string cpu_id = 2;
|
67
|
+
optional string node_id = 3;
|
68
|
+
optional string engine_id = 4;
|
69
|
+
}
|
70
|
+
|
71
|
+
message Answer {
|
72
|
+
string answer_id = 1;
|
73
|
+
string request_id = 2;
|
74
|
+
uint64 received = 3;
|
75
|
+
uint64 created = 4;
|
76
|
+
optional AnswerMeta meta = 6;
|
77
|
+
repeated Artifact artifacts = 7;
|
78
|
+
}
|
79
|
+
|
80
|
+
enum DiffusionSampler {
|
81
|
+
SAMPLER_DDIM = 0;
|
82
|
+
SAMPLER_DDPM = 1;
|
83
|
+
SAMPLER_K_EULER = 2;
|
84
|
+
SAMPLER_K_EULER_ANCESTRAL = 3;
|
85
|
+
SAMPLER_K_HEUN = 4;
|
86
|
+
SAMPLER_K_DPM_2 = 5;
|
87
|
+
SAMPLER_K_DPM_2_ANCESTRAL = 6;
|
88
|
+
SAMPLER_K_LMS = 7;
|
89
|
+
}
|
90
|
+
|
91
|
+
message SamplerParameters {
|
92
|
+
optional float eta = 1;
|
93
|
+
optional uint64 sampling_steps = 2;
|
94
|
+
optional uint64 latent_channels = 3;
|
95
|
+
optional uint64 downsampling_factor = 4;
|
96
|
+
optional float cfg_scale = 5;
|
97
|
+
}
|
98
|
+
|
99
|
+
message ConditionerParameters {
|
100
|
+
optional string vector_adjust_prior = 1;
|
101
|
+
}
|
102
|
+
|
103
|
+
enum Upscaler {
|
104
|
+
UPSCALER_RGB = 0;
|
105
|
+
UPSCALER_GFPGAN = 1;
|
106
|
+
UPSCALER_ESRGAN = 2;
|
107
|
+
}
|
108
|
+
|
109
|
+
message StepParameter {
|
110
|
+
float scaled_step = 1;
|
111
|
+
optional SamplerParameters sampler = 2;
|
112
|
+
}
|
113
|
+
|
114
|
+
message TransformType {
|
115
|
+
oneof type {
|
116
|
+
DiffusionSampler diffusion = 1;
|
117
|
+
Upscaler upscaler = 2;
|
118
|
+
}
|
119
|
+
}
|
120
|
+
|
121
|
+
message ImageParameters {
|
122
|
+
optional uint64 height = 1;
|
123
|
+
optional uint64 width = 2;
|
124
|
+
repeated uint32 seed = 3;
|
125
|
+
optional uint64 samples = 4;
|
126
|
+
optional uint64 steps = 5;
|
127
|
+
optional TransformType transform = 6;
|
128
|
+
repeated StepParameter parameters = 7;
|
129
|
+
}
|
130
|
+
|
131
|
+
enum Action {
|
132
|
+
ACTION_PASSTHROUGH = 0;
|
133
|
+
ACTION_REGENERATE_DUPLICATE = 1;
|
134
|
+
ACTION_REGENERATE = 2;
|
135
|
+
ACTION_OBFUSCATE_DUPLICATE = 3;
|
136
|
+
ACTION_OBFUSCATE = 4;
|
137
|
+
ACTION_DISCARD = 5;
|
138
|
+
}
|
139
|
+
|
140
|
+
enum ClassifierMode {
|
141
|
+
CLSFR_MODE_ZEROSHOT = 0;
|
142
|
+
CLSFR_MODE_MULTICLASS = 1;
|
143
|
+
/*CLSFR_MODE_ODDSRATIO = 2;*/
|
144
|
+
}
|
145
|
+
|
146
|
+
message ClassifierConcept {
|
147
|
+
string concept = 1;
|
148
|
+
optional float threshold = 2;
|
149
|
+
}
|
150
|
+
|
151
|
+
message ClassifierCategory {
|
152
|
+
string name = 1;
|
153
|
+
repeated ClassifierConcept concepts = 2;
|
154
|
+
optional float adjustment = 3;
|
155
|
+
optional Action action = 4;
|
156
|
+
optional ClassifierMode classifier_mode = 5;
|
157
|
+
}
|
158
|
+
|
159
|
+
message ClassifierParameters {
|
160
|
+
repeated ClassifierCategory categories = 1;
|
161
|
+
repeated ClassifierCategory exceeds = 2;
|
162
|
+
optional Action realized_action = 3;
|
163
|
+
}
|
164
|
+
|
165
|
+
message Request {
|
166
|
+
string engine_id = 1;
|
167
|
+
string request_id = 2;
|
168
|
+
ArtifactType requested_type = 3;
|
169
|
+
repeated Prompt prompt = 4;
|
170
|
+
oneof params {
|
171
|
+
ImageParameters image = 5;
|
172
|
+
}
|
173
|
+
optional ConditionerParameters conditioner = 6;
|
174
|
+
optional ClassifierParameters classifier = 7;
|
175
|
+
}
|
176
|
+
|
177
|
+
service GenerationService {
|
178
|
+
rpc Generate (Request) returns (stream Answer) {};
|
179
|
+
}
|
@@ -0,0 +1,30 @@
|
|
1
|
+
require_relative 'lib/stability_sdk/version'
|
2
|
+
|
3
|
+
Gem::Specification.new do |spec|
|
4
|
+
spec.name = "stability_sdk"
|
5
|
+
spec.version = StabilitySDK::VERSION
|
6
|
+
spec.authors = ["Kosei Moriyama"]
|
7
|
+
spec.email = ["cou929@gmail.com"]
|
8
|
+
|
9
|
+
spec.summary = "Ruby client for interacting with stability.ai APIs (e.g. stable diffusion inference)"
|
10
|
+
spec.description = "Ruby client of https://github.com/Stability-AI/stability-sdk"
|
11
|
+
spec.homepage = "https://github.com/cou929/stability-sdk-ruby"
|
12
|
+
spec.required_ruby_version = Gem::Requirement.new(">= 2.3.0")
|
13
|
+
|
14
|
+
spec.metadata["homepage_uri"] = spec.homepage
|
15
|
+
spec.metadata["source_code_uri"] = "https://github.com/cou929/stability-sdk-ruby"
|
16
|
+
spec.metadata["changelog_uri"] = "https://github.com/cou929/stability-sdk-ruby"
|
17
|
+
|
18
|
+
# Specify which files should be added to the gem when it is released.
|
19
|
+
# The `git ls-files -z` loads the files in the RubyGem that have been added into git.
|
20
|
+
spec.files = Dir.chdir(File.expand_path('..', __FILE__)) do
|
21
|
+
`git ls-files -z`.split("\x0").reject { |f| f.match(%r{^(test|spec|features)/}) }
|
22
|
+
end
|
23
|
+
spec.bindir = "exe"
|
24
|
+
spec.executables = spec.files.grep(%r{^exe/}) { |f| File.basename(f) }
|
25
|
+
spec.require_paths = ["lib"]
|
26
|
+
|
27
|
+
spec.add_dependency "grpc"
|
28
|
+
spec.add_dependency "mime-types"
|
29
|
+
spec.add_development_dependency "grpc-tools"
|
30
|
+
end
|
metadata
ADDED
@@ -0,0 +1,106 @@
|
|
1
|
+
--- !ruby/object:Gem::Specification
|
2
|
+
name: stability_sdk
|
3
|
+
version: !ruby/object:Gem::Version
|
4
|
+
version: 0.1.0
|
5
|
+
platform: ruby
|
6
|
+
authors:
|
7
|
+
- Kosei Moriyama
|
8
|
+
autorequire:
|
9
|
+
bindir: exe
|
10
|
+
cert_chain: []
|
11
|
+
date: 2022-09-07 00:00:00.000000000 Z
|
12
|
+
dependencies:
|
13
|
+
- !ruby/object:Gem::Dependency
|
14
|
+
name: grpc
|
15
|
+
requirement: !ruby/object:Gem::Requirement
|
16
|
+
requirements:
|
17
|
+
- - ">="
|
18
|
+
- !ruby/object:Gem::Version
|
19
|
+
version: '0'
|
20
|
+
type: :runtime
|
21
|
+
prerelease: false
|
22
|
+
version_requirements: !ruby/object:Gem::Requirement
|
23
|
+
requirements:
|
24
|
+
- - ">="
|
25
|
+
- !ruby/object:Gem::Version
|
26
|
+
version: '0'
|
27
|
+
- !ruby/object:Gem::Dependency
|
28
|
+
name: mime-types
|
29
|
+
requirement: !ruby/object:Gem::Requirement
|
30
|
+
requirements:
|
31
|
+
- - ">="
|
32
|
+
- !ruby/object:Gem::Version
|
33
|
+
version: '0'
|
34
|
+
type: :runtime
|
35
|
+
prerelease: false
|
36
|
+
version_requirements: !ruby/object:Gem::Requirement
|
37
|
+
requirements:
|
38
|
+
- - ">="
|
39
|
+
- !ruby/object:Gem::Version
|
40
|
+
version: '0'
|
41
|
+
- !ruby/object:Gem::Dependency
|
42
|
+
name: grpc-tools
|
43
|
+
requirement: !ruby/object:Gem::Requirement
|
44
|
+
requirements:
|
45
|
+
- - ">="
|
46
|
+
- !ruby/object:Gem::Version
|
47
|
+
version: '0'
|
48
|
+
type: :development
|
49
|
+
prerelease: false
|
50
|
+
version_requirements: !ruby/object:Gem::Requirement
|
51
|
+
requirements:
|
52
|
+
- - ">="
|
53
|
+
- !ruby/object:Gem::Version
|
54
|
+
version: '0'
|
55
|
+
description: Ruby client of https://github.com/Stability-AI/stability-sdk
|
56
|
+
email:
|
57
|
+
- cou929@gmail.com
|
58
|
+
executables:
|
59
|
+
- stability-client
|
60
|
+
extensions: []
|
61
|
+
extra_rdoc_files: []
|
62
|
+
files:
|
63
|
+
- ".gitignore"
|
64
|
+
- ".travis.yml"
|
65
|
+
- Gemfile
|
66
|
+
- Gemfile.lock
|
67
|
+
- README.md
|
68
|
+
- Rakefile
|
69
|
+
- bin/console
|
70
|
+
- bin/setup
|
71
|
+
- exe/stability-client
|
72
|
+
- lib/generation_pb.rb
|
73
|
+
- lib/generation_services_pb.rb
|
74
|
+
- lib/stability_sdk.rb
|
75
|
+
- lib/stability_sdk/cli.rb
|
76
|
+
- lib/stability_sdk/client.rb
|
77
|
+
- lib/stability_sdk/version.rb
|
78
|
+
- proto/generation.proto
|
79
|
+
- stability_sdk.gemspec
|
80
|
+
homepage: https://github.com/cou929/stability-sdk-ruby
|
81
|
+
licenses: []
|
82
|
+
metadata:
|
83
|
+
homepage_uri: https://github.com/cou929/stability-sdk-ruby
|
84
|
+
source_code_uri: https://github.com/cou929/stability-sdk-ruby
|
85
|
+
changelog_uri: https://github.com/cou929/stability-sdk-ruby
|
86
|
+
post_install_message:
|
87
|
+
rdoc_options: []
|
88
|
+
require_paths:
|
89
|
+
- lib
|
90
|
+
required_ruby_version: !ruby/object:Gem::Requirement
|
91
|
+
requirements:
|
92
|
+
- - ">="
|
93
|
+
- !ruby/object:Gem::Version
|
94
|
+
version: 2.3.0
|
95
|
+
required_rubygems_version: !ruby/object:Gem::Requirement
|
96
|
+
requirements:
|
97
|
+
- - ">="
|
98
|
+
- !ruby/object:Gem::Version
|
99
|
+
version: '0'
|
100
|
+
requirements: []
|
101
|
+
rubygems_version: 3.1.4
|
102
|
+
signing_key:
|
103
|
+
specification_version: 4
|
104
|
+
summary: Ruby client for interacting with stability.ai APIs (e.g. stable diffusion
|
105
|
+
inference)
|
106
|
+
test_files: []
|