stability_sdk 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml ADDED
@@ -0,0 +1,7 @@
1
+ ---
2
+ SHA256:
3
+ metadata.gz: dac029de10ccef05e11d7fee6a581f14fec50252e1cf7790ce2e849e5061489c
4
+ data.tar.gz: 66f8803ff804dae6971cdf3b8473225012255e87d67f00484583d65ecd8afb6a
5
+ SHA512:
6
+ metadata.gz: ad988347184eafe784ab654b0022f1742cf56e51c9e2d4d88b3507f97162c6aaa7a165ad8ee5a6d3c34c1a9427fc860a54679192235c154bfea3008765dd8f89
7
+ data.tar.gz: d4522e6b2644f0ef309dbfb862a0667ed80c77e93c30bc575764997aa511d73789b3f270641d44c83cc2fdc017b7a0d5562e7de79ed7bebf86ff894019ad2642
data/.gitignore ADDED
@@ -0,0 +1,8 @@
1
+ /.bundle/
2
+ /.yardoc
3
+ /_yardoc/
4
+ /coverage/
5
+ /doc/
6
+ /pkg/
7
+ /spec/reports/
8
+ /tmp/
data/.travis.yml ADDED
@@ -0,0 +1,6 @@
1
+ ---
2
+ language: ruby
3
+ cache: bundler
4
+ rvm:
5
+ - 2.7.2
6
+ before_install: gem install bundler -v 2.1.4
data/Gemfile ADDED
@@ -0,0 +1,7 @@
1
+ source "https://rubygems.org"
2
+
3
+ # Specify your gem's dependencies in stability_sdk.gemspec
4
+ gemspec
5
+
6
+ gem "rake", "~> 12.0"
7
+ gem "minitest", "~> 5.0"
data/Gemfile.lock ADDED
@@ -0,0 +1,34 @@
1
+ PATH
2
+ remote: .
3
+ specs:
4
+ stability_sdk (0.1.0)
5
+ grpc
6
+ mime-types
7
+
8
+ GEM
9
+ remote: https://rubygems.org/
10
+ specs:
11
+ google-protobuf (3.21.5)
12
+ googleapis-common-protos-types (1.4.0)
13
+ google-protobuf (~> 3.14)
14
+ grpc (1.48.0)
15
+ google-protobuf (~> 3.19)
16
+ googleapis-common-protos-types (~> 1.0)
17
+ grpc-tools (1.48.0)
18
+ mime-types (3.4.1)
19
+ mime-types-data (~> 3.2015)
20
+ mime-types-data (3.2022.0105)
21
+ minitest (5.16.3)
22
+ rake (12.3.3)
23
+
24
+ PLATFORMS
25
+ arm64-darwin-20
26
+
27
+ DEPENDENCIES
28
+ grpc-tools
29
+ minitest (~> 5.0)
30
+ rake (~> 12.0)
31
+ stability_sdk!
32
+
33
+ BUNDLED WITH
34
+ 2.2.2
data/README.md ADDED
@@ -0,0 +1,89 @@
1
+ # StabilitySDK - Ruby client for stability.ai APIs, such as Stable Diffusion
2
+
3
+ A ruby client for [stability.ai](https://stability.ai/) APIs, e.g., [Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release). Referring to https://github.com/Stability-AI/stability-sdk.
4
+
5
+ ## Installation
6
+
7
+ Add this line to your application's Gemfile:
8
+
9
+ ```ruby
10
+ gem 'stability_sdk'
11
+ ```
12
+
13
+ And then execute:
14
+
15
+ $ bundle install
16
+
17
+ Or install it yourself as:
18
+
19
+ $ gem install stability_sdk
20
+
21
+ ## Usage
22
+
23
+ First you need to create a [DreamStudio](https://beta.dreamstudio.ai/home)'s account and get an API Key of it.
24
+
25
+ - Access [DreamStudio](https://beta.dreamstudio.ai/dream) and create an account if you have not had it
26
+ - Go to the [membership page](https://beta.dreamstudio.ai/membership)
27
+ - You can get the API Key in an `API Key` tab
28
+
29
+ ### Command line usage
30
+
31
+ ```sh
32
+ STABILITY_SDK_API_KEY=YOUR_API_KEY stability-client 'A night in winter, oil-on-canvas landscape painting, by Vincent van Gogh'
33
+ ```
34
+
35
+ This command saves an image like this:
36
+
37
+ ![3749380973_A_night_in_winter__oil_on_canvas_landscape_painting__by_Vincent_van_Gogh](https://user-images.githubusercontent.com/25668/188884116-0b03494b-0b34-49de-bbbc-89fbc2f6029d.png)
38
+
39
+
40
+ ```sh
41
+ Usage: stability-client [options] YOUR_PROMPT_TEXT
42
+
43
+ Options:
44
+ --api_key=VAL api key of DreamStudio account. You can also specify by a STABILITY_SDK_API_KEY environment variable
45
+ -H, --height=VAL height of image in pixel. default 512
46
+ -W, --width=VAL width of image in pixel. default 512
47
+ -C, --cfg_scale=VAL CFG scale factor. default 7.0
48
+ -A, --sampler=VAL ddim, plms, k_euler, k_euler_ancestral, k_heun, k_dpm_2, k_dpm_2_ancestral, k_lms. default k_lms
49
+ -s, --steps=VAL number of steps. default 50
50
+ -S, --seed=VAL random seed to use in integer
51
+ -p, --prefix=VAL output prefixes for artifacts. default `generation`
52
+ --no-store do not write out artifacts
53
+ -n, --num_samples=VAL number of samples to generate
54
+ -e, --engine=VAL engine to use for inference. default `stable-diffusion-v1`
55
+ -v, --verbose
56
+ ```
57
+
58
+ ### SDK usage
59
+
60
+ This sample code saves a generated image as `result.png`.
61
+
62
+ ```ruby
63
+ require "stability_sdk"
64
+
65
+ client = StabilitySDK::Client.new(api_key: "YOUR_API_KEY")
66
+
67
+ prompt = "your prompot here"
68
+ options = {}
69
+
70
+ client.generate(prompt, options) do |answer|
71
+ answer.artifacts.each do |artifact|
72
+ if artifact.type == :ARTIFACT_IMAGE
73
+ File.open("result.png", "wb") do |f|
74
+ f.write(artifact.binary)
75
+ end
76
+ end
77
+ end
78
+ end
79
+ ```
80
+
81
+ ## Development
82
+
83
+ After checking out the repo, run `bin/setup` to install dependencies. Then, run `rake test` to run the tests. You can also run `bin/console` for an interactive prompt that will allow you to experiment.
84
+
85
+ To install this gem onto your local machine, run `bundle exec rake install`. To release a new version, update the version number in `version.rb`, and then run `bundle exec rake release`, which will create a git tag for the version, push git commits and tags, and push the `.gem` file to [rubygems.org](https://rubygems.org).
86
+
87
+ ## Contributing
88
+
89
+ Bug reports and pull requests are welcome on GitHub at https://github.com/cou929/stability-sdk-ruby.
data/Rakefile ADDED
@@ -0,0 +1,10 @@
1
+ require "bundler/gem_tasks"
2
+ require "rake/testtask"
3
+
4
+ Rake::TestTask.new(:test) do |t|
5
+ t.libs << "test"
6
+ t.libs << "lib"
7
+ t.test_files = FileList["test/**/*_test.rb"]
8
+ end
9
+
10
+ task :default => :test
data/bin/console ADDED
@@ -0,0 +1,14 @@
1
+ #!/usr/bin/env ruby
2
+
3
+ require "bundler/setup"
4
+ require "stability_sdk"
5
+
6
+ # You can add fixtures and/or initialization code here to make experimenting
7
+ # with your gem easier. You can also use a different console, if you like.
8
+
9
+ # (If you use this, don't forget to add pry to your Gemfile!)
10
+ # require "pry"
11
+ # Pry.start
12
+
13
+ require "irb"
14
+ IRB.start(__FILE__)
data/bin/setup ADDED
@@ -0,0 +1,8 @@
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+ IFS=$'\n\t'
4
+ set -vx
5
+
6
+ bundle install
7
+
8
+ # Do any other automated setup that you need to do here
@@ -0,0 +1,43 @@
1
+ #!/usr/bin/env ruby
2
+
3
+ require "optparse"
4
+ require "mime/types"
5
+ require "logger"
6
+ require "stability_sdk"
7
+
8
+ logger = Logger.new(STDOUT)
9
+ logger.level = Logger::WARN
10
+
11
+ opt = OptionParser.new
12
+ Version = StabilitySDK::VERSION
13
+
14
+ options = {}
15
+
16
+ opt.banner = "Usage: stability-client [options] YOUR_PROMPT_TEXT"
17
+ opt.separator ""
18
+ opt.separator "Options:"
19
+ opt.on("--api_key=VAL", "api key of DreamStudio account. You can also specify by a STABILITY_SDK_API_KEY environment variable") {|v| options[:api_key] = v }
20
+ opt.on("-H", "--height=VAL", "height of image in pixel. default 512") {|v| options[:height] = v }
21
+ opt.on("-W", "--width=VAL", "width of image in pixel. default 512") {|v| options[:width] = v }
22
+ opt.on("-C", "--cfg_scale=VAL", "CFG scale factor. default 7.0") {|v| options[:cfg_scale] = v }
23
+ opt.on("-A", "--sampler=VAL", "ddim, plms, k_euler, k_euler_ancestral, k_heun, k_dpm_2, k_dpm_2_ancestral, k_lms. default k_lms") {|v| options[:sampler] = v }
24
+ opt.on("-s", "--steps=VAL", "number of steps. default 50") {|v| options[:steps] = v }
25
+ opt.on("-S", "--seed=VAL", "random seed to use in integer") {|v| options[:seed] = v }
26
+ opt.on("-p", "--prefix=VAL", "output prefixes for artifacts. default `generation`") {|v| options[:prefix] = v }
27
+ opt.on("--no-store", "do not write out artifacts") {|v| options[:no_store] = v }
28
+ opt.on("-n", "--num_samples=VAL", "number of samples to generate") {|v| options[:num_samples] = v }
29
+ opt.on("-e", "--engine=VAL", "engine to use for inference. default `stable-diffusion-v1`") {|v| options[:engine_id] = v }
30
+ opt.on("-v", "--verbose") { logger.level = Logger::DEBUG }
31
+ opt.parse!(ARGV)
32
+
33
+ prompt = ARGV.join(" ")
34
+ raise StabilitySDK::InsufficientParameter, "prompt is required" if prompt.nil? || prompt == ""
35
+
36
+ options[:api_key] = ENV["STABILITY_SDK_API_KEY"] if ENV["STABILITY_SDK_API_KEY"]
37
+ raise StabilitySDK::InsufficientParameter, "api key is required" if !options.has_key?(:api_key) || options[:api_key] == ""
38
+
39
+ client = StabilitySDK::Client.new(api_key: options[:api_key])
40
+
41
+ client.generate(prompt, options) do |answer|
42
+ StabilitySDK::CLI.save_answer(answer, options, logger)
43
+ end
@@ -0,0 +1,182 @@
1
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
2
+ # source: generation.proto
3
+
4
+ require 'google/protobuf'
5
+
6
+ Google::Protobuf::DescriptorPool.generated_pool.build do
7
+ add_file("generation.proto", :syntax => :proto3) do
8
+ add_message "gooseai.Token" do
9
+ proto3_optional :text, :string, 1
10
+ optional :id, :uint32, 2
11
+ end
12
+ add_message "gooseai.Tokens" do
13
+ repeated :tokens, :message, 1, "gooseai.Token"
14
+ proto3_optional :tokenizer_id, :string, 2
15
+ end
16
+ add_message "gooseai.Artifact" do
17
+ optional :id, :uint64, 1
18
+ optional :type, :enum, 2, "gooseai.ArtifactType"
19
+ optional :mime, :string, 3
20
+ proto3_optional :magic, :string, 4
21
+ optional :index, :uint32, 8
22
+ optional :finish_reason, :enum, 9, "gooseai.FinishReason"
23
+ optional :seed, :uint32, 10
24
+ oneof :data do
25
+ optional :binary, :bytes, 5
26
+ optional :text, :string, 6
27
+ optional :tokens, :message, 7, "gooseai.Tokens"
28
+ optional :classifier, :message, 11, "gooseai.ClassifierParameters"
29
+ end
30
+ end
31
+ add_message "gooseai.PromptParameters" do
32
+ proto3_optional :init, :bool, 1
33
+ proto3_optional :weight, :float, 2
34
+ end
35
+ add_message "gooseai.Prompt" do
36
+ proto3_optional :parameters, :message, 1, "gooseai.PromptParameters"
37
+ oneof :prompt do
38
+ optional :text, :string, 2
39
+ optional :tokens, :message, 3, "gooseai.Tokens"
40
+ optional :artifact, :message, 4, "gooseai.Artifact"
41
+ end
42
+ end
43
+ add_message "gooseai.AnswerMeta" do
44
+ proto3_optional :gpu_id, :string, 1
45
+ proto3_optional :cpu_id, :string, 2
46
+ proto3_optional :node_id, :string, 3
47
+ proto3_optional :engine_id, :string, 4
48
+ end
49
+ add_message "gooseai.Answer" do
50
+ optional :answer_id, :string, 1
51
+ optional :request_id, :string, 2
52
+ optional :received, :uint64, 3
53
+ optional :created, :uint64, 4
54
+ proto3_optional :meta, :message, 6, "gooseai.AnswerMeta"
55
+ repeated :artifacts, :message, 7, "gooseai.Artifact"
56
+ end
57
+ add_message "gooseai.SamplerParameters" do
58
+ proto3_optional :eta, :float, 1
59
+ proto3_optional :sampling_steps, :uint64, 2
60
+ proto3_optional :latent_channels, :uint64, 3
61
+ proto3_optional :downsampling_factor, :uint64, 4
62
+ proto3_optional :cfg_scale, :float, 5
63
+ end
64
+ add_message "gooseai.ConditionerParameters" do
65
+ proto3_optional :vector_adjust_prior, :string, 1
66
+ end
67
+ add_message "gooseai.StepParameter" do
68
+ optional :scaled_step, :float, 1
69
+ proto3_optional :sampler, :message, 2, "gooseai.SamplerParameters"
70
+ end
71
+ add_message "gooseai.TransformType" do
72
+ oneof :type do
73
+ optional :diffusion, :enum, 1, "gooseai.DiffusionSampler"
74
+ optional :upscaler, :enum, 2, "gooseai.Upscaler"
75
+ end
76
+ end
77
+ add_message "gooseai.ImageParameters" do
78
+ proto3_optional :height, :uint64, 1
79
+ proto3_optional :width, :uint64, 2
80
+ repeated :seed, :uint32, 3
81
+ proto3_optional :samples, :uint64, 4
82
+ proto3_optional :steps, :uint64, 5
83
+ proto3_optional :transform, :message, 6, "gooseai.TransformType"
84
+ repeated :parameters, :message, 7, "gooseai.StepParameter"
85
+ end
86
+ add_message "gooseai.ClassifierConcept" do
87
+ optional :concept, :string, 1
88
+ proto3_optional :threshold, :float, 2
89
+ end
90
+ add_message "gooseai.ClassifierCategory" do
91
+ optional :name, :string, 1
92
+ repeated :concepts, :message, 2, "gooseai.ClassifierConcept"
93
+ proto3_optional :adjustment, :float, 3
94
+ proto3_optional :action, :enum, 4, "gooseai.Action"
95
+ proto3_optional :classifier_mode, :enum, 5, "gooseai.ClassifierMode"
96
+ end
97
+ add_message "gooseai.ClassifierParameters" do
98
+ repeated :categories, :message, 1, "gooseai.ClassifierCategory"
99
+ repeated :exceeds, :message, 2, "gooseai.ClassifierCategory"
100
+ proto3_optional :realized_action, :enum, 3, "gooseai.Action"
101
+ end
102
+ add_message "gooseai.Request" do
103
+ optional :engine_id, :string, 1
104
+ optional :request_id, :string, 2
105
+ optional :requested_type, :enum, 3, "gooseai.ArtifactType"
106
+ repeated :prompt, :message, 4, "gooseai.Prompt"
107
+ proto3_optional :conditioner, :message, 6, "gooseai.ConditionerParameters"
108
+ proto3_optional :classifier, :message, 7, "gooseai.ClassifierParameters"
109
+ oneof :params do
110
+ optional :image, :message, 5, "gooseai.ImageParameters"
111
+ end
112
+ end
113
+ add_enum "gooseai.FinishReason" do
114
+ value :NULL, 0
115
+ value :LENGTH, 1
116
+ value :STOP, 2
117
+ value :ERROR, 3
118
+ value :FILTER, 4
119
+ end
120
+ add_enum "gooseai.ArtifactType" do
121
+ value :ARTIFACT_NONE, 0
122
+ value :ARTIFACT_IMAGE, 1
123
+ value :ARTIFACT_VIDEO, 2
124
+ value :ARTIFACT_TEXT, 3
125
+ value :ARTIFACT_TOKENS, 4
126
+ value :ARTIFACT_EMBEDDING, 5
127
+ value :ARTIFACT_CLASSIFICATIONS, 6
128
+ end
129
+ add_enum "gooseai.DiffusionSampler" do
130
+ value :SAMPLER_DDIM, 0
131
+ value :SAMPLER_DDPM, 1
132
+ value :SAMPLER_K_EULER, 2
133
+ value :SAMPLER_K_EULER_ANCESTRAL, 3
134
+ value :SAMPLER_K_HEUN, 4
135
+ value :SAMPLER_K_DPM_2, 5
136
+ value :SAMPLER_K_DPM_2_ANCESTRAL, 6
137
+ value :SAMPLER_K_LMS, 7
138
+ end
139
+ add_enum "gooseai.Upscaler" do
140
+ value :UPSCALER_RGB, 0
141
+ value :UPSCALER_GFPGAN, 1
142
+ value :UPSCALER_ESRGAN, 2
143
+ end
144
+ add_enum "gooseai.Action" do
145
+ value :ACTION_PASSTHROUGH, 0
146
+ value :ACTION_REGENERATE_DUPLICATE, 1
147
+ value :ACTION_REGENERATE, 2
148
+ value :ACTION_OBFUSCATE_DUPLICATE, 3
149
+ value :ACTION_OBFUSCATE, 4
150
+ value :ACTION_DISCARD, 5
151
+ end
152
+ add_enum "gooseai.ClassifierMode" do
153
+ value :CLSFR_MODE_ZEROSHOT, 0
154
+ value :CLSFR_MODE_MULTICLASS, 1
155
+ end
156
+ end
157
+ end
158
+
159
+ module Gooseai
160
+ Token = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.Token").msgclass
161
+ Tokens = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.Tokens").msgclass
162
+ Artifact = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.Artifact").msgclass
163
+ PromptParameters = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.PromptParameters").msgclass
164
+ Prompt = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.Prompt").msgclass
165
+ AnswerMeta = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.AnswerMeta").msgclass
166
+ Answer = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.Answer").msgclass
167
+ SamplerParameters = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.SamplerParameters").msgclass
168
+ ConditionerParameters = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.ConditionerParameters").msgclass
169
+ StepParameter = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.StepParameter").msgclass
170
+ TransformType = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.TransformType").msgclass
171
+ ImageParameters = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.ImageParameters").msgclass
172
+ ClassifierConcept = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.ClassifierConcept").msgclass
173
+ ClassifierCategory = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.ClassifierCategory").msgclass
174
+ ClassifierParameters = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.ClassifierParameters").msgclass
175
+ Request = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.Request").msgclass
176
+ FinishReason = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.FinishReason").enummodule
177
+ ArtifactType = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.ArtifactType").enummodule
178
+ DiffusionSampler = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.DiffusionSampler").enummodule
179
+ Upscaler = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.Upscaler").enummodule
180
+ Action = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.Action").enummodule
181
+ ClassifierMode = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gooseai.ClassifierMode").enummodule
182
+ end
@@ -0,0 +1,22 @@
1
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
2
+ # Source: generation.proto for package 'gooseai'
3
+
4
+ require 'grpc'
5
+ require 'generation_pb'
6
+
7
+ module Gooseai
8
+ module GenerationService
9
+ class Service
10
+
11
+ include ::GRPC::GenericService
12
+
13
+ self.marshal_class_method = :encode
14
+ self.unmarshal_class_method = :decode
15
+ self.service_name = 'gooseai.GenerationService'
16
+
17
+ rpc :Generate, ::Gooseai::Request, stream(::Gooseai::Answer)
18
+ end
19
+
20
+ Stub = Service.rpc_stub_class
21
+ end
22
+ end
@@ -0,0 +1,34 @@
1
+ module StabilitySDK
2
+ class CLI
3
+ def self.save_answer(answer, options, logger)
4
+ logger.debug "received answer: #{answer}"
5
+ return if options[:no_store]
6
+
7
+ answer.artifacts.each_with_index do |artifact, idx|
8
+ filename_base = "#{options[:prefix] || "generation"}-#{answer.request_id}-#{answer.answer_id}-#{idx}"
9
+
10
+ filename = ""
11
+ contents = ""
12
+
13
+ case artifact.type
14
+ when :ARTIFACT_IMAGE
15
+ ext = MIME::Types[artifact.mime].first.preferred_extension
16
+ filename = "#{filename_base}.#{ext}"
17
+ contents = artifact.binary
18
+ when :ARTIFACT_CLASSIFICATIONS
19
+ ext = "pb.json"
20
+ filename = "#{filename_base}.#{ext}"
21
+ contents = artifact.classifier.to_json
22
+ else
23
+ logger.warn "not implemented for ArtifactType #{artifact.type}"
24
+ end
25
+
26
+ next if filename == "" || contents == ""
27
+
28
+ File.open(filename, "wb") do |f|
29
+ f.write(contents)
30
+ end
31
+ end
32
+ end
33
+ end
34
+ end
@@ -0,0 +1,75 @@
1
+ require "grpc"
2
+ require "generation_services_pb"
3
+
4
+ module StabilitySDK
5
+ class Client
6
+ DEFAULT_API_HOST = "grpc.stability.ai:443"
7
+ DEFAULT_IMAGE_WIDTH = 512
8
+ DEFAULT_IMAGE_HEIGHT = 512
9
+ DEFAULT_SAMPLE_SIZE = 1
10
+ DEFAULT_STEPS = 50
11
+ DEFAULT_ENGINE_ID = "stable-diffusion-v1"
12
+ DEFAULT_CFG_SCALE = 7.0
13
+ DEFAULT_SAMPLER_ALGORITHM = Gooseai::DiffusionSampler::SAMPLER_K_LMS
14
+
15
+ sampler_algorithms = {
16
+ "ddim": Gooseai::DiffusionSampler::SAMPLER_DDIM,
17
+ "plms": Gooseai::DiffusionSampler::SAMPLER_DDPM,
18
+ "k_euler": Gooseai::DiffusionSampler::SAMPLER_K_EULER,
19
+ "k_euler_ancestral": Gooseai::DiffusionSampler::SAMPLER_K_EULER_ANCESTRAL,
20
+ "k_heun": Gooseai::DiffusionSampler::SAMPLER_K_HEUN,
21
+ "k_dpm_2": Gooseai::DiffusionSampler::SAMPLER_K_DPM_2,
22
+ "k_dpm_2_ancestral": Gooseai::DiffusionSampler::SAMPLER_K_DPM_2_ANCESTRAL,
23
+ "k_lms": Gooseai::DiffusionSampler::SAMPLER_K_LMS,
24
+ }
25
+
26
+ def initialize(options={})
27
+ host = options[:api_host] || DEFAULT_API_HOST
28
+ channel_creds = options.has_key?(:ca_cert) ? GRPC::Core::ChannelCredentials.new(options[:ca_cert]) : GRPC::Core::ChannelCredentials.new
29
+ call_creds = GRPC::Core::CallCredentials.new(proc { { "authorization" => "Bearer #{options[:api_key]}" } })
30
+ creds = channel_creds.compose(call_creds)
31
+
32
+ @stub = Gooseai::GenerationService::Stub.new(host, creds)
33
+ end
34
+
35
+ def generate(prompt, options, &block)
36
+ image_param = image_param(options)
37
+ req = Gooseai::Request.new(
38
+ prompt: [Gooseai::Prompt.new(text: prompt)],
39
+ engine_id: options[:engine_id] || DEFAULT_ENGINE_ID,
40
+ image: image_param
41
+ )
42
+
43
+ @stub.generate(req).each do |answer|
44
+ block.call(answer)
45
+ end
46
+ end
47
+
48
+ def image_param(options={})
49
+ width = options.has_key?(:width) ? options[:width].to_i : DEFAULT_IMAGE_WIDTH
50
+ height = options.has_key?(:height) ? options[:height] : DEFAULT_IMAGE_HEIGHT
51
+ samples = options.has_key?(:num_samples) ? [:num_samples].to_i : DEFAULT_SAMPLE_SIZE
52
+ steps = options.has_key?(:steps) ? options[:steps].to_i : DEFAULT_STEPS
53
+ seed = options.has_key?(:seed) ? [options[:seed]] : [rand(4294967295)]
54
+ transform = Gooseai::TransformType.new(
55
+ diffusion: options.has_key?(:sampler) ? sampler_algorithms[options[:sampler]] : DEFAULT_SAMPLER_ALGORITHM,
56
+ )
57
+ parameters = [Gooseai::StepParameter.new(
58
+ scaled_step: 0,
59
+ sampler: Gooseai::SamplerParameters.new(
60
+ cfg_scale: options.has_key?(:cfg_scale) ? options[:cfg_scale].to_f : DEFAULT_CFG_SCALE,
61
+ ),
62
+ )]
63
+
64
+ return Gooseai::ImageParameters.new(
65
+ width: width,
66
+ height: height,
67
+ samples: samples,
68
+ steps: steps,
69
+ seed: seed,
70
+ transform: transform,
71
+ parameters: parameters,
72
+ )
73
+ end
74
+ end
75
+ end
@@ -0,0 +1,3 @@
1
+ module StabilitySDK
2
+ VERSION = "0.1.0"
3
+ end
@@ -0,0 +1,8 @@
1
+ require "stability_sdk/version"
2
+ require "stability_sdk/client"
3
+ require "stability_sdk/cli"
4
+
5
+ module StabilitySDK
6
+ class Error < StandardError; end
7
+ class InsufficientParameter < StandardError; end
8
+ end
@@ -0,0 +1,179 @@
1
+ syntax = 'proto3';
2
+ package gooseai;
3
+ option go_package = "./;generation";
4
+
5
+ enum FinishReason {
6
+ NULL = 0;
7
+ LENGTH = 1;
8
+ STOP = 2;
9
+ ERROR = 3;
10
+ FILTER = 4;
11
+ }
12
+
13
+
14
+ enum ArtifactType {
15
+ ARTIFACT_NONE = 0;
16
+ ARTIFACT_IMAGE = 1;
17
+ ARTIFACT_VIDEO = 2;
18
+ ARTIFACT_TEXT = 3;
19
+ ARTIFACT_TOKENS = 4;
20
+ ARTIFACT_EMBEDDING = 5;
21
+ ARTIFACT_CLASSIFICATIONS = 6;
22
+ }
23
+
24
+ message Token {
25
+ optional string text = 1;
26
+ uint32 id = 2;
27
+ }
28
+
29
+ message Tokens {
30
+ repeated Token tokens = 1;
31
+ optional string tokenizer_id = 2;
32
+ }
33
+
34
+ message Artifact {
35
+ uint64 id = 1;
36
+ ArtifactType type = 2;
37
+ string mime = 3;
38
+ optional string magic = 4;
39
+ oneof data {
40
+ bytes binary = 5;
41
+ string text = 6;
42
+ Tokens tokens = 7;
43
+ ClassifierParameters classifier = 11;
44
+ }
45
+ uint32 index = 8;
46
+ FinishReason finish_reason = 9;
47
+ uint32 seed = 10;
48
+ }
49
+
50
+ message PromptParameters {
51
+ optional bool init = 1;
52
+ optional float weight = 2;
53
+ }
54
+
55
+ message Prompt {
56
+ optional PromptParameters parameters = 1;
57
+ oneof prompt {
58
+ string text = 2;
59
+ Tokens tokens = 3;
60
+ Artifact artifact = 4;
61
+ }
62
+ }
63
+
64
+ message AnswerMeta {
65
+ optional string gpu_id = 1;
66
+ optional string cpu_id = 2;
67
+ optional string node_id = 3;
68
+ optional string engine_id = 4;
69
+ }
70
+
71
+ message Answer {
72
+ string answer_id = 1;
73
+ string request_id = 2;
74
+ uint64 received = 3;
75
+ uint64 created = 4;
76
+ optional AnswerMeta meta = 6;
77
+ repeated Artifact artifacts = 7;
78
+ }
79
+
80
+ enum DiffusionSampler {
81
+ SAMPLER_DDIM = 0;
82
+ SAMPLER_DDPM = 1;
83
+ SAMPLER_K_EULER = 2;
84
+ SAMPLER_K_EULER_ANCESTRAL = 3;
85
+ SAMPLER_K_HEUN = 4;
86
+ SAMPLER_K_DPM_2 = 5;
87
+ SAMPLER_K_DPM_2_ANCESTRAL = 6;
88
+ SAMPLER_K_LMS = 7;
89
+ }
90
+
91
+ message SamplerParameters {
92
+ optional float eta = 1;
93
+ optional uint64 sampling_steps = 2;
94
+ optional uint64 latent_channels = 3;
95
+ optional uint64 downsampling_factor = 4;
96
+ optional float cfg_scale = 5;
97
+ }
98
+
99
+ message ConditionerParameters {
100
+ optional string vector_adjust_prior = 1;
101
+ }
102
+
103
+ enum Upscaler {
104
+ UPSCALER_RGB = 0;
105
+ UPSCALER_GFPGAN = 1;
106
+ UPSCALER_ESRGAN = 2;
107
+ }
108
+
109
+ message StepParameter {
110
+ float scaled_step = 1;
111
+ optional SamplerParameters sampler = 2;
112
+ }
113
+
114
+ message TransformType {
115
+ oneof type {
116
+ DiffusionSampler diffusion = 1;
117
+ Upscaler upscaler = 2;
118
+ }
119
+ }
120
+
121
+ message ImageParameters {
122
+ optional uint64 height = 1;
123
+ optional uint64 width = 2;
124
+ repeated uint32 seed = 3;
125
+ optional uint64 samples = 4;
126
+ optional uint64 steps = 5;
127
+ optional TransformType transform = 6;
128
+ repeated StepParameter parameters = 7;
129
+ }
130
+
131
+ enum Action {
132
+ ACTION_PASSTHROUGH = 0;
133
+ ACTION_REGENERATE_DUPLICATE = 1;
134
+ ACTION_REGENERATE = 2;
135
+ ACTION_OBFUSCATE_DUPLICATE = 3;
136
+ ACTION_OBFUSCATE = 4;
137
+ ACTION_DISCARD = 5;
138
+ }
139
+
140
+ enum ClassifierMode {
141
+ CLSFR_MODE_ZEROSHOT = 0;
142
+ CLSFR_MODE_MULTICLASS = 1;
143
+ /*CLSFR_MODE_ODDSRATIO = 2;*/
144
+ }
145
+
146
+ message ClassifierConcept {
147
+ string concept = 1;
148
+ optional float threshold = 2;
149
+ }
150
+
151
+ message ClassifierCategory {
152
+ string name = 1;
153
+ repeated ClassifierConcept concepts = 2;
154
+ optional float adjustment = 3;
155
+ optional Action action = 4;
156
+ optional ClassifierMode classifier_mode = 5;
157
+ }
158
+
159
+ message ClassifierParameters {
160
+ repeated ClassifierCategory categories = 1;
161
+ repeated ClassifierCategory exceeds = 2;
162
+ optional Action realized_action = 3;
163
+ }
164
+
165
+ message Request {
166
+ string engine_id = 1;
167
+ string request_id = 2;
168
+ ArtifactType requested_type = 3;
169
+ repeated Prompt prompt = 4;
170
+ oneof params {
171
+ ImageParameters image = 5;
172
+ }
173
+ optional ConditionerParameters conditioner = 6;
174
+ optional ClassifierParameters classifier = 7;
175
+ }
176
+
177
+ service GenerationService {
178
+ rpc Generate (Request) returns (stream Answer) {};
179
+ }
@@ -0,0 +1,30 @@
1
+ require_relative 'lib/stability_sdk/version'
2
+
3
+ Gem::Specification.new do |spec|
4
+ spec.name = "stability_sdk"
5
+ spec.version = StabilitySDK::VERSION
6
+ spec.authors = ["Kosei Moriyama"]
7
+ spec.email = ["cou929@gmail.com"]
8
+
9
+ spec.summary = "Ruby client for interacting with stability.ai APIs (e.g. stable diffusion inference)"
10
+ spec.description = "Ruby client of https://github.com/Stability-AI/stability-sdk"
11
+ spec.homepage = "https://github.com/cou929/stability-sdk-ruby"
12
+ spec.required_ruby_version = Gem::Requirement.new(">= 2.3.0")
13
+
14
+ spec.metadata["homepage_uri"] = spec.homepage
15
+ spec.metadata["source_code_uri"] = "https://github.com/cou929/stability-sdk-ruby"
16
+ spec.metadata["changelog_uri"] = "https://github.com/cou929/stability-sdk-ruby"
17
+
18
+ # Specify which files should be added to the gem when it is released.
19
+ # The `git ls-files -z` loads the files in the RubyGem that have been added into git.
20
+ spec.files = Dir.chdir(File.expand_path('..', __FILE__)) do
21
+ `git ls-files -z`.split("\x0").reject { |f| f.match(%r{^(test|spec|features)/}) }
22
+ end
23
+ spec.bindir = "exe"
24
+ spec.executables = spec.files.grep(%r{^exe/}) { |f| File.basename(f) }
25
+ spec.require_paths = ["lib"]
26
+
27
+ spec.add_dependency "grpc"
28
+ spec.add_dependency "mime-types"
29
+ spec.add_development_dependency "grpc-tools"
30
+ end
metadata ADDED
@@ -0,0 +1,106 @@
1
+ --- !ruby/object:Gem::Specification
2
+ name: stability_sdk
3
+ version: !ruby/object:Gem::Version
4
+ version: 0.1.0
5
+ platform: ruby
6
+ authors:
7
+ - Kosei Moriyama
8
+ autorequire:
9
+ bindir: exe
10
+ cert_chain: []
11
+ date: 2022-09-07 00:00:00.000000000 Z
12
+ dependencies:
13
+ - !ruby/object:Gem::Dependency
14
+ name: grpc
15
+ requirement: !ruby/object:Gem::Requirement
16
+ requirements:
17
+ - - ">="
18
+ - !ruby/object:Gem::Version
19
+ version: '0'
20
+ type: :runtime
21
+ prerelease: false
22
+ version_requirements: !ruby/object:Gem::Requirement
23
+ requirements:
24
+ - - ">="
25
+ - !ruby/object:Gem::Version
26
+ version: '0'
27
+ - !ruby/object:Gem::Dependency
28
+ name: mime-types
29
+ requirement: !ruby/object:Gem::Requirement
30
+ requirements:
31
+ - - ">="
32
+ - !ruby/object:Gem::Version
33
+ version: '0'
34
+ type: :runtime
35
+ prerelease: false
36
+ version_requirements: !ruby/object:Gem::Requirement
37
+ requirements:
38
+ - - ">="
39
+ - !ruby/object:Gem::Version
40
+ version: '0'
41
+ - !ruby/object:Gem::Dependency
42
+ name: grpc-tools
43
+ requirement: !ruby/object:Gem::Requirement
44
+ requirements:
45
+ - - ">="
46
+ - !ruby/object:Gem::Version
47
+ version: '0'
48
+ type: :development
49
+ prerelease: false
50
+ version_requirements: !ruby/object:Gem::Requirement
51
+ requirements:
52
+ - - ">="
53
+ - !ruby/object:Gem::Version
54
+ version: '0'
55
+ description: Ruby client of https://github.com/Stability-AI/stability-sdk
56
+ email:
57
+ - cou929@gmail.com
58
+ executables:
59
+ - stability-client
60
+ extensions: []
61
+ extra_rdoc_files: []
62
+ files:
63
+ - ".gitignore"
64
+ - ".travis.yml"
65
+ - Gemfile
66
+ - Gemfile.lock
67
+ - README.md
68
+ - Rakefile
69
+ - bin/console
70
+ - bin/setup
71
+ - exe/stability-client
72
+ - lib/generation_pb.rb
73
+ - lib/generation_services_pb.rb
74
+ - lib/stability_sdk.rb
75
+ - lib/stability_sdk/cli.rb
76
+ - lib/stability_sdk/client.rb
77
+ - lib/stability_sdk/version.rb
78
+ - proto/generation.proto
79
+ - stability_sdk.gemspec
80
+ homepage: https://github.com/cou929/stability-sdk-ruby
81
+ licenses: []
82
+ metadata:
83
+ homepage_uri: https://github.com/cou929/stability-sdk-ruby
84
+ source_code_uri: https://github.com/cou929/stability-sdk-ruby
85
+ changelog_uri: https://github.com/cou929/stability-sdk-ruby
86
+ post_install_message:
87
+ rdoc_options: []
88
+ require_paths:
89
+ - lib
90
+ required_ruby_version: !ruby/object:Gem::Requirement
91
+ requirements:
92
+ - - ">="
93
+ - !ruby/object:Gem::Version
94
+ version: 2.3.0
95
+ required_rubygems_version: !ruby/object:Gem::Requirement
96
+ requirements:
97
+ - - ">="
98
+ - !ruby/object:Gem::Version
99
+ version: '0'
100
+ requirements: []
101
+ rubygems_version: 3.1.4
102
+ signing_key:
103
+ specification_version: 4
104
+ summary: Ruby client for interacting with stability.ai APIs (e.g. stable diffusion
105
+ inference)
106
+ test_files: []