stability_sdk 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/.gitignore +8 -0
- data/.travis.yml +6 -0
- data/Gemfile +7 -0
- data/Gemfile.lock +34 -0
- data/README.md +89 -0
- data/Rakefile +10 -0
- data/bin/console +14 -0
- data/bin/setup +8 -0
- data/exe/stability-client +43 -0
- data/lib/generation_pb.rb +182 -0
- data/lib/generation_services_pb.rb +22 -0
- data/lib/stability_sdk/cli.rb +34 -0
- data/lib/stability_sdk/client.rb +75 -0
- data/lib/stability_sdk/version.rb +3 -0
- data/lib/stability_sdk.rb +8 -0
- data/proto/generation.proto +179 -0
- data/stability_sdk.gemspec +30 -0
- metadata +106 -0
checksums.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
1
|
+
---
|
2
|
+
SHA256:
|
3
|
+
metadata.gz: dac029de10ccef05e11d7fee6a581f14fec50252e1cf7790ce2e849e5061489c
|
4
|
+
data.tar.gz: 66f8803ff804dae6971cdf3b8473225012255e87d67f00484583d65ecd8afb6a
|
5
|
+
SHA512:
|
6
|
+
metadata.gz: ad988347184eafe784ab654b0022f1742cf56e51c9e2d4d88b3507f97162c6aaa7a165ad8ee5a6d3c34c1a9427fc860a54679192235c154bfea3008765dd8f89
|
7
|
+
data.tar.gz: d4522e6b2644f0ef309dbfb862a0667ed80c77e93c30bc575764997aa511d73789b3f270641d44c83cc2fdc017b7a0d5562e7de79ed7bebf86ff894019ad2642
|
data/.gitignore
ADDED
data/.travis.yml
ADDED
data/Gemfile
ADDED
data/Gemfile.lock
ADDED
@@ -0,0 +1,34 @@
|
|
1
|
+
PATH
|
2
|
+
remote: .
|
3
|
+
specs:
|
4
|
+
stability_sdk (0.1.0)
|
5
|
+
grpc
|
6
|
+
mime-types
|
7
|
+
|
8
|
+
GEM
|
9
|
+
remote: https://rubygems.org/
|
10
|
+
specs:
|
11
|
+
google-protobuf (3.21.5)
|
12
|
+
googleapis-common-protos-types (1.4.0)
|
13
|
+
google-protobuf (~> 3.14)
|
14
|
+
grpc (1.48.0)
|
15
|
+
google-protobuf (~> 3.19)
|
16
|
+
googleapis-common-protos-types (~> 1.0)
|
17
|
+
grpc-tools (1.48.0)
|
18
|
+
mime-types (3.4.1)
|
19
|
+
mime-types-data (~> 3.2015)
|
20
|
+
mime-types-data (3.2022.0105)
|
21
|
+
minitest (5.16.3)
|
22
|
+
rake (12.3.3)
|
23
|
+
|
24
|
+
PLATFORMS
|
25
|
+
arm64-darwin-20
|
26
|
+
|
27
|
+
DEPENDENCIES
|
28
|
+
grpc-tools
|
29
|
+
minitest (~> 5.0)
|
30
|
+
rake (~> 12.0)
|
31
|
+
stability_sdk!
|
32
|
+
|
33
|
+
BUNDLED WITH
|
34
|
+
2.2.2
|
data/README.md
ADDED
@@ -0,0 +1,89 @@
|
|
1
|
+
# StabilitySDK - Ruby client for stability.ai APIs, such as Stable Diffusion
|
2
|
+
|
3
|
+
A ruby client for [stability.ai](https://stability.ai/) APIs, e.g., [Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release). Referring to https://github.com/Stability-AI/stability-sdk.
|
4
|
+
|
5
|
+
## Installation
|
6
|
+
|
7
|
+
Add this line to your application's Gemfile:
|
8
|
+
|
9
|
+
```ruby
|
10
|
+
gem 'stability_sdk'
|
11
|
+
```
|
12
|
+
|
13
|
+
And then execute:
|
14
|
+
|
15
|
+
$ bundle install
|
16
|
+
|
17
|
+
Or install it yourself as:
|
18
|
+
|
19
|
+
$ gem install stability_sdk
|
20
|
+
|
21
|
+
## Usage
|
22
|
+
|
23
|
+
First you need to create a [DreamStudio](https://beta.dreamstudio.ai/home)'s account and get an API Key of it.
|
24
|
+
|
25
|
+
- Access [DreamStudio](https://beta.dreamstudio.ai/dream) and create an account if you have not had it
|
26
|
+
- Go to the [membership page](https://beta.dreamstudio.ai/membership)
|
27
|
+
- You can get the API Key in an `API Key` tab
|
28
|
+
|
29
|
+
### Command line usage
|
30
|
+
|
31
|
+
```sh
|
32
|
+
STABILITY_SDK_API_KEY=YOUR_API_KEY stability-client 'A night in winter, oil-on-canvas landscape painting, by Vincent van Gogh'
|
33
|
+
```
|
34
|
+
|
35
|
+
This command saves an image like this:
|
36
|
+
|
37
|
+

|
38
|
+
|
39
|
+
|
40
|
+
```sh
|
41
|
+
Usage: stability-client [options] YOUR_PROMPT_TEXT
|
42
|
+
|
43
|
+
Options:
|
44
|
+
--api_key=VAL api key of DreamStudio account. You can also specify by a STABILITY_SDK_API_KEY environment variable
|
45
|
+
-H, --height=VAL height of image in pixel. default 512
|
46
|
+
-W, --width=VAL width of image in pixel. default 512
|
47
|
+
-C, --cfg_scale=VAL CFG scale factor. default 7.0
|
48
|
+
-A, --sampler=VAL ddim, plms, k_euler, k_euler_ancestral, k_heun, k_dpm_2, k_dpm_2_ancestral, k_lms. default k_lms
|
49
|
+
-s, --steps=VAL number of steps. default 50
|
50
|
+
-S, --seed=VAL random seed to use in integer
|
51
|
+
-p, --prefix=VAL output prefixes for artifacts. default `generation`
|
52
|
+
--no-store do not write out artifacts
|
53
|
+
-n, --num_samples=VAL number of samples to generate
|
54
|
+
-e, --engine=VAL engine to use for inference. default `stable-diffusion-v1`
|
55
|
+
-v, --verbose
|
56
|
+
```
|
57
|
+
|
58
|
+
### SDK usage
|
59
|
+
|
60
|
+
This sample code saves a generated image as `result.png`.
|
61
|
+
|
62
|
+
```ruby
|
63
|
+
require "stability_sdk"
|
64
|
+
|
65
|
+
client = StabilitySDK::Client.new(api_key: "YOUR_API_KEY")
|
66
|
+
|
67
|
+
prompt = "your prompot here"
|
68
|
+
options = {}
|
69
|
+
|
70
|
+
client.generate(prompt, options) do |answer|
|
71
|
+
answer.artifacts.each do |artifact|
|
72
|
+
if artifact.type == :ARTIFACT_IMAGE
|
73
|
+
File.open("result.png", "wb") do |f|
|
74
|
+
f.write(artifact.binary)
|
75
|
+
end
|
76
|
+
end
|
77
|
+
end
|
78
|
+
end
|
79
|
+
```
|
80
|
+
|
81
|
+
## Development
|
82
|
+
|
83
|
+
After checking out the repo, run `bin/setup` to install dependencies. Then, run `rake test` to run the tests. You can also run `bin/console` for an interactive prompt that will allow you to experiment.
|
84
|
+
|
85
|
+
To install this gem onto your local machine, run `bundle exec rake install`. To release a new version, update the version number in `version.rb`, and then run `bundle exec rake release`, which will create a git tag for the version, push git commits and tags, and push the `.gem` file to [rubygems.org](https://rubygems.org).
|
86
|
+
|
87
|
+
## Contributing
|
88
|
+
|
89
|
+
Bug reports and pull requests are welcome on GitHub at https://github.com/cou929/stability-sdk-ruby.
|
data/Rakefile
ADDED
data/bin/console
ADDED
@@ -0,0 +1,14 @@
|
|
1
|
+
#!/usr/bin/env ruby
|
2
|
+
|
3
|
+
require "bundler/setup"
|
4
|
+
require "stability_sdk"
|
5
|
+
|
6
|
+
# You can add fixtures and/or initialization code here to make experimenting
|
7
|
+
# with your gem easier. You can also use a different console, if you like.
|
8
|
+
|
9
|
+
# (If you use this, don't forget to add pry to your Gemfile!)
|
10
|
+
# require "pry"
|
11
|
+
# Pry.start
|
12
|
+
|
13
|
+
require "irb"
|
14
|
+
IRB.start(__FILE__)
|
data/bin/setup
ADDED
@@ -0,0 +1,43 @@
|
|
1
|
+
#!/usr/bin/env ruby
|
2
|
+
|
3
|
+
require "optparse"
|
4
|
+
require "mime/types"
|
5
|
+
require "logger"
|
6
|
+
require "stability_sdk"
|
7
|
+
|
8
|
+
logger = Logger.new(STDOUT)
|
9
|
+
logger.level = Logger::WARN
|
10
|
+
|
11
|
+
opt = OptionParser.new
|
12
|
+
Version = StabilitySDK::VERSION
|
13
|
+
|
14
|
+
options = {}
|
15
|
+
|
16
|
+
opt.banner = "Usage: stability-client [options] YOUR_PROMPT_TEXT"
|
17
|
+
opt.separator ""
|
18
|
+
opt.separator "Options:"
|
19
|
+
opt.on("--api_key=VAL", "api key of DreamStudio account. You can also specify by a STABILITY_SDK_API_KEY environment variable") {|v| options[:api_key] = v }
|
20
|
+
opt.on("-H", "--height=VAL", "height of image in pixel. default 512") {|v| options[:height] = v }
|
21
|
+
opt.on("-W", "--width=VAL", "width of image in pixel. default 512") {|v| options[:width] = v }
|
22
|
+
opt.on("-C", "--cfg_scale=VAL", "CFG scale factor. default 7.0") {|v| options[:cfg_scale] = v }
|
23
|
+
opt.on("-A", "--sampler=VAL", "ddim, plms, k_euler, k_euler_ancestral, k_heun, k_dpm_2, k_dpm_2_ancestral, k_lms. default k_lms") {|v| options[:sampler] = v }
|
24
|
+
opt.on("-s", "--steps=VAL", "number of steps. default 50") {|v| options[:steps] = v }
|
25
|
+
opt.on("-S", "--seed=VAL", "random seed to use in integer") {|v| options[:seed] = v }
|
26
|
+
opt.on("-p", "--prefix=VAL", "output prefixes for artifacts. default `generation`") {|v| options[:prefix] = v }
|
27
|
+
opt.on("--no-store", "do not write out artifacts") {|v| options[:no_store] = v }
|
28
|
+
opt.on("-n", "--num_samples=VAL", "number of samples to generate") {|v| options[:num_samples] = v }
|
29
|
+
opt.on("-e", "--engine=VAL", "engine to use for inference. default `stable-diffusion-v1`") {|v| options[:engine_id] = v }
|
30
|
+
opt.on("-v", "--verbose") { logger.level = Logger::DEBUG }
|
31
|
+
opt.parse!(ARGV)
|
32
|
+
|
33
|
+
prompt = ARGV.join(" ")
|
34
|
+
raise StabilitySDK::InsufficientParameter, "prompt is required" if prompt.nil? || prompt == ""
|
35
|
+
|
36
|
+
options[:api_key] = ENV["STABILITY_SDK_API_KEY"] if ENV["STABILITY_SDK_API_KEY"]
|
37
|
+
raise StabilitySDK::InsufficientParameter, "api key is required" if !options.has_key?(:api_key) || options[:api_key] == ""
|
38
|
+
|
39
|
+
client = StabilitySDK::Client.new(api_key: options[:api_key])
|
40
|
+
|
41
|
+
client.generate(prompt, options) do |answer|
|
42
|
+
StabilitySDK::CLI.save_answer(answer, options, logger)
|
43
|
+
end
|
@@ -0,0 +1,182 @@
|
|
1
|
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
2
|
+
# source: generation.proto
|
3
|
+
|
4
|
+
require 'google/protobuf'
|
5
|
+
|
6
|
+
Google::Protobuf::DescriptorPool.generated_pool.build do
|
7
|
+
add_file("generation.proto", :syntax => :proto3) do
|
8
|
+
add_message "gooseai.Token" do
|
9
|
+
proto3_optional :text, :string, 1
|
10
|
+
optional :id, :uint32, 2
|
11
|
+
end
|
12
|
+
add_message "gooseai.Tokens" do
|
13
|
+
repeated :tokens, :message, 1, "gooseai.Token"
|
14
|
+
proto3_optional :tokenizer_id, :string, 2
|
15
|
+
end
|
16
|
+
add_message "gooseai.Artifact" do
|
17
|
+
optional :id, :uint64, 1
|
18
|
+
optional :type, :enum, 2, "gooseai.ArtifactType"
|
19
|
+
optional :mime, :string, 3
|
20
|
+
proto3_optional :magic, :string, 4
|
21
|
+
optional :index, :uint32, 8
|
22
|
+
optional :finish_reason, :enum, 9, "gooseai.FinishReason"
|
23
|
+
optional :seed, :uint32, 10
|
24
|
+
oneof :data do
|
25
|
+
optional :binary, :bytes, 5
|
26
|
+
optional :text, :string, 6
|
27
|
+
optional :tokens, :message, 7, "gooseai.Tokens"
|
28
|
+
optional :classifier, :message, 11, "gooseai.ClassifierParameters"
|
29
|
+
end
|
30
|
+
end
|
31
|
+
add_message "gooseai.PromptParameters" do
|
32
|
+
proto3_optional :init, :bool, 1
|
33
|
+
proto3_optional :weight, :float, 2
|
34
|
+
end
|
35
|
+
add_message "gooseai.Prompt" do
|
36
|
+
proto3_optional :parameters, :message, 1, "gooseai.PromptParameters"
|
37
|
+
oneof :prompt do
|
38
|
+
optional :text, :string, 2
|
39
|
+
optional :tokens, :message, 3, "gooseai.Tokens"
|
40
|
+
optional :artifact, :message, 4, "gooseai.Artifact"
|
41
|
+
end
|
42
|
+
end
|
43
|
+
add_message "gooseai.AnswerMeta" do
|
44
|
+
proto3_optional :gpu_id, :string, 1
|
45
|
+
proto3_optional :cpu_id, :string, 2
|
46
|
+
proto3_optional :node_id, :string, 3
|
47
|
+
proto3_optional :engine_id, :string, 4
|
48
|
+
end
|
49
|
+
add_message "gooseai.Answer" do
|
50
|
+
optional :answer_id, :string, 1
|
51
|
+
optional :request_id, :string, 2
|
52
|
+
optional :received, :uint64, 3
|
53
|
+
optional :created, :uint64, 4
|
54
|
+
proto3_optional :meta, :message, 6, "gooseai.AnswerMeta"
|
55
|
+
repeated :artifacts, :message, 7, "gooseai.Artifact"
|
56
|
+
end
|
57
|
+
add_message "gooseai.SamplerParameters" do
|
58
|
+
proto3_optional :eta, :float, 1
|
59
|
+
proto3_optional :sampling_steps, :uint64, 2
|
60
|
+
proto3_optional :latent_channels, :uint64, 3
|
61
|
+
proto3_optional :downsampling_factor, :uint64, 4
|
62
|
+
proto3_optional :cfg_scale, :float, 5
|
63
|
+
end
|
64
|
+
add_message "gooseai.ConditionerParameters" do
|
65
|
+
proto3_optional :vector_adjust_prior, :string, 1
|
66
|
+
end
|
67
|
+
add_message "gooseai.StepParameter" do
|
68
|
+
optional :scaled_step, :float, 1
|
69
|
+
proto3_optional :sampler, :message, 2, "gooseai.SamplerParameters"
|
70
|
+
end
|
71
|
+
add_message "gooseai.TransformType" do
|
72
|
+
oneof :type do
|
73
|
+
optional :diffusion, :enum, 1, "gooseai.DiffusionSampler"
|
74
|
+
optional :upscaler, :enum, 2, "gooseai.Upscaler"
|
75
|
+
end
|
76
|
+
end
|
77
|
+
add_message "gooseai.ImageParameters" do
|
78
|
+
proto3_optional :height, :uint64, 1
|
79
|
+
proto3_optional :width, :uint64, 2
|
80
|
+
repeated :seed, :uint32, 3
|
81
|
+
proto3_optional :samples, :uint64, 4
|
82
|
+
proto3_optional :steps, :uint64, 5
|
83
|
+
proto3_optional :transform, :message, 6, "gooseai.TransformType"
|
84
|
+
repeated :parameters, :message, 7, "gooseai.StepParameter"
|
85
|
+
end
|
86
|
+
add_message "gooseai.ClassifierConcept" do
|
87
|
+
optional :concept, :string, 1
|
88
|
+
proto3_optional :threshold, :float, 2
|
89
|
+
end
|
90
|
+
add_message "gooseai.ClassifierCategory" do
|
91
|
+
optional :name, :string, 1
|
92
|
+
repeated :concepts, :message, 2, "gooseai.ClassifierConcept"
|
93
|
+
proto3_optional :adjustment, :float, 3
|
94
|
+
proto3_optional :action, :enum, 4, "gooseai.Action"
|
95
|
+
proto3_optional :classifier_mode, :enum, 5, "gooseai.ClassifierMode"
|
96
|
+
end
|
97
|
+
add_message "gooseai.ClassifierParameters" do
|
98
|
+
repeated :categories, :message, 1, "gooseai.ClassifierCategory"
|
99
|
+
repeated :exceeds, :message, 2, "gooseai.ClassifierCategory"
|
100
|
+
proto3_optional :realized_action, :enum, 3, "gooseai.Action"
|
101
|
+
end
|
102
|
+
add_message "gooseai.Request" do
|
103
|
+
optional :engine_id, :string, 1
|
104
|
+
optional :request_id, :string, 2
|
105
|
+
optional :requested_type, :enum, 3, "gooseai.ArtifactType"
|
106
|
+
repeated :prompt, :message, 4, "gooseai.Prompt"
|
107
|
+
proto3_optional :conditioner, :message, 6, "gooseai.ConditionerParameters"
|
108
|
+
proto3_optional :classifier, :message, 7, "gooseai.ClassifierParameters"
|
109
|
+
oneof :params do
|
110
|
+
optional :image, :message, 5, "gooseai.ImageParameters"
|
111
|
+
end
|
112
|
+
end
|
113
|
+
add_enum "gooseai.FinishReason" do
|
114
|
+
value :NULL, 0
|
115
|
+
value :LENGTH, 1
|
116
|
+
value :STOP, 2
|
117
|
+
value :ERROR, 3
|
118
|
+
value :FILTER, 4
|
119
|
+
end
|
120
|
+
add_enum "gooseai.ArtifactType" do
|
121
|
+
value :ARTIFACT_NONE, 0
|
122
|
+
value :ARTIFACT_IMAGE, 1
|
123
|
+
value :ARTIFACT_VIDEO, 2
|
124
|
+
value :ARTIFACT_TEXT, 3
|
125
|
+
value :ARTIFACT_TOKENS, 4
|
126
|
+
value :ARTIFACT_EMBEDDING, 5
|
127
|
+
value :ARTIFACT_CLASSIFICATIONS, 6
|
128
|
+
end
|
129
|
+
add_enum "gooseai.DiffusionSampler" do
|
130
|
+
value :SAMPLER_DDIM, 0
|
131
|
+
value :SAMPLER_DDPM, 1
|
132
|
+
value :SAMPLER_K_EULER, 2
|
133
|
+
value :SAMPLER_K_EULER_ANCESTRAL, 3
|
134
|
+
value :SAMPLER_K_HEUN, 4
|
135
|
+
value :SAMPLER_K_DPM_2, 5
|
136
|
+
value :SAMPLER_K_DPM_2_ANCESTRAL, 6
|
137
|
+
value :SAMPLER_K_LMS, 7
|
138
|
+
end
|
139
|
+
add_enum "gooseai.Upscaler" do
|
140
|
+
value :UPSCALER_RGB, 0
|
141
|
+
value :UPSCALER_GFPGAN, 1
|
142
|
+
value :UPSCALER_ESRGAN, 2
|
143
|
+
end
|
144
|
+
add_enum "gooseai.Action" do
|
145
|
+
value :ACTION_PASSTHROUGH, 0
|
146
|
+
value :ACTION_REGENERATE_DUPLICATE, 1
|
147
|
+
value :ACTION_REGENERATE, 2
|
148
|
+
value :ACTION_OBFUSCATE_DUPLICATE, 3
|
149
|
+
value :ACTION_OBFUSCATE, 4
|
150
|
+
value :ACTION_DISCARD, 5
|
151
|
+
end
|
152
|
+
add_enum "gooseai.ClassifierMode" do
|
153
|
+
value :CLSFR_MODE_ZEROSHOT, 0
|
154
|
+
value :CLSFR_MODE_MULTICLASS, 1
|
155
|
+
end
|
156
|
+
end
|
157
|
+
end
|
158
|
+
|
159
|
+
module Gooseai
|
160
|
+
Token = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.Token").msgclass
|
161
|
+
Tokens = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.Tokens").msgclass
|
162
|
+
Artifact = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.Artifact").msgclass
|
163
|
+
PromptParameters = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.PromptParameters").msgclass
|
164
|
+
Prompt = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.Prompt").msgclass
|
165
|
+
AnswerMeta = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.AnswerMeta").msgclass
|
166
|
+
Answer = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.Answer").msgclass
|
167
|
+
SamplerParameters = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.SamplerParameters").msgclass
|
168
|
+
ConditionerParameters = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.ConditionerParameters").msgclass
|
169
|
+
StepParameter = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.StepParameter").msgclass
|
170
|
+
TransformType = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.TransformType").msgclass
|
171
|
+
ImageParameters = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.ImageParameters").msgclass
|
172
|
+
ClassifierConcept = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.ClassifierConcept").msgclass
|
173
|
+
ClassifierCategory = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.ClassifierCategory").msgclass
|
174
|
+
ClassifierParameters = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.ClassifierParameters").msgclass
|
175
|
+
Request = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.Request").msgclass
|
176
|
+
FinishReason = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.FinishReason").enummodule
|
177
|
+
ArtifactType = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.ArtifactType").enummodule
|
178
|
+
DiffusionSampler = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.DiffusionSampler").enummodule
|
179
|
+
Upscaler = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.Upscaler").enummodule
|
180
|
+
Action = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.Action").enummodule
|
181
|
+
ClassifierMode = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.ClassifierMode").enummodule
|
182
|
+
end
|
@@ -0,0 +1,22 @@
|
|
1
|
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
2
|
+
# Source: generation.proto for package 'gooseai'
|
3
|
+
|
4
|
+
require 'grpc'
|
5
|
+
require 'generation_pb'
|
6
|
+
|
7
|
+
module Gooseai
|
8
|
+
module GenerationService
|
9
|
+
class Service
|
10
|
+
|
11
|
+
include ::GRPC::GenericService
|
12
|
+
|
13
|
+
self.marshal_class_method = :encode
|
14
|
+
self.unmarshal_class_method = :decode
|
15
|
+
self.service_name = 'gooseai.GenerationService'
|
16
|
+
|
17
|
+
rpc :Generate, ::Gooseai::Request, stream(::Gooseai::Answer)
|
18
|
+
end
|
19
|
+
|
20
|
+
Stub = Service.rpc_stub_class
|
21
|
+
end
|
22
|
+
end
|
@@ -0,0 +1,34 @@
|
|
1
|
+
module StabilitySDK
|
2
|
+
class CLI
|
3
|
+
def self.save_answer(answer, options, logger)
|
4
|
+
logger.debug "received answer: #{answer}"
|
5
|
+
return if options[:no_store]
|
6
|
+
|
7
|
+
answer.artifacts.each_with_index do |artifact, idx|
|
8
|
+
filename_base = "#{options[:prefix] || "generation"}-#{answer.request_id}-#{answer.answer_id}-#{idx}"
|
9
|
+
|
10
|
+
filename = ""
|
11
|
+
contents = ""
|
12
|
+
|
13
|
+
case artifact.type
|
14
|
+
when :ARTIFACT_IMAGE
|
15
|
+
ext = MIME::Types[artifact.mime].first.preferred_extension
|
16
|
+
filename = "#{filename_base}.#{ext}"
|
17
|
+
contents = artifact.binary
|
18
|
+
when :ARTIFACT_CLASSIFICATIONS
|
19
|
+
ext = "pb.json"
|
20
|
+
filename = "#{filename_base}.#{ext}"
|
21
|
+
contents = artifact.classifier.to_json
|
22
|
+
else
|
23
|
+
logger.warn "not implemented for ArtifactType #{artifact.type}"
|
24
|
+
end
|
25
|
+
|
26
|
+
next if filename == "" || contents == ""
|
27
|
+
|
28
|
+
File.open(filename, "wb") do |f|
|
29
|
+
f.write(contents)
|
30
|
+
end
|
31
|
+
end
|
32
|
+
end
|
33
|
+
end
|
34
|
+
end
|
@@ -0,0 +1,75 @@
|
|
1
|
+
require "grpc"
|
2
|
+
require "generation_services_pb"
|
3
|
+
|
4
|
+
module StabilitySDK
|
5
|
+
class Client
|
6
|
+
DEFAULT_API_HOST = "grpc.stability.ai:443"
|
7
|
+
DEFAULT_IMAGE_WIDTH = 512
|
8
|
+
DEFAULT_IMAGE_HEIGHT = 512
|
9
|
+
DEFAULT_SAMPLE_SIZE = 1
|
10
|
+
DEFAULT_STEPS = 50
|
11
|
+
DEFAULT_ENGINE_ID = "stable-diffusion-v1"
|
12
|
+
DEFAULT_CFG_SCALE = 7.0
|
13
|
+
DEFAULT_SAMPLER_ALGORITHM = Gooseai::DiffusionSampler::SAMPLER_K_LMS
|
14
|
+
|
15
|
+
sampler_algorithms = {
|
16
|
+
"ddim": Gooseai::DiffusionSampler::SAMPLER_DDIM,
|
17
|
+
"plms": Gooseai::DiffusionSampler::SAMPLER_DDPM,
|
18
|
+
"k_euler": Gooseai::DiffusionSampler::SAMPLER_K_EULER,
|
19
|
+
"k_euler_ancestral": Gooseai::DiffusionSampler::SAMPLER_K_EULER_ANCESTRAL,
|
20
|
+
"k_heun": Gooseai::DiffusionSampler::SAMPLER_K_HEUN,
|
21
|
+
"k_dpm_2": Gooseai::DiffusionSampler::SAMPLER_K_DPM_2,
|
22
|
+
"k_dpm_2_ancestral": Gooseai::DiffusionSampler::SAMPLER_K_DPM_2_ANCESTRAL,
|
23
|
+
"k_lms": Gooseai::DiffusionSampler::SAMPLER_K_LMS,
|
24
|
+
}
|
25
|
+
|
26
|
+
def initialize(options={})
|
27
|
+
host = options[:api_host] || DEFAULT_API_HOST
|
28
|
+
channel_creds = options.has_key?(:ca_cert) ? GRPC::Core::ChannelCredentials.new(options[:ca_cert]) : GRPC::Core::ChannelCredentials.new
|
29
|
+
call_creds = GRPC::Core::CallCredentials.new(proc { { "authorization" => "Bearer #{options[:api_key]}" } })
|
30
|
+
creds = channel_creds.compose(call_creds)
|
31
|
+
|
32
|
+
@stub = Gooseai::GenerationService::Stub.new(host, creds)
|
33
|
+
end
|
34
|
+
|
35
|
+
def generate(prompt, options, &block)
|
36
|
+
image_param = image_param(options)
|
37
|
+
req = Gooseai::Request.new(
|
38
|
+
prompt: [Gooseai::Prompt.new(text: prompt)],
|
39
|
+
engine_id: options[:engine_id] || DEFAULT_ENGINE_ID,
|
40
|
+
image: image_param
|
41
|
+
)
|
42
|
+
|
43
|
+
@stub.generate(req).each do |answer|
|
44
|
+
block.call(answer)
|
45
|
+
end
|
46
|
+
end
|
47
|
+
|
48
|
+
def image_param(options={})
|
49
|
+
width = options.has_key?(:width) ? options[:width].to_i : DEFAULT_IMAGE_WIDTH
|
50
|
+
height = options.has_key?(:height) ? options[:height] : DEFAULT_IMAGE_HEIGHT
|
51
|
+
samples = options.has_key?(:num_samples) ? [:num_samples].to_i : DEFAULT_SAMPLE_SIZE
|
52
|
+
steps = options.has_key?(:steps) ? options[:steps].to_i : DEFAULT_STEPS
|
53
|
+
seed = options.has_key?(:seed) ? [options[:seed]] : [rand(4294967295)]
|
54
|
+
transform = Gooseai::TransformType.new(
|
55
|
+
diffusion: options.has_key?(:sampler) ? sampler_algorithms[options[:sampler]] : DEFAULT_SAMPLER_ALGORITHM,
|
56
|
+
)
|
57
|
+
parameters = [Gooseai::StepParameter.new(
|
58
|
+
scaled_step: 0,
|
59
|
+
sampler: Gooseai::SamplerParameters.new(
|
60
|
+
cfg_scale: options.has_key?(:cfg_scale) ? options[:cfg_scale].to_f : DEFAULT_CFG_SCALE,
|
61
|
+
),
|
62
|
+
)]
|
63
|
+
|
64
|
+
return Gooseai::ImageParameters.new(
|
65
|
+
width: width,
|
66
|
+
height: height,
|
67
|
+
samples: samples,
|
68
|
+
steps: steps,
|
69
|
+
seed: seed,
|
70
|
+
transform: transform,
|
71
|
+
parameters: parameters,
|
72
|
+
)
|
73
|
+
end
|
74
|
+
end
|
75
|
+
end
|
@@ -0,0 +1,179 @@
|
|
1
|
+
syntax = 'proto3';
|
2
|
+
package gooseai;
|
3
|
+
option go_package = "./;generation";
|
4
|
+
|
5
|
+
enum FinishReason {
|
6
|
+
NULL = 0;
|
7
|
+
LENGTH = 1;
|
8
|
+
STOP = 2;
|
9
|
+
ERROR = 3;
|
10
|
+
FILTER = 4;
|
11
|
+
}
|
12
|
+
|
13
|
+
|
14
|
+
enum ArtifactType {
|
15
|
+
ARTIFACT_NONE = 0;
|
16
|
+
ARTIFACT_IMAGE = 1;
|
17
|
+
ARTIFACT_VIDEO = 2;
|
18
|
+
ARTIFACT_TEXT = 3;
|
19
|
+
ARTIFACT_TOKENS = 4;
|
20
|
+
ARTIFACT_EMBEDDING = 5;
|
21
|
+
ARTIFACT_CLASSIFICATIONS = 6;
|
22
|
+
}
|
23
|
+
|
24
|
+
message Token {
|
25
|
+
optional string text = 1;
|
26
|
+
uint32 id = 2;
|
27
|
+
}
|
28
|
+
|
29
|
+
message Tokens {
|
30
|
+
repeated Token tokens = 1;
|
31
|
+
optional string tokenizer_id = 2;
|
32
|
+
}
|
33
|
+
|
34
|
+
message Artifact {
|
35
|
+
uint64 id = 1;
|
36
|
+
ArtifactType type = 2;
|
37
|
+
string mime = 3;
|
38
|
+
optional string magic = 4;
|
39
|
+
oneof data {
|
40
|
+
bytes binary = 5;
|
41
|
+
string text = 6;
|
42
|
+
Tokens tokens = 7;
|
43
|
+
ClassifierParameters classifier = 11;
|
44
|
+
}
|
45
|
+
uint32 index = 8;
|
46
|
+
FinishReason finish_reason = 9;
|
47
|
+
uint32 seed = 10;
|
48
|
+
}
|
49
|
+
|
50
|
+
message PromptParameters {
|
51
|
+
optional bool init = 1;
|
52
|
+
optional float weight = 2;
|
53
|
+
}
|
54
|
+
|
55
|
+
message Prompt {
|
56
|
+
optional PromptParameters parameters = 1;
|
57
|
+
oneof prompt {
|
58
|
+
string text = 2;
|
59
|
+
Tokens tokens = 3;
|
60
|
+
Artifact artifact = 4;
|
61
|
+
}
|
62
|
+
}
|
63
|
+
|
64
|
+
message AnswerMeta {
|
65
|
+
optional string gpu_id = 1;
|
66
|
+
optional string cpu_id = 2;
|
67
|
+
optional string node_id = 3;
|
68
|
+
optional string engine_id = 4;
|
69
|
+
}
|
70
|
+
|
71
|
+
message Answer {
|
72
|
+
string answer_id = 1;
|
73
|
+
string request_id = 2;
|
74
|
+
uint64 received = 3;
|
75
|
+
uint64 created = 4;
|
76
|
+
optional AnswerMeta meta = 6;
|
77
|
+
repeated Artifact artifacts = 7;
|
78
|
+
}
|
79
|
+
|
80
|
+
enum DiffusionSampler {
|
81
|
+
SAMPLER_DDIM = 0;
|
82
|
+
SAMPLER_DDPM = 1;
|
83
|
+
SAMPLER_K_EULER = 2;
|
84
|
+
SAMPLER_K_EULER_ANCESTRAL = 3;
|
85
|
+
SAMPLER_K_HEUN = 4;
|
86
|
+
SAMPLER_K_DPM_2 = 5;
|
87
|
+
SAMPLER_K_DPM_2_ANCESTRAL = 6;
|
88
|
+
SAMPLER_K_LMS = 7;
|
89
|
+
}
|
90
|
+
|
91
|
+
message SamplerParameters {
|
92
|
+
optional float eta = 1;
|
93
|
+
optional uint64 sampling_steps = 2;
|
94
|
+
optional uint64 latent_channels = 3;
|
95
|
+
optional uint64 downsampling_factor = 4;
|
96
|
+
optional float cfg_scale = 5;
|
97
|
+
}
|
98
|
+
|
99
|
+
message ConditionerParameters {
|
100
|
+
optional string vector_adjust_prior = 1;
|
101
|
+
}
|
102
|
+
|
103
|
+
enum Upscaler {
|
104
|
+
UPSCALER_RGB = 0;
|
105
|
+
UPSCALER_GFPGAN = 1;
|
106
|
+
UPSCALER_ESRGAN = 2;
|
107
|
+
}
|
108
|
+
|
109
|
+
message StepParameter {
|
110
|
+
float scaled_step = 1;
|
111
|
+
optional SamplerParameters sampler = 2;
|
112
|
+
}
|
113
|
+
|
114
|
+
message TransformType {
|
115
|
+
oneof type {
|
116
|
+
DiffusionSampler diffusion = 1;
|
117
|
+
Upscaler upscaler = 2;
|
118
|
+
}
|
119
|
+
}
|
120
|
+
|
121
|
+
message ImageParameters {
|
122
|
+
optional uint64 height = 1;
|
123
|
+
optional uint64 width = 2;
|
124
|
+
repeated uint32 seed = 3;
|
125
|
+
optional uint64 samples = 4;
|
126
|
+
optional uint64 steps = 5;
|
127
|
+
optional TransformType transform = 6;
|
128
|
+
repeated StepParameter parameters = 7;
|
129
|
+
}
|
130
|
+
|
131
|
+
enum Action {
|
132
|
+
ACTION_PASSTHROUGH = 0;
|
133
|
+
ACTION_REGENERATE_DUPLICATE = 1;
|
134
|
+
ACTION_REGENERATE = 2;
|
135
|
+
ACTION_OBFUSCATE_DUPLICATE = 3;
|
136
|
+
ACTION_OBFUSCATE = 4;
|
137
|
+
ACTION_DISCARD = 5;
|
138
|
+
}
|
139
|
+
|
140
|
+
enum ClassifierMode {
|
141
|
+
CLSFR_MODE_ZEROSHOT = 0;
|
142
|
+
CLSFR_MODE_MULTICLASS = 1;
|
143
|
+
/*CLSFR_MODE_ODDSRATIO = 2;*/
|
144
|
+
}
|
145
|
+
|
146
|
+
message ClassifierConcept {
|
147
|
+
string concept = 1;
|
148
|
+
optional float threshold = 2;
|
149
|
+
}
|
150
|
+
|
151
|
+
message ClassifierCategory {
|
152
|
+
string name = 1;
|
153
|
+
repeated ClassifierConcept concepts = 2;
|
154
|
+
optional float adjustment = 3;
|
155
|
+
optional Action action = 4;
|
156
|
+
optional ClassifierMode classifier_mode = 5;
|
157
|
+
}
|
158
|
+
|
159
|
+
message ClassifierParameters {
|
160
|
+
repeated ClassifierCategory categories = 1;
|
161
|
+
repeated ClassifierCategory exceeds = 2;
|
162
|
+
optional Action realized_action = 3;
|
163
|
+
}
|
164
|
+
|
165
|
+
message Request {
|
166
|
+
string engine_id = 1;
|
167
|
+
string request_id = 2;
|
168
|
+
ArtifactType requested_type = 3;
|
169
|
+
repeated Prompt prompt = 4;
|
170
|
+
oneof params {
|
171
|
+
ImageParameters image = 5;
|
172
|
+
}
|
173
|
+
optional ConditionerParameters conditioner = 6;
|
174
|
+
optional ClassifierParameters classifier = 7;
|
175
|
+
}
|
176
|
+
|
177
|
+
service GenerationService {
|
178
|
+
rpc Generate (Request) returns (stream Answer) {};
|
179
|
+
}
|
@@ -0,0 +1,30 @@
|
|
1
|
+
require_relative 'lib/stability_sdk/version'
|
2
|
+
|
3
|
+
Gem::Specification.new do |spec|
|
4
|
+
spec.name = "stability_sdk"
|
5
|
+
spec.version = StabilitySDK::VERSION
|
6
|
+
spec.authors = ["Kosei Moriyama"]
|
7
|
+
spec.email = ["cou929@gmail.com"]
|
8
|
+
|
9
|
+
spec.summary = "Ruby client for interacting with stability.ai APIs (e.g. stable diffusion inference)"
|
10
|
+
spec.description = "Ruby client of https://github.com/Stability-AI/stability-sdk"
|
11
|
+
spec.homepage = "https://github.com/cou929/stability-sdk-ruby"
|
12
|
+
spec.required_ruby_version = Gem::Requirement.new(">= 2.3.0")
|
13
|
+
|
14
|
+
spec.metadata["homepage_uri"] = spec.homepage
|
15
|
+
spec.metadata["source_code_uri"] = "https://github.com/cou929/stability-sdk-ruby"
|
16
|
+
spec.metadata["changelog_uri"] = "https://github.com/cou929/stability-sdk-ruby"
|
17
|
+
|
18
|
+
# Specify which files should be added to the gem when it is released.
|
19
|
+
# The `git ls-files -z` loads the files in the RubyGem that have been added into git.
|
20
|
+
spec.files = Dir.chdir(File.expand_path('..', __FILE__)) do
|
21
|
+
`git ls-files -z`.split("\x0").reject { |f| f.match(%r{^(test|spec|features)/}) }
|
22
|
+
end
|
23
|
+
spec.bindir = "exe"
|
24
|
+
spec.executables = spec.files.grep(%r{^exe/}) { |f| File.basename(f) }
|
25
|
+
spec.require_paths = ["lib"]
|
26
|
+
|
27
|
+
spec.add_dependency "grpc"
|
28
|
+
spec.add_dependency "mime-types"
|
29
|
+
spec.add_development_dependency "grpc-tools"
|
30
|
+
end
|
metadata
ADDED
@@ -0,0 +1,106 @@
|
|
1
|
+
--- !ruby/object:Gem::Specification
|
2
|
+
name: stability_sdk
|
3
|
+
version: !ruby/object:Gem::Version
|
4
|
+
version: 0.1.0
|
5
|
+
platform: ruby
|
6
|
+
authors:
|
7
|
+
- Kosei Moriyama
|
8
|
+
autorequire:
|
9
|
+
bindir: exe
|
10
|
+
cert_chain: []
|
11
|
+
date: 2022-09-07 00:00:00.000000000 Z
|
12
|
+
dependencies:
|
13
|
+
- !ruby/object:Gem::Dependency
|
14
|
+
name: grpc
|
15
|
+
requirement: !ruby/object:Gem::Requirement
|
16
|
+
requirements:
|
17
|
+
- - ">="
|
18
|
+
- !ruby/object:Gem::Version
|
19
|
+
version: '0'
|
20
|
+
type: :runtime
|
21
|
+
prerelease: false
|
22
|
+
version_requirements: !ruby/object:Gem::Requirement
|
23
|
+
requirements:
|
24
|
+
- - ">="
|
25
|
+
- !ruby/object:Gem::Version
|
26
|
+
version: '0'
|
27
|
+
- !ruby/object:Gem::Dependency
|
28
|
+
name: mime-types
|
29
|
+
requirement: !ruby/object:Gem::Requirement
|
30
|
+
requirements:
|
31
|
+
- - ">="
|
32
|
+
- !ruby/object:Gem::Version
|
33
|
+
version: '0'
|
34
|
+
type: :runtime
|
35
|
+
prerelease: false
|
36
|
+
version_requirements: !ruby/object:Gem::Requirement
|
37
|
+
requirements:
|
38
|
+
- - ">="
|
39
|
+
- !ruby/object:Gem::Version
|
40
|
+
version: '0'
|
41
|
+
- !ruby/object:Gem::Dependency
|
42
|
+
name: grpc-tools
|
43
|
+
requirement: !ruby/object:Gem::Requirement
|
44
|
+
requirements:
|
45
|
+
- - ">="
|
46
|
+
- !ruby/object:Gem::Version
|
47
|
+
version: '0'
|
48
|
+
type: :development
|
49
|
+
prerelease: false
|
50
|
+
version_requirements: !ruby/object:Gem::Requirement
|
51
|
+
requirements:
|
52
|
+
- - ">="
|
53
|
+
- !ruby/object:Gem::Version
|
54
|
+
version: '0'
|
55
|
+
description: Ruby client of https://github.com/Stability-AI/stability-sdk
|
56
|
+
email:
|
57
|
+
- cou929@gmail.com
|
58
|
+
executables:
|
59
|
+
- stability-client
|
60
|
+
extensions: []
|
61
|
+
extra_rdoc_files: []
|
62
|
+
files:
|
63
|
+
- ".gitignore"
|
64
|
+
- ".travis.yml"
|
65
|
+
- Gemfile
|
66
|
+
- Gemfile.lock
|
67
|
+
- README.md
|
68
|
+
- Rakefile
|
69
|
+
- bin/console
|
70
|
+
- bin/setup
|
71
|
+
- exe/stability-client
|
72
|
+
- lib/generation_pb.rb
|
73
|
+
- lib/generation_services_pb.rb
|
74
|
+
- lib/stability_sdk.rb
|
75
|
+
- lib/stability_sdk/cli.rb
|
76
|
+
- lib/stability_sdk/client.rb
|
77
|
+
- lib/stability_sdk/version.rb
|
78
|
+
- proto/generation.proto
|
79
|
+
- stability_sdk.gemspec
|
80
|
+
homepage: https://github.com/cou929/stability-sdk-ruby
|
81
|
+
licenses: []
|
82
|
+
metadata:
|
83
|
+
homepage_uri: https://github.com/cou929/stability-sdk-ruby
|
84
|
+
source_code_uri: https://github.com/cou929/stability-sdk-ruby
|
85
|
+
changelog_uri: https://github.com/cou929/stability-sdk-ruby
|
86
|
+
post_install_message:
|
87
|
+
rdoc_options: []
|
88
|
+
require_paths:
|
89
|
+
- lib
|
90
|
+
required_ruby_version: !ruby/object:Gem::Requirement
|
91
|
+
requirements:
|
92
|
+
- - ">="
|
93
|
+
- !ruby/object:Gem::Version
|
94
|
+
version: 2.3.0
|
95
|
+
required_rubygems_version: !ruby/object:Gem::Requirement
|
96
|
+
requirements:
|
97
|
+
- - ">="
|
98
|
+
- !ruby/object:Gem::Version
|
99
|
+
version: '0'
|
100
|
+
requirements: []
|
101
|
+
rubygems_version: 3.1.4
|
102
|
+
signing_key:
|
103
|
+
specification_version: 4
|
104
|
+
summary: Ruby client for interacting with stability.ai APIs (e.g. stable diffusion
|
105
|
+
inference)
|
106
|
+
test_files: []
|