stability_sdk 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml ADDED
@@ -0,0 +1,7 @@
1
+ ---
2
+ SHA256:
3
+ metadata.gz: dac029de10ccef05e11d7fee6a581f14fec50252e1cf7790ce2e849e5061489c
4
+ data.tar.gz: 66f8803ff804dae6971cdf3b8473225012255e87d67f00484583d65ecd8afb6a
5
+ SHA512:
6
+ metadata.gz: ad988347184eafe784ab654b0022f1742cf56e51c9e2d4d88b3507f97162c6aaa7a165ad8ee5a6d3c34c1a9427fc860a54679192235c154bfea3008765dd8f89
7
+ data.tar.gz: d4522e6b2644f0ef309dbfb862a0667ed80c77e93c30bc575764997aa511d73789b3f270641d44c83cc2fdc017b7a0d5562e7de79ed7bebf86ff894019ad2642
data/.gitignore ADDED
@@ -0,0 +1,8 @@
1
+ /.bundle/
2
+ /.yardoc
3
+ /_yardoc/
4
+ /coverage/
5
+ /doc/
6
+ /pkg/
7
+ /spec/reports/
8
+ /tmp/
data/.travis.yml ADDED
@@ -0,0 +1,6 @@
1
+ ---
2
+ language: ruby
3
+ cache: bundler
4
+ rvm:
5
+ - 2.7.2
6
+ before_install: gem install bundler -v 2.1.4
data/Gemfile ADDED
@@ -0,0 +1,7 @@
1
+ source "https://rubygems.org"
2
+
3
+ # Specify your gem's dependencies in stability_sdk.gemspec
4
+ gemspec
5
+
6
+ gem "rake", "~> 12.0"
7
+ gem "minitest", "~> 5.0"
data/Gemfile.lock ADDED
@@ -0,0 +1,34 @@
1
+ PATH
2
+ remote: .
3
+ specs:
4
+ stability_sdk (0.1.0)
5
+ grpc
6
+ mime-types
7
+
8
+ GEM
9
+ remote: https://rubygems.org/
10
+ specs:
11
+ google-protobuf (3.21.5)
12
+ googleapis-common-protos-types (1.4.0)
13
+ google-protobuf (~> 3.14)
14
+ grpc (1.48.0)
15
+ google-protobuf (~> 3.19)
16
+ googleapis-common-protos-types (~> 1.0)
17
+ grpc-tools (1.48.0)
18
+ mime-types (3.4.1)
19
+ mime-types-data (~> 3.2015)
20
+ mime-types-data (3.2022.0105)
21
+ minitest (5.16.3)
22
+ rake (12.3.3)
23
+
24
+ PLATFORMS
25
+ arm64-darwin-20
26
+
27
+ DEPENDENCIES
28
+ grpc-tools
29
+ minitest (~> 5.0)
30
+ rake (~> 12.0)
31
+ stability_sdk!
32
+
33
+ BUNDLED WITH
34
+ 2.2.2
data/README.md ADDED
@@ -0,0 +1,89 @@
1
+ # StabilitySDK - Ruby client for stability.ai APIs, such as Stable Diffusion
2
+
3
+ A ruby client for [stability.ai](https://stability.ai/) APIs, e.g., [Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release). Referring to https://github.com/Stability-AI/stability-sdk.
4
+
5
+ ## Installation
6
+
7
+ Add this line to your application's Gemfile:
8
+
9
+ ```ruby
10
+ gem 'stability_sdk'
11
+ ```
12
+
13
+ And then execute:
14
+
15
+ $ bundle install
16
+
17
+ Or install it yourself as:
18
+
19
+ $ gem install stability_sdk
20
+
21
+ ## Usage
22
+
23
+ First you need to create a [DreamStudio](https://beta.dreamstudio.ai/home)'s account and get an API Key of it.
24
+
25
+ - Access [DreamStudio](https://beta.dreamstudio.ai/dream) and create an account if you have not had it
26
+ - Go to the [membership page](https://beta.dreamstudio.ai/membership)
27
+ - You can get the API Key in an `API Key` tab
28
+
29
+ ### Command line usage
30
+
31
+ ```sh
32
+ STABILITY_SDK_API_KEY=YOUR_API_KEY stability-client 'A night in winter, oil-on-canvas landscape painting, by Vincent van Gogh'
33
+ ```
34
+
35
+ This command saves an image like this:
36
+
37
+ ![3749380973_A_night_in_winter__oil_on_canvas_landscape_painting__by_Vincent_van_Gogh](https://user-images.githubusercontent.com/25668/188884116-0b03494b-0b34-49de-bbbc-89fbc2f6029d.png)
38
+
39
+
40
+ ```sh
41
+ Usage: stability-client [options] YOUR_PROMPT_TEXT
42
+
43
+ Options:
44
+ --api_key=VAL api key of DreamStudio account. You can also specify by a STABILITY_SDK_API_KEY environment variable
45
+ -H, --height=VAL height of image in pixel. default 512
46
+ -W, --width=VAL width of image in pixel. default 512
47
+ -C, --cfg_scale=VAL CFG scale factor. default 7.0
48
+ -A, --sampler=VAL ddim, plms, k_euler, k_euler_ancestral, k_heun, k_dpm_2, k_dpm_2_ancestral, k_lms. default k_lms
49
+ -s, --steps=VAL number of steps. default 50
50
+ -S, --seed=VAL random seed to use in integer
51
+ -p, --prefix=VAL output prefixes for artifacts. default `generation`
52
+ --no-store do not write out artifacts
53
+ -n, --num_samples=VAL number of samples to generate
54
+ -e, --engine=VAL engine to use for inference. default `stable-diffusion-v1`
55
+ -v, --verbose
56
+ ```
57
+
58
+ ### SDK usage
59
+
60
+ This sample code saves a generated image as `result.png`.
61
+
62
+ ```ruby
63
+ require "stability_sdk"
64
+
65
+ client = StabilitySDK::Client.new(api_key: "YOUR_API_KEY")
66
+
67
+ prompt = "your prompot here"
68
+ options = {}
69
+
70
+ client.generate(prompt, options) do |answer|
71
+ answer.artifacts.each do |artifact|
72
+ if artifact.type == :ARTIFACT_IMAGE
73
+ File.open("result.png", "wb") do |f|
74
+ f.write(artifact.binary)
75
+ end
76
+ end
77
+ end
78
+ end
79
+ ```
80
+
81
+ ## Development
82
+
83
+ After checking out the repo, run `bin/setup` to install dependencies. Then, run `rake test` to run the tests. You can also run `bin/console` for an interactive prompt that will allow you to experiment.
84
+
85
+ To install this gem onto your local machine, run `bundle exec rake install`. To release a new version, update the version number in `version.rb`, and then run `bundle exec rake release`, which will create a git tag for the version, push git commits and tags, and push the `.gem` file to [rubygems.org](https://rubygems.org).
86
+
87
+ ## Contributing
88
+
89
+ Bug reports and pull requests are welcome on GitHub at https://github.com/cou929/stability-sdk-ruby.
data/Rakefile ADDED
@@ -0,0 +1,10 @@
1
+ require "bundler/gem_tasks"
2
+ require "rake/testtask"
3
+
4
+ Rake::TestTask.new(:test) do |t|
5
+ t.libs << "test"
6
+ t.libs << "lib"
7
+ t.test_files = FileList["test/**/*_test.rb"]
8
+ end
9
+
10
+ task :default => :test
data/bin/console ADDED
@@ -0,0 +1,14 @@
1
+ #!/usr/bin/env ruby
2
+
3
+ require "bundler/setup"
4
+ require "stability_sdk"
5
+
6
+ # You can add fixtures and/or initialization code here to make experimenting
7
+ # with your gem easier. You can also use a different console, if you like.
8
+
9
+ # (If you use this, don't forget to add pry to your Gemfile!)
10
+ # require "pry"
11
+ # Pry.start
12
+
13
+ require "irb"
14
+ IRB.start(__FILE__)
data/bin/setup ADDED
@@ -0,0 +1,8 @@
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+ IFS=$'\n\t'
4
+ set -vx
5
+
6
+ bundle install
7
+
8
+ # Do any other automated setup that you need to do here
@@ -0,0 +1,43 @@
1
+ #!/usr/bin/env ruby
2
+
3
+ require "optparse"
4
+ require "mime/types"
5
+ require "logger"
6
+ require "stability_sdk"
7
+
8
+ logger = Logger.new(STDOUT)
9
+ logger.level = Logger::WARN
10
+
11
+ opt = OptionParser.new
12
+ Version = StabilitySDK::VERSION
13
+
14
+ options = {}
15
+
16
+ opt.banner = "Usage: stability-client [options] YOUR_PROMPT_TEXT"
17
+ opt.separator ""
18
+ opt.separator "Options:"
19
+ opt.on("--api_key=VAL", "api key of DreamStudio account. You can also specify by a STABILITY_SDK_API_KEY environment variable") {|v| options[:api_key] = v }
20
+ opt.on("-H", "--height=VAL", "height of image in pixel. default 512") {|v| options[:height] = v }
21
+ opt.on("-W", "--width=VAL", "width of image in pixel. default 512") {|v| options[:width] = v }
22
+ opt.on("-C", "--cfg_scale=VAL", "CFG scale factor. default 7.0") {|v| options[:cfg_scale] = v }
23
+ opt.on("-A", "--sampler=VAL", "ddim, plms, k_euler, k_euler_ancestral, k_heun, k_dpm_2, k_dpm_2_ancestral, k_lms. default k_lms") {|v| options[:sampler] = v }
24
+ opt.on("-s", "--steps=VAL", "number of steps. default 50") {|v| options[:steps] = v }
25
+ opt.on("-S", "--seed=VAL", "random seed to use in integer") {|v| options[:seed] = v }
26
+ opt.on("-p", "--prefix=VAL", "output prefixes for artifacts. default `generation`") {|v| options[:prefix] = v }
27
+ opt.on("--no-store", "do not write out artifacts") {|v| options[:no_store] = v }
28
+ opt.on("-n", "--num_samples=VAL", "number of samples to generate") {|v| options[:num_samples] = v }
29
+ opt.on("-e", "--engine=VAL", "engine to use for inference. default `stable-diffusion-v1`") {|v| options[:engine_id] = v }
30
+ opt.on("-v", "--verbose") { logger.level = Logger::DEBUG }
31
+ opt.parse!(ARGV)
32
+
33
+ prompt = ARGV.join(" ")
34
+ raise StabilitySDK::InsufficientParameter, "prompt is required" if prompt.nil? || prompt == ""
35
+
36
+ options[:api_key] = ENV["STABILITY_SDK_API_KEY"] if ENV["STABILITY_SDK_API_KEY"]
37
+ raise StabilitySDK::InsufficientParameter, "api key is required" if !options.has_key?(:api_key) || options[:api_key] == ""
38
+
39
+ client = StabilitySDK::Client.new(api_key: options[:api_key])
40
+
41
+ client.generate(prompt, options) do |answer|
42
+ StabilitySDK::CLI.save_answer(answer, options, logger)
43
+ end
@@ -0,0 +1,182 @@
1
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
2
+ # source: generation.proto
3
+
4
+ require 'google/protobuf'
5
+
6
+ Google::Protobuf::DescriptorPool.generated_pool.build do
7
+ add_file("generation.proto", :syntax => :proto3) do
8
+ add_message "gooseai.Token" do
9
+ proto3_optional :text, :string, 1
10
+ optional :id, :uint32, 2
11
+ end
12
+ add_message "gooseai.Tokens" do
13
+ repeated :tokens, :message, 1, "gooseai.Token"
14
+ proto3_optional :tokenizer_id, :string, 2
15
+ end
16
+ add_message "gooseai.Artifact" do
17
+ optional :id, :uint64, 1
18
+ optional :type, :enum, 2, "gooseai.ArtifactType"
19
+ optional :mime, :string, 3
20
+ proto3_optional :magic, :string, 4
21
+ optional :index, :uint32, 8
22
+ optional :finish_reason, :enum, 9, "gooseai.FinishReason"
23
+ optional :seed, :uint32, 10
24
+ oneof :data do
25
+ optional :binary, :bytes, 5
26
+ optional :text, :string, 6
27
+ optional :tokens, :message, 7, "gooseai.Tokens"
28
+ optional :classifier, :message, 11, "gooseai.ClassifierParameters"
29
+ end
30
+ end
31
+ add_message "gooseai.PromptParameters" do
32
+ proto3_optional :init, :bool, 1
33
+ proto3_optional :weight, :float, 2
34
+ end
35
+ add_message "gooseai.Prompt" do
36
+ proto3_optional :parameters, :message, 1, "gooseai.PromptParameters"
37
+ oneof :prompt do
38
+ optional :text, :string, 2
39
+ optional :tokens, :message, 3, "gooseai.Tokens"
40
+ optional :artifact, :message, 4, "gooseai.Artifact"
41
+ end
42
+ end
43
+ add_message "gooseai.AnswerMeta" do
44
+ proto3_optional :gpu_id, :string, 1
45
+ proto3_optional :cpu_id, :string, 2
46
+ proto3_optional :node_id, :string, 3
47
+ proto3_optional :engine_id, :string, 4
48
+ end
49
+ add_message "gooseai.Answer" do
50
+ optional :answer_id, :string, 1
51
+ optional :request_id, :string, 2
52
+ optional :received, :uint64, 3
53
+ optional :created, :uint64, 4
54
+ proto3_optional :meta, :message, 6, "gooseai.AnswerMeta"
55
+ repeated :artifacts, :message, 7, "gooseai.Artifact"
56
+ end
57
+ add_message "gooseai.SamplerParameters" do
58
+ proto3_optional :eta, :float, 1
59
+ proto3_optional :sampling_steps, :uint64, 2
60
+ proto3_optional :latent_channels, :uint64, 3
61
+ proto3_optional :downsampling_factor, :uint64, 4
62
+ proto3_optional :cfg_scale, :float, 5
63
+ end
64
+ add_message "gooseai.ConditionerParameters" do
65
+ proto3_optional :vector_adjust_prior, :string, 1
66
+ end
67
+ add_message "gooseai.StepParameter" do
68
+ optional :scaled_step, :float, 1
69
+ proto3_optional :sampler, :message, 2, "gooseai.SamplerParameters"
70
+ end
71
+ add_message "gooseai.TransformType" do
72
+ oneof :type do
73
+ optional :diffusion, :enum, 1, "gooseai.DiffusionSampler"
74
+ optional :upscaler, :enum, 2, "gooseai.Upscaler"
75
+ end
76
+ end
77
+ add_message "gooseai.ImageParameters" do
78
+ proto3_optional :height, :uint64, 1
79
+ proto3_optional :width, :uint64, 2
80
+ repeated :seed, :uint32, 3
81
+ proto3_optional :samples, :uint64, 4
82
+ proto3_optional :steps, :uint64, 5
83
+ proto3_optional :transform, :message, 6, "gooseai.TransformType"
84
+ repeated :parameters, :message, 7, "gooseai.StepParameter"
85
+ end
86
+ add_message "gooseai.ClassifierConcept" do
87
+ optional :concept, :string, 1
88
+ proto3_optional :threshold, :float, 2
89
+ end
90
+ add_message "gooseai.ClassifierCategory" do
91
+ optional :name, :string, 1
92
+ repeated :concepts, :message, 2, "gooseai.ClassifierConcept"
93
+ proto3_optional :adjustment, :float, 3
94
+ proto3_optional :action, :enum, 4, "gooseai.Action"
95
+ proto3_optional :classifier_mode, :enum, 5, "gooseai.ClassifierMode"
96
+ end
97
+ add_message "gooseai.ClassifierParameters" do
98
+ repeated :categories, :message, 1, "gooseai.ClassifierCategory"
99
+ repeated :exceeds, :message, 2, "gooseai.ClassifierCategory"
100
+ proto3_optional :realized_action, :enum, 3, "gooseai.Action"
101
+ end
102
+ add_message "gooseai.Request" do
103
+ optional :engine_id, :string, 1
104
+ optional :request_id, :string, 2
105
+ optional :requested_type, :enum, 3, "gooseai.ArtifactType"
106
+ repeated :prompt, :message, 4, "gooseai.Prompt"
107
+ proto3_optional :conditioner, :message, 6, "gooseai.ConditionerParameters"
108
+ proto3_optional :classifier, :message, 7, "gooseai.ClassifierParameters"
109
+ oneof :params do
110
+ optional :image, :message, 5, "gooseai.ImageParameters"
111
+ end
112
+ end
113
+ add_enum "gooseai.FinishReason" do
114
+ value :NULL, 0
115
+ value :LENGTH, 1
116
+ value :STOP, 2
117
+ value :ERROR, 3
118
+ value :FILTER, 4
119
+ end
120
+ add_enum "gooseai.ArtifactType" do
121
+ value :ARTIFACT_NONE, 0
122
+ value :ARTIFACT_IMAGE, 1
123
+ value :ARTIFACT_VIDEO, 2
124
+ value :ARTIFACT_TEXT, 3
125
+ value :ARTIFACT_TOKENS, 4
126
+ value :ARTIFACT_EMBEDDING, 5
127
+ value :ARTIFACT_CLASSIFICATIONS, 6
128
+ end
129
+ add_enum "gooseai.DiffusionSampler" do
130
+ value :SAMPLER_DDIM, 0
131
+ value :SAMPLER_DDPM, 1
132
+ value :SAMPLER_K_EULER, 2
133
+ value :SAMPLER_K_EULER_ANCESTRAL, 3
134
+ value :SAMPLER_K_HEUN, 4
135
+ value :SAMPLER_K_DPM_2, 5
136
+ value :SAMPLER_K_DPM_2_ANCESTRAL, 6
137
+ value :SAMPLER_K_LMS, 7
138
+ end
139
+ add_enum "gooseai.Upscaler" do
140
+ value :UPSCALER_RGB, 0
141
+ value :UPSCALER_GFPGAN, 1
142
+ value :UPSCALER_ESRGAN, 2
143
+ end
144
+ add_enum "gooseai.Action" do
145
+ value :ACTION_PASSTHROUGH, 0
146
+ value :ACTION_REGENERATE_DUPLICATE, 1
147
+ value :ACTION_REGENERATE, 2
148
+ value :ACTION_OBFUSCATE_DUPLICATE, 3
149
+ value :ACTION_OBFUSCATE, 4
150
+ value :ACTION_DISCARD, 5
151
+ end
152
+ add_enum "gooseai.ClassifierMode" do
153
+ value :CLSFR_MODE_ZEROSHOT, 0
154
+ value :CLSFR_MODE_MULTICLASS, 1
155
+ end
156
+ end
157
+ end
158
+
159
+ module Gooseai
160
+ Token = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.Token").msgclass
161
+ Tokens = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.Tokens").msgclass
162
+ Artifact = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.Artifact").msgclass
163
+ PromptParameters = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.PromptParameters").msgclass
164
+ Prompt = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.Prompt").msgclass
165
+ AnswerMeta = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.AnswerMeta").msgclass
166
+ Answer = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.Answer").msgclass
167
+ SamplerParameters = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.SamplerParameters").msgclass
168
+ ConditionerParameters = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.ConditionerParameters").msgclass
169
+ StepParameter = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.StepParameter").msgclass
170
+ TransformType = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.TransformType").msgclass
171
+ ImageParameters = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.ImageParameters").msgclass
172
+ ClassifierConcept = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.ClassifierConcept").msgclass
173
+ ClassifierCategory = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.ClassifierCategory").msgclass
174
+ ClassifierParameters = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.ClassifierParameters").msgclass
175
+ Request = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.Request").msgclass
176
+ FinishReason = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.FinishReason").enummodule
177
+ ArtifactType = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.ArtifactType").enummodule
178
+ DiffusionSampler = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.DiffusionSampler").enummodule
179
+ Upscaler = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.Upscaler").enummodule
180
+ Action = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.Action").enummodule
181
+ ClassifierMode = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.ClassifierMode").enummodule
182
+ end
@@ -0,0 +1,22 @@
1
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
2
+ # Source: generation.proto for package 'gooseai'
3
+
4
+ require 'grpc'
5
+ require 'generation_pb'
6
+
7
+ module Gooseai
8
+ module GenerationService
9
+ class Service
10
+
11
+ include ::GRPC::GenericService
12
+
13
+ self.marshal_class_method = :encode
14
+ self.unmarshal_class_method = :decode
15
+ self.service_name = 'gooseai.GenerationService'
16
+
17
+ rpc :Generate, ::Gooseai::Request, stream(::Gooseai::Answer)
18
+ end
19
+
20
+ Stub = Service.rpc_stub_class
21
+ end
22
+ end
@@ -0,0 +1,34 @@
1
+ module StabilitySDK
2
+ class CLI
3
+ def self.save_answer(answer, options, logger)
4
+ logger.debug "received answer: #{answer}"
5
+ return if options[:no_store]
6
+
7
+ answer.artifacts.each_with_index do |artifact, idx|
8
+ filename_base = "#{options[:prefix] || "generation"}-#{answer.request_id}-#{answer.answer_id}-#{idx}"
9
+
10
+ filename = ""
11
+ contents = ""
12
+
13
+ case artifact.type
14
+ when :ARTIFACT_IMAGE
15
+ ext = MIME::Types[artifact.mime].first.preferred_extension
16
+ filename = "#{filename_base}.#{ext}"
17
+ contents = artifact.binary
18
+ when :ARTIFACT_CLASSIFICATIONS
19
+ ext = "pb.json"
20
+ filename = "#{filename_base}.#{ext}"
21
+ contents = artifact.classifier.to_json
22
+ else
23
+ logger.warn "not implemented for ArtifactType #{artifact.type}"
24
+ end
25
+
26
+ next if filename == "" || contents == ""
27
+
28
+ File.open(filename, "wb") do |f|
29
+ f.write(contents)
30
+ end
31
+ end
32
+ end
33
+ end
34
+ end
@@ -0,0 +1,75 @@
1
+ require "grpc"
2
+ require "generation_services_pb"
3
+
4
+ module StabilitySDK
5
+ class Client
6
+ DEFAULT_API_HOST = "grpc.stability.ai:443"
7
+ DEFAULT_IMAGE_WIDTH = 512
8
+ DEFAULT_IMAGE_HEIGHT = 512
9
+ DEFAULT_SAMPLE_SIZE = 1
10
+ DEFAULT_STEPS = 50
11
+ DEFAULT_ENGINE_ID = "stable-diffusion-v1"
12
+ DEFAULT_CFG_SCALE = 7.0
13
+ DEFAULT_SAMPLER_ALGORITHM = Gooseai::DiffusionSampler::SAMPLER_K_LMS
14
+
15
+ sampler_algorithms = {
16
+ "ddim": Gooseai::DiffusionSampler::SAMPLER_DDIM,
17
+ "plms": Gooseai::DiffusionSampler::SAMPLER_DDPM,
18
+ "k_euler": Gooseai::DiffusionSampler::SAMPLER_K_EULER,
19
+ "k_euler_ancestral": Gooseai::DiffusionSampler::SAMPLER_K_EULER_ANCESTRAL,
20
+ "k_heun": Gooseai::DiffusionSampler::SAMPLER_K_HEUN,
21
+ "k_dpm_2": Gooseai::DiffusionSampler::SAMPLER_K_DPM_2,
22
+ "k_dpm_2_ancestral": Gooseai::DiffusionSampler::SAMPLER_K_DPM_2_ANCESTRAL,
23
+ "k_lms": Gooseai::DiffusionSampler::SAMPLER_K_LMS,
24
+ }
25
+
26
+ def initialize(options={})
27
+ host = options[:api_host] || DEFAULT_API_HOST
28
+ channel_creds = options.has_key?(:ca_cert) ? GRPC::Core::ChannelCredentials.new(options[:ca_cert]) : GRPC::Core::ChannelCredentials.new
29
+ call_creds = GRPC::Core::CallCredentials.new(proc { { "authorization" => "Bearer #{options[:api_key]}" } })
30
+ creds = channel_creds.compose(call_creds)
31
+
32
+ @stub = Gooseai::GenerationService::Stub.new(host, creds)
33
+ end
34
+
35
+ def generate(prompt, options, &block)
36
+ image_param = image_param(options)
37
+ req = Gooseai::Request.new(
38
+ prompt: [Gooseai::Prompt.new(text: prompt)],
39
+ engine_id: options[:engine_id] || DEFAULT_ENGINE_ID,
40
+ image: image_param
41
+ )
42
+
43
+ @stub.generate(req).each do |answer|
44
+ block.call(answer)
45
+ end
46
+ end
47
+
48
+ def image_param(options={})
49
+ width = options.has_key?(:width) ? options[:width].to_i : DEFAULT_IMAGE_WIDTH
50
+ height = options.has_key?(:height) ? options[:height] : DEFAULT_IMAGE_HEIGHT
51
+ samples = options.has_key?(:num_samples) ? [:num_samples].to_i : DEFAULT_SAMPLE_SIZE
52
+ steps = options.has_key?(:steps) ? options[:steps].to_i : DEFAULT_STEPS
53
+ seed = options.has_key?(:seed) ? [options[:seed]] : [rand(4294967295)]
54
+ transform = Gooseai::TransformType.new(
55
+ diffusion: options.has_key?(:sampler) ? sampler_algorithms[options[:sampler]] : DEFAULT_SAMPLER_ALGORITHM,
56
+ )
57
+ parameters = [Gooseai::StepParameter.new(
58
+ scaled_step: 0,
59
+ sampler: Gooseai::SamplerParameters.new(
60
+ cfg_scale: options.has_key?(:cfg_scale) ? options[:cfg_scale].to_f : DEFAULT_CFG_SCALE,
61
+ ),
62
+ )]
63
+
64
+ return Gooseai::ImageParameters.new(
65
+ width: width,
66
+ height: height,
67
+ samples: samples,
68
+ steps: steps,
69
+ seed: seed,
70
+ transform: transform,
71
+ parameters: parameters,
72
+ )
73
+ end
74
+ end
75
+ end
@@ -0,0 +1,3 @@
1
+ module StabilitySDK
2
+ VERSION = "0.1.0"
3
+ end
@@ -0,0 +1,8 @@
1
+ require "stability_sdk/version"
2
+ require "stability_sdk/client"
3
+ require "stability_sdk/cli"
4
+
5
+ module StabilitySDK
6
+ class Error < StandardError; end
7
+ class InsufficientParameter < StandardError; end
8
+ end
@@ -0,0 +1,179 @@
1
+ syntax = 'proto3';
2
+ package gooseai;
3
+ option go_package = "./;generation";
4
+
5
+ enum FinishReason {
6
+ NULL = 0;
7
+ LENGTH = 1;
8
+ STOP = 2;
9
+ ERROR = 3;
10
+ FILTER = 4;
11
+ }
12
+
13
+
14
+ enum ArtifactType {
15
+ ARTIFACT_NONE = 0;
16
+ ARTIFACT_IMAGE = 1;
17
+ ARTIFACT_VIDEO = 2;
18
+ ARTIFACT_TEXT = 3;
19
+ ARTIFACT_TOKENS = 4;
20
+ ARTIFACT_EMBEDDING = 5;
21
+ ARTIFACT_CLASSIFICATIONS = 6;
22
+ }
23
+
24
+ message Token {
25
+ optional string text = 1;
26
+ uint32 id = 2;
27
+ }
28
+
29
+ message Tokens {
30
+ repeated Token tokens = 1;
31
+ optional string tokenizer_id = 2;
32
+ }
33
+
34
+ message Artifact {
35
+ uint64 id = 1;
36
+ ArtifactType type = 2;
37
+ string mime = 3;
38
+ optional string magic = 4;
39
+ oneof data {
40
+ bytes binary = 5;
41
+ string text = 6;
42
+ Tokens tokens = 7;
43
+ ClassifierParameters classifier = 11;
44
+ }
45
+ uint32 index = 8;
46
+ FinishReason finish_reason = 9;
47
+ uint32 seed = 10;
48
+ }
49
+
50
+ message PromptParameters {
51
+ optional bool init = 1;
52
+ optional float weight = 2;
53
+ }
54
+
55
+ message Prompt {
56
+ optional PromptParameters parameters = 1;
57
+ oneof prompt {
58
+ string text = 2;
59
+ Tokens tokens = 3;
60
+ Artifact artifact = 4;
61
+ }
62
+ }
63
+
64
+ message AnswerMeta {
65
+ optional string gpu_id = 1;
66
+ optional string cpu_id = 2;
67
+ optional string node_id = 3;
68
+ optional string engine_id = 4;
69
+ }
70
+
71
+ message Answer {
72
+ string answer_id = 1;
73
+ string request_id = 2;
74
+ uint64 received = 3;
75
+ uint64 created = 4;
76
+ optional AnswerMeta meta = 6;
77
+ repeated Artifact artifacts = 7;
78
+ }
79
+
80
+ enum DiffusionSampler {
81
+ SAMPLER_DDIM = 0;
82
+ SAMPLER_DDPM = 1;
83
+ SAMPLER_K_EULER = 2;
84
+ SAMPLER_K_EULER_ANCESTRAL = 3;
85
+ SAMPLER_K_HEUN = 4;
86
+ SAMPLER_K_DPM_2 = 5;
87
+ SAMPLER_K_DPM_2_ANCESTRAL = 6;
88
+ SAMPLER_K_LMS = 7;
89
+ }
90
+
91
+ message SamplerParameters {
92
+ optional float eta = 1;
93
+ optional uint64 sampling_steps = 2;
94
+ optional uint64 latent_channels = 3;
95
+ optional uint64 downsampling_factor = 4;
96
+ optional float cfg_scale = 5;
97
+ }
98
+
99
+ message ConditionerParameters {
100
+ optional string vector_adjust_prior = 1;
101
+ }
102
+
103
+ enum Upscaler {
104
+ UPSCALER_RGB = 0;
105
+ UPSCALER_GFPGAN = 1;
106
+ UPSCALER_ESRGAN = 2;
107
+ }
108
+
109
+ message StepParameter {
110
+ float scaled_step = 1;
111
+ optional SamplerParameters sampler = 2;
112
+ }
113
+
114
+ message TransformType {
115
+ oneof type {
116
+ DiffusionSampler diffusion = 1;
117
+ Upscaler upscaler = 2;
118
+ }
119
+ }
120
+
121
+ message ImageParameters {
122
+ optional uint64 height = 1;
123
+ optional uint64 width = 2;
124
+ repeated uint32 seed = 3;
125
+ optional uint64 samples = 4;
126
+ optional uint64 steps = 5;
127
+ optional TransformType transform = 6;
128
+ repeated StepParameter parameters = 7;
129
+ }
130
+
131
+ enum Action {
132
+ ACTION_PASSTHROUGH = 0;
133
+ ACTION_REGENERATE_DUPLICATE = 1;
134
+ ACTION_REGENERATE = 2;
135
+ ACTION_OBFUSCATE_DUPLICATE = 3;
136
+ ACTION_OBFUSCATE = 4;
137
+ ACTION_DISCARD = 5;
138
+ }
139
+
140
+ enum ClassifierMode {
141
+ CLSFR_MODE_ZEROSHOT = 0;
142
+ CLSFR_MODE_MULTICLASS = 1;
143
+ /*CLSFR_MODE_ODDSRATIO = 2;*/
144
+ }
145
+
146
+ message ClassifierConcept {
147
+ string concept = 1;
148
+ optional float threshold = 2;
149
+ }
150
+
151
+ message ClassifierCategory {
152
+ string name = 1;
153
+ repeated ClassifierConcept concepts = 2;
154
+ optional float adjustment = 3;
155
+ optional Action action = 4;
156
+ optional ClassifierMode classifier_mode = 5;
157
+ }
158
+
159
+ message ClassifierParameters {
160
+ repeated ClassifierCategory categories = 1;
161
+ repeated ClassifierCategory exceeds = 2;
162
+ optional Action realized_action = 3;
163
+ }
164
+
165
+ message Request {
166
+ string engine_id = 1;
167
+ string request_id = 2;
168
+ ArtifactType requested_type = 3;
169
+ repeated Prompt prompt = 4;
170
+ oneof params {
171
+ ImageParameters image = 5;
172
+ }
173
+ optional ConditionerParameters conditioner = 6;
174
+ optional ClassifierParameters classifier = 7;
175
+ }
176
+
177
+ service GenerationService {
178
+ rpc Generate (Request) returns (stream Answer) {};
179
+ }
@@ -0,0 +1,30 @@
1
+ require_relative 'lib/stability_sdk/version'
2
+
3
+ Gem::Specification.new do |spec|
4
+ spec.name = "stability_sdk"
5
+ spec.version = StabilitySDK::VERSION
6
+ spec.authors = ["Kosei Moriyama"]
7
+ spec.email = ["cou929@gmail.com"]
8
+
9
+ spec.summary = "Ruby client for interacting with stability.ai APIs (e.g. stable diffusion inference)"
10
+ spec.description = "Ruby client of https://github.com/Stability-AI/stability-sdk"
11
+ spec.homepage = "https://github.com/cou929/stability-sdk-ruby"
12
+ spec.required_ruby_version = Gem::Requirement.new(">= 2.3.0")
13
+
14
+ spec.metadata["homepage_uri"] = spec.homepage
15
+ spec.metadata["source_code_uri"] = "https://github.com/cou929/stability-sdk-ruby"
16
+ spec.metadata["changelog_uri"] = "https://github.com/cou929/stability-sdk-ruby"
17
+
18
+ # Specify which files should be added to the gem when it is released.
19
+ # The `git ls-files -z` loads the files in the RubyGem that have been added into git.
20
+ spec.files = Dir.chdir(File.expand_path('..', __FILE__)) do
21
+ `git ls-files -z`.split("\x0").reject { |f| f.match(%r{^(test|spec|features)/}) }
22
+ end
23
+ spec.bindir = "exe"
24
+ spec.executables = spec.files.grep(%r{^exe/}) { |f| File.basename(f) }
25
+ spec.require_paths = ["lib"]
26
+
27
+ spec.add_dependency "grpc"
28
+ spec.add_dependency "mime-types"
29
+ spec.add_development_dependency "grpc-tools"
30
+ end
metadata ADDED
@@ -0,0 +1,106 @@
1
+ --- !ruby/object:Gem::Specification
2
+ name: stability_sdk
3
+ version: !ruby/object:Gem::Version
4
+ version: 0.1.0
5
+ platform: ruby
6
+ authors:
7
+ - Kosei Moriyama
8
+ autorequire:
9
+ bindir: exe
10
+ cert_chain: []
11
+ date: 2022-09-07 00:00:00.000000000 Z
12
+ dependencies:
13
+ - !ruby/object:Gem::Dependency
14
+ name: grpc
15
+ requirement: !ruby/object:Gem::Requirement
16
+ requirements:
17
+ - - ">="
18
+ - !ruby/object:Gem::Version
19
+ version: '0'
20
+ type: :runtime
21
+ prerelease: false
22
+ version_requirements: !ruby/object:Gem::Requirement
23
+ requirements:
24
+ - - ">="
25
+ - !ruby/object:Gem::Version
26
+ version: '0'
27
+ - !ruby/object:Gem::Dependency
28
+ name: mime-types
29
+ requirement: !ruby/object:Gem::Requirement
30
+ requirements:
31
+ - - ">="
32
+ - !ruby/object:Gem::Version
33
+ version: '0'
34
+ type: :runtime
35
+ prerelease: false
36
+ version_requirements: !ruby/object:Gem::Requirement
37
+ requirements:
38
+ - - ">="
39
+ - !ruby/object:Gem::Version
40
+ version: '0'
41
+ - !ruby/object:Gem::Dependency
42
+ name: grpc-tools
43
+ requirement: !ruby/object:Gem::Requirement
44
+ requirements:
45
+ - - ">="
46
+ - !ruby/object:Gem::Version
47
+ version: '0'
48
+ type: :development
49
+ prerelease: false
50
+ version_requirements: !ruby/object:Gem::Requirement
51
+ requirements:
52
+ - - ">="
53
+ - !ruby/object:Gem::Version
54
+ version: '0'
55
+ description: Ruby client of https://github.com/Stability-AI/stability-sdk
56
+ email:
57
+ - cou929@gmail.com
58
+ executables:
59
+ - stability-client
60
+ extensions: []
61
+ extra_rdoc_files: []
62
+ files:
63
+ - ".gitignore"
64
+ - ".travis.yml"
65
+ - Gemfile
66
+ - Gemfile.lock
67
+ - README.md
68
+ - Rakefile
69
+ - bin/console
70
+ - bin/setup
71
+ - exe/stability-client
72
+ - lib/generation_pb.rb
73
+ - lib/generation_services_pb.rb
74
+ - lib/stability_sdk.rb
75
+ - lib/stability_sdk/cli.rb
76
+ - lib/stability_sdk/client.rb
77
+ - lib/stability_sdk/version.rb
78
+ - proto/generation.proto
79
+ - stability_sdk.gemspec
80
+ homepage: https://github.com/cou929/stability-sdk-ruby
81
+ licenses: []
82
+ metadata:
83
+ homepage_uri: https://github.com/cou929/stability-sdk-ruby
84
+ source_code_uri: https://github.com/cou929/stability-sdk-ruby
85
+ changelog_uri: https://github.com/cou929/stability-sdk-ruby
86
+ post_install_message:
87
+ rdoc_options: []
88
+ require_paths:
89
+ - lib
90
+ required_ruby_version: !ruby/object:Gem::Requirement
91
+ requirements:
92
+ - - ">="
93
+ - !ruby/object:Gem::Version
94
+ version: 2.3.0
95
+ required_rubygems_version: !ruby/object:Gem::Requirement
96
+ requirements:
97
+ - - ">="
98
+ - !ruby/object:Gem::Version
99
+ version: '0'
100
+ requirements: []
101
+ rubygems_version: 3.1.4
102
+ signing_key:
103
+ specification_version: 4
104
+ summary: Ruby client for interacting with stability.ai APIs (e.g. stable diffusion
105
+ inference)
106
+ test_files: []