fal 0.0.1 → 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.env.example +1 -0
- data/.rubocop.yml +8 -0
- data/Gemfile +5 -0
- data/Gemfile.lock +90 -26
- data/README.md +272 -0
- data/lib/fal/client.rb +151 -0
- data/lib/fal/model.rb +166 -0
- data/lib/fal/price.rb +92 -0
- data/lib/fal/price_estimate.rb +89 -0
- data/lib/fal/request.rb +161 -0
- data/lib/fal/stream.rb +92 -0
- data/lib/fal/version.rb +1 -1
- data/lib/fal/webhook_request.rb +99 -0
- data/lib/fal.rb +102 -0
- data/rbi/fal/client.rbi +59 -0
- data/rbi/fal/fal.rbi +68 -0
- data/rbi/fal/model.rbi +130 -0
- data/rbi/fal/price.rbi +49 -0
- data/rbi/fal/price_estimate.rbi +61 -0
- data/rbi/fal/request.rbi +88 -0
- data/rbi/fal/stream.rbi +37 -0
- data/rbi/fal/version.rbi +6 -0
- data/rbi/fal/webhook_request.rbi +61 -0
- data/sorbet/config +2 -0
- data/sorbet/rbi/.gitignore +2 -0
- metadata +21 -2
data/lib/fal/model.rb
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Fal
|
|
4
|
+
# Represents a model endpoint discoverable via the Models API.
|
|
5
|
+
# Provides helpers to list, search, and fetch pricing and to run requests.
|
|
6
|
+
class Model
|
|
7
|
+
MODELS_PATH = "/models"
|
|
8
|
+
|
|
9
|
+
# @return [String]
|
|
10
|
+
attr_reader :endpoint_id
|
|
11
|
+
# Flattened metadata fields
|
|
12
|
+
# @return [String, nil]
|
|
13
|
+
attr_reader :display_name
|
|
14
|
+
# @return [String, nil]
|
|
15
|
+
attr_reader :category
|
|
16
|
+
# @return [String, nil]
|
|
17
|
+
attr_reader :description
|
|
18
|
+
# @return [String, nil]
|
|
19
|
+
attr_reader :status
|
|
20
|
+
# @return [Array<String>, nil]
|
|
21
|
+
attr_reader :tags
|
|
22
|
+
# @return [String, nil]
|
|
23
|
+
attr_reader :updated_at
|
|
24
|
+
# @return [Boolean, nil]
|
|
25
|
+
attr_reader :is_favorited
|
|
26
|
+
# @return [String, nil]
|
|
27
|
+
attr_reader :thumbnail_url
|
|
28
|
+
# @return [String, nil]
|
|
29
|
+
attr_reader :thumbnail_animated_url
|
|
30
|
+
# @return [String, nil]
|
|
31
|
+
attr_reader :model_url
|
|
32
|
+
# @return [String, nil]
|
|
33
|
+
attr_reader :github_url
|
|
34
|
+
# @return [String, nil]
|
|
35
|
+
attr_reader :license_type
|
|
36
|
+
# @return [String, nil]
|
|
37
|
+
attr_reader :date
|
|
38
|
+
# @return [Hash, nil]
|
|
39
|
+
attr_reader :group
|
|
40
|
+
# @return [Boolean, nil]
|
|
41
|
+
attr_reader :highlighted
|
|
42
|
+
# @return [String, nil]
|
|
43
|
+
attr_reader :kind
|
|
44
|
+
# @return [Array<String>, nil]
|
|
45
|
+
attr_reader :training_endpoint_ids
|
|
46
|
+
# @return [Array<String>, nil]
|
|
47
|
+
attr_reader :inference_endpoint_ids
|
|
48
|
+
# @return [String, nil]
|
|
49
|
+
attr_reader :stream_url
|
|
50
|
+
# @return [Float, nil]
|
|
51
|
+
attr_reader :duration_estimate
|
|
52
|
+
# @return [Boolean, nil]
|
|
53
|
+
attr_reader :pinned
|
|
54
|
+
# @return [Hash, nil]
|
|
55
|
+
attr_reader :openapi
|
|
56
|
+
|
|
57
|
+
# @param attributes [Hash]
|
|
58
|
+
# @param client [Fal::Client]
|
|
59
|
+
def initialize(attributes, client: Fal.client)
|
|
60
|
+
@client = client
|
|
61
|
+
reset_attributes(attributes)
|
|
62
|
+
end
|
|
63
|
+
|
|
64
|
+
# Fetch and memoize the price object for this model's endpoint.
|
|
65
|
+
# @return [Fal::Price, nil]
|
|
66
|
+
def price
|
|
67
|
+
@price ||= Fal::Price.find_by(endpoint_id: @endpoint_id, client: @client)
|
|
68
|
+
end
|
|
69
|
+
|
|
70
|
+
# Run a queued request for this model endpoint.
|
|
71
|
+
# @param input [Hash]
|
|
72
|
+
# @param webhook_url [String, nil]
|
|
73
|
+
# @return [Fal::Request]
|
|
74
|
+
def run(input:, webhook_url: nil)
|
|
75
|
+
Fal::Request.create!(endpoint_id: @endpoint_id, input: input, webhook_url: webhook_url, client: @client)
|
|
76
|
+
end
|
|
77
|
+
|
|
78
|
+
class << self
|
|
79
|
+
# Find a specific model by endpoint_id.
|
|
80
|
+
# @param endpoint_id [String]
|
|
81
|
+
# @param client [Fal::Client]
|
|
82
|
+
# @return [Fal::Model, nil]
|
|
83
|
+
def find_by(endpoint_id:, client: Fal.client)
|
|
84
|
+
response = client.get_api(MODELS_PATH, query: { endpoint_id: endpoint_id })
|
|
85
|
+
entry = Array(response && response["models"]).find { |m| m["endpoint_id"] == endpoint_id }
|
|
86
|
+
entry ? new(entry, client: client) : nil
|
|
87
|
+
end
|
|
88
|
+
|
|
89
|
+
# Iterate through models with optional search filters.
|
|
90
|
+
# @param client [Fal::Client]
|
|
91
|
+
# @param query [String, nil] Free-text search query
|
|
92
|
+
# @param category [String, nil]
|
|
93
|
+
# @param status [String, nil]
|
|
94
|
+
# @param expand [Array<String>, String, nil]
|
|
95
|
+
# @yield [Fal::Model]
|
|
96
|
+
# @return [void]
|
|
97
|
+
def each(client: Fal.client, query: nil, category: nil, status: nil, expand: nil, &block)
|
|
98
|
+
cursor = nil
|
|
99
|
+
loop do
|
|
100
|
+
query_hash = { limit: 50, cursor: cursor, q: query, category: category, status: status }.compact
|
|
101
|
+
query_hash[:expand] = expand if expand
|
|
102
|
+
response = client.get_api(MODELS_PATH, query: query_hash)
|
|
103
|
+
models = Array(response && response["models"])
|
|
104
|
+
models.each { |attributes| block.call(new(attributes, client: client)) }
|
|
105
|
+
cursor = response && response["next_cursor"]
|
|
106
|
+
break if cursor.nil?
|
|
107
|
+
end
|
|
108
|
+
end
|
|
109
|
+
|
|
110
|
+
# Return an array of models for the given filters (or all models).
|
|
111
|
+
# @param client [Fal::Client]
|
|
112
|
+
# @param query [String, nil]
|
|
113
|
+
# @param category [String, nil]
|
|
114
|
+
# @param status [String, nil]
|
|
115
|
+
# @param expand [Array<String>, String, nil]
|
|
116
|
+
# @return [Array<Fal::Model>]
|
|
117
|
+
def all(client: Fal.client, query: nil, category: nil, status: nil, expand: nil)
|
|
118
|
+
results = []
|
|
119
|
+
each(client: client, query: query, category: category, status: status, expand: expand) { |m| results << m }
|
|
120
|
+
results
|
|
121
|
+
end
|
|
122
|
+
|
|
123
|
+
# Convenience search wrapper that returns all matching models.
|
|
124
|
+
# @param client [Fal::Client]
|
|
125
|
+
# @param query [String, nil]
|
|
126
|
+
# @param category [String, nil]
|
|
127
|
+
# @param status [String, nil]
|
|
128
|
+
# @param expand [Array<String>, String, nil]
|
|
129
|
+
# @return [Array<Fal::Model>]
|
|
130
|
+
def search(query: nil, category: nil, status: nil, expand: nil, client: Fal.client)
|
|
131
|
+
all(client: client, query: query, category: category, status: status, expand: expand)
|
|
132
|
+
end
|
|
133
|
+
end
|
|
134
|
+
|
|
135
|
+
private
|
|
136
|
+
|
|
137
|
+
def reset_attributes(attributes)
|
|
138
|
+
@endpoint_id = attributes["endpoint_id"]
|
|
139
|
+
|
|
140
|
+
meta = attributes["metadata"] || {}
|
|
141
|
+
@display_name = meta["display_name"]
|
|
142
|
+
@category = meta["category"]
|
|
143
|
+
@description = meta["description"]
|
|
144
|
+
@status = meta["status"]
|
|
145
|
+
@tags = meta["tags"]
|
|
146
|
+
@updated_at = meta["updated_at"]
|
|
147
|
+
@is_favorited = meta["is_favorited"]
|
|
148
|
+
@thumbnail_url = meta["thumbnail_url"]
|
|
149
|
+
@thumbnail_animated_url = meta["thumbnail_animated_url"]
|
|
150
|
+
@model_url = meta["model_url"]
|
|
151
|
+
@github_url = meta["github_url"]
|
|
152
|
+
@license_type = meta["license_type"]
|
|
153
|
+
@date = meta["date"]
|
|
154
|
+
@group = meta["group"]
|
|
155
|
+
@highlighted = meta["highlighted"]
|
|
156
|
+
@kind = meta["kind"]
|
|
157
|
+
@training_endpoint_ids = meta["training_endpoint_ids"]
|
|
158
|
+
@inference_endpoint_ids = meta["inference_endpoint_ids"]
|
|
159
|
+
@stream_url = meta["stream_url"]
|
|
160
|
+
@duration_estimate = meta["duration_estimate"]
|
|
161
|
+
@pinned = meta["pinned"]
|
|
162
|
+
|
|
163
|
+
@openapi = attributes["openapi"]
|
|
164
|
+
end
|
|
165
|
+
end
|
|
166
|
+
end
|
data/lib/fal/price.rb
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Fal
|
|
4
|
+
# Represents pricing information for a model endpoint.
|
|
5
|
+
# Fetches data from the Platform API at /models and /models/pricing.
|
|
6
|
+
class Price
|
|
7
|
+
MODELS_PATH = "/models"
|
|
8
|
+
PRICING_PATH = "/models/pricing"
|
|
9
|
+
|
|
10
|
+
# Billing units returned by the pricing service.
|
|
11
|
+
module Unit
|
|
12
|
+
# Output-based units
|
|
13
|
+
IMAGES = "image"
|
|
14
|
+
VIDEOS = "video"
|
|
15
|
+
MEGAPIXELS = "megapixels"
|
|
16
|
+
|
|
17
|
+
# Compute-based units (provider-specific)
|
|
18
|
+
GPU_SECONDS = "gpu_second"
|
|
19
|
+
GPU_MINUTES = "gpu_minute"
|
|
20
|
+
GPU_HOURS = "gpu_hour"
|
|
21
|
+
end
|
|
22
|
+
|
|
23
|
+
# @return [String]
|
|
24
|
+
attr_reader :endpoint_id
|
|
25
|
+
# @return [Float]
|
|
26
|
+
attr_reader :unit_price
|
|
27
|
+
# @return [String]
|
|
28
|
+
attr_reader :unit
|
|
29
|
+
# @return [String]
|
|
30
|
+
attr_reader :currency
|
|
31
|
+
|
|
32
|
+
# @param attributes [Hash] Raw attributes from pricing API
|
|
33
|
+
# @param client [Fal::Client]
|
|
34
|
+
def initialize(attributes, client: Fal.client)
|
|
35
|
+
@client = client
|
|
36
|
+
reset_attributes(attributes)
|
|
37
|
+
end
|
|
38
|
+
|
|
39
|
+
class << self
|
|
40
|
+
# Find pricing for a specific model endpoint.
|
|
41
|
+
# @param endpoint_id [String]
|
|
42
|
+
# @param client [Fal::Client]
|
|
43
|
+
# @return [Fal::Price, nil]
|
|
44
|
+
def find_by(endpoint_id:, client: Fal.client)
|
|
45
|
+
response = client.get_api(PRICING_PATH, query: { endpoint_id: endpoint_id })
|
|
46
|
+
entry = Array(response && response["prices"]).find { |p| p["endpoint_id"] == endpoint_id }
|
|
47
|
+
entry ? new(entry, client: client) : nil
|
|
48
|
+
end
|
|
49
|
+
|
|
50
|
+
# Iterate over all prices by paging through models and fetching pricing in batches.
|
|
51
|
+
# @param client [Fal::Client]
|
|
52
|
+
# @yield [Fal::Price]
|
|
53
|
+
# @return [void]
|
|
54
|
+
def each(client: Fal.client, &block)
|
|
55
|
+
cursor = nil
|
|
56
|
+
loop do
|
|
57
|
+
models_response = client.get_api(MODELS_PATH, query: { limit: 50, cursor: cursor }.compact)
|
|
58
|
+
models = Array(models_response && models_response["models"])
|
|
59
|
+
endpoint_ids = models.map { |m| m["endpoint_id"] }.compact
|
|
60
|
+
|
|
61
|
+
if endpoint_ids.any?
|
|
62
|
+
pricing_response = client.get_api(PRICING_PATH, query: { endpoint_id: endpoint_ids })
|
|
63
|
+
Array(pricing_response && pricing_response["prices"]).each do |attributes|
|
|
64
|
+
block.call(new(attributes, client: client))
|
|
65
|
+
end
|
|
66
|
+
end
|
|
67
|
+
|
|
68
|
+
cursor = models_response && models_response["next_cursor"]
|
|
69
|
+
break if cursor.nil?
|
|
70
|
+
end
|
|
71
|
+
end
|
|
72
|
+
|
|
73
|
+
# Return an array of all prices.
|
|
74
|
+
# @param client [Fal::Client]
|
|
75
|
+
# @return [Array<Fal::Price>]
|
|
76
|
+
def all(client: Fal.client)
|
|
77
|
+
results = []
|
|
78
|
+
each(client: client) { |price| results << price }
|
|
79
|
+
results
|
|
80
|
+
end
|
|
81
|
+
end
|
|
82
|
+
|
|
83
|
+
private
|
|
84
|
+
|
|
85
|
+
def reset_attributes(attributes)
|
|
86
|
+
@endpoint_id = attributes["endpoint_id"]
|
|
87
|
+
@unit_price = attributes["unit_price"]
|
|
88
|
+
@unit = attributes["unit"]
|
|
89
|
+
@currency = attributes["currency"]
|
|
90
|
+
end
|
|
91
|
+
end
|
|
92
|
+
end
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Fal
|
|
4
|
+
# Represents a cost estimate response from the Platform API.
|
|
5
|
+
# Computes estimates via POST /models/pricing/estimate.
|
|
6
|
+
class PriceEstimate
|
|
7
|
+
ESTIMATE_PATH = "/models/pricing/estimate"
|
|
8
|
+
|
|
9
|
+
# Supported estimate types.
|
|
10
|
+
module EstimateType
|
|
11
|
+
# @return [String]
|
|
12
|
+
HISTORICAL_API_PRICE = "historical_api_price"
|
|
13
|
+
# @return [String]
|
|
14
|
+
UNIT_PRICE = "unit_price"
|
|
15
|
+
end
|
|
16
|
+
|
|
17
|
+
# Simple value object for endpoint inputs.
|
|
18
|
+
class Endpoint
|
|
19
|
+
# @return [String]
|
|
20
|
+
attr_reader :endpoint_id
|
|
21
|
+
# @return [Integer, nil]
|
|
22
|
+
attr_reader :call_quantity
|
|
23
|
+
# @return [Float, nil]
|
|
24
|
+
attr_reader :unit_quantity
|
|
25
|
+
|
|
26
|
+
# @param endpoint_id [String]
|
|
27
|
+
# @param call_quantity [Integer, nil]
|
|
28
|
+
# @param unit_quantity [Float, nil]
|
|
29
|
+
def initialize(endpoint_id:, call_quantity: nil, unit_quantity: nil)
|
|
30
|
+
@endpoint_id = endpoint_id
|
|
31
|
+
@call_quantity = call_quantity
|
|
32
|
+
@unit_quantity = unit_quantity
|
|
33
|
+
end
|
|
34
|
+
end
|
|
35
|
+
|
|
36
|
+
# @return [String]
|
|
37
|
+
attr_reader :estimate_type
|
|
38
|
+
# @return [Float]
|
|
39
|
+
attr_reader :total_cost
|
|
40
|
+
# @return [String]
|
|
41
|
+
attr_reader :currency
|
|
42
|
+
|
|
43
|
+
# @param attributes [Hash]
|
|
44
|
+
# @param client [Fal::Client]
|
|
45
|
+
def initialize(attributes, client: Fal.client)
|
|
46
|
+
@client = client
|
|
47
|
+
reset_attributes(attributes)
|
|
48
|
+
end
|
|
49
|
+
|
|
50
|
+
class << self
|
|
51
|
+
# Create a new cost estimate.
|
|
52
|
+
# Corresponds to POST https://api.fal.ai/v1/models/pricing/estimate
|
|
53
|
+
# @param estimate_type [String] one of EstimateType constants
|
|
54
|
+
# @param endpoints [Array<Fal::PriceEstimate::Endpoint, Hash>]
|
|
55
|
+
# @param client [Fal::Client]
|
|
56
|
+
# @return [Fal::PriceEstimate]
|
|
57
|
+
def create(estimate_type:, endpoints:, client: Fal.client)
|
|
58
|
+
endpoint_map = {}
|
|
59
|
+
Array(endpoints).each do |ep|
|
|
60
|
+
endpoint = ep.is_a?(Endpoint) ? ep : Endpoint.new(**ep)
|
|
61
|
+
quantity = endpoint.unit_quantity || endpoint.call_quantity
|
|
62
|
+
|
|
63
|
+
if estimate_type == EstimateType::UNIT_PRICE
|
|
64
|
+
# Accept either unit_quantity or call_quantity (treated as units) for convenience.
|
|
65
|
+
endpoint_map[endpoint.endpoint_id] = { "unit_quantity" => quantity }
|
|
66
|
+
else
|
|
67
|
+
endpoint_map[endpoint.endpoint_id] = { "call_quantity" => quantity }
|
|
68
|
+
end
|
|
69
|
+
end
|
|
70
|
+
|
|
71
|
+
payload = {
|
|
72
|
+
estimate_type: estimate_type,
|
|
73
|
+
endpoints: endpoint_map
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
attributes = client.post_api(ESTIMATE_PATH, payload)
|
|
77
|
+
new(attributes, client: client)
|
|
78
|
+
end
|
|
79
|
+
end
|
|
80
|
+
|
|
81
|
+
private
|
|
82
|
+
|
|
83
|
+
def reset_attributes(attributes)
|
|
84
|
+
@estimate_type = attributes["estimate_type"]
|
|
85
|
+
@total_cost = attributes["total_cost"]
|
|
86
|
+
@currency = attributes["currency"]
|
|
87
|
+
end
|
|
88
|
+
end
|
|
89
|
+
end
|
data/lib/fal/request.rb
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Fal
|
|
4
|
+
# Represents a queued request submitted to a fal model endpoint.
|
|
5
|
+
# Provides helpers to create, query status, cancel, and fetch response payloads
|
|
6
|
+
# using the Queue API described in the fal docs.
|
|
7
|
+
# See: https://docs.fal.ai/model-apis/model-endpoints/queue
|
|
8
|
+
class Request
|
|
9
|
+
# Request status values returned by the Queue API.
|
|
10
|
+
module Status
|
|
11
|
+
# @return [String]
|
|
12
|
+
IN_QUEUE = "IN_QUEUE"
|
|
13
|
+
# @return [String]
|
|
14
|
+
IN_PROGRESS = "IN_PROGRESS"
|
|
15
|
+
# @return [String]
|
|
16
|
+
COMPLETED = "COMPLETED"
|
|
17
|
+
end
|
|
18
|
+
|
|
19
|
+
# @return [String] The request identifier (request_id)
|
|
20
|
+
attr_reader :id
|
|
21
|
+
# @return [String] The current status, one of Fal::Request::Status constants
|
|
22
|
+
attr_reader :status
|
|
23
|
+
# @return [Integer, nil] The current position in the queue, if available
|
|
24
|
+
attr_reader :queue_position
|
|
25
|
+
# @return [Array<Hash>, nil] Log entries when requested via logs=1
|
|
26
|
+
attr_reader :logs
|
|
27
|
+
# @return [Hash, nil] Response payload when status is COMPLETED
|
|
28
|
+
attr_reader :response
|
|
29
|
+
# @return [String] The model identifier used when creating this request
|
|
30
|
+
attr_reader :endpoint_id
|
|
31
|
+
|
|
32
|
+
# @param attributes [Hash] Raw attributes from fal Queue API
|
|
33
|
+
# @param endpoint_id [String] Model ID in "namespace/name" format
|
|
34
|
+
# @param client [Fal::Client] HTTP client to use for subsequent calls
|
|
35
|
+
def initialize(attributes, endpoint_id:, client: Fal.client)
|
|
36
|
+
@client = client
|
|
37
|
+
@endpoint_id = endpoint_id
|
|
38
|
+
reset_attributes(attributes)
|
|
39
|
+
end
|
|
40
|
+
|
|
41
|
+
class << self
|
|
42
|
+
# Create a new queued request for a model.
|
|
43
|
+
# Corresponds to POST https://queue.fal.run/{endpoint_id}
|
|
44
|
+
# Optionally appends fal_webhook query param per docs.
|
|
45
|
+
# @param endpoint_id [String]
|
|
46
|
+
# @param input [Hash]
|
|
47
|
+
# @param webhook_url [String, nil]
|
|
48
|
+
# @param client [Fal::Client]
|
|
49
|
+
# @return [Fal::Request]
|
|
50
|
+
def create!(endpoint_id:, input:, webhook_url: nil, client: Fal.client)
|
|
51
|
+
path = "/#{endpoint_id}"
|
|
52
|
+
body = input || {}
|
|
53
|
+
path = "#{path}?fal_webhook=#{CGI.escape(webhook_url)}" if webhook_url
|
|
54
|
+
attributes = client.post(path, body)
|
|
55
|
+
new(attributes, endpoint_id: endpoint_id, client: client)
|
|
56
|
+
end
|
|
57
|
+
|
|
58
|
+
# Find the current status for a given request.
|
|
59
|
+
# Corresponds to GET https://queue.fal.run/{endpoint_id}/requests/{request_id}/status
|
|
60
|
+
# @param id [String]
|
|
61
|
+
# @param endpoint_id [String]
|
|
62
|
+
# @param logs [Boolean] include logs if true
|
|
63
|
+
# @param client [Fal::Client]
|
|
64
|
+
# @return [Fal::Request]
|
|
65
|
+
def find_by!(id:, endpoint_id:, logs: false, client: Fal.client)
|
|
66
|
+
endpoint_id_without_subpath = endpoint_id.split("/").slice(0, 2).join("/")
|
|
67
|
+
attributes = client.get("/#{endpoint_id_without_subpath}/requests/#{id}/status",
|
|
68
|
+
query: (logs ? { logs: 1 } : nil))
|
|
69
|
+
new(attributes, endpoint_id: endpoint_id, client: client)
|
|
70
|
+
end
|
|
71
|
+
|
|
72
|
+
# Stream a synchronous request using SSE and yield response chunks as they arrive.
|
|
73
|
+
# It returns a Fal::Request initialized with the last streamed data in the response field.
|
|
74
|
+
# @param endpoint_id [String]
|
|
75
|
+
# @param input [Hash]
|
|
76
|
+
# @param client [Fal::Client]
|
|
77
|
+
# @yield [chunk] yields each parsed chunk Hash from the stream
|
|
78
|
+
# @yieldparam chunk [Hash]
|
|
79
|
+
# @return [Fal::Request]
|
|
80
|
+
def stream!(endpoint_id:, input:, client: Fal.client, &block)
|
|
81
|
+
path = "/#{endpoint_id}/stream"
|
|
82
|
+
last_data = nil
|
|
83
|
+
|
|
84
|
+
Stream.new(path: path, input: input, client: client).each do |event|
|
|
85
|
+
data = event["data"]
|
|
86
|
+
last_data = data
|
|
87
|
+
block&.call(data)
|
|
88
|
+
end
|
|
89
|
+
|
|
90
|
+
# Wrap last chunk into a Request-like object for convenience
|
|
91
|
+
# Build attributes from last event, using inner response if available
|
|
92
|
+
response_payload = if last_data&.key?("response")
|
|
93
|
+
last_data["response"]
|
|
94
|
+
else
|
|
95
|
+
last_data
|
|
96
|
+
end
|
|
97
|
+
attributes = {
|
|
98
|
+
"request_id" => last_data && last_data["request_id"],
|
|
99
|
+
"status" => last_data && last_data["status"],
|
|
100
|
+
"response" => response_payload
|
|
101
|
+
}.compact
|
|
102
|
+
new(attributes, endpoint_id: endpoint_id, client: client)
|
|
103
|
+
end
|
|
104
|
+
end
|
|
105
|
+
|
|
106
|
+
# @return [String] The model ID without the subpath
|
|
107
|
+
def endpoint_id_without_subpath
|
|
108
|
+
@endpoint_id.split("/").slice(0, 2).join("/")
|
|
109
|
+
end
|
|
110
|
+
|
|
111
|
+
# Reload the current status from the Queue API.
|
|
112
|
+
# @param logs [Boolean] include logs if true
|
|
113
|
+
# @return [Fal::Request]
|
|
114
|
+
def reload!(logs: false)
|
|
115
|
+
if @status == Status::IN_PROGRESS || @status == Status::IN_QUEUE
|
|
116
|
+
attributes = @client.get("/#{endpoint_id_without_subpath}/requests/#{@id}/status",
|
|
117
|
+
query: (logs ? { logs: 1 } : nil))
|
|
118
|
+
reset_attributes(attributes)
|
|
119
|
+
end
|
|
120
|
+
|
|
121
|
+
@response = @client.get("/#{endpoint_id_without_subpath}/requests/#{@id}") if @status == Status::COMPLETED
|
|
122
|
+
|
|
123
|
+
self
|
|
124
|
+
end
|
|
125
|
+
|
|
126
|
+
# Attempt to cancel the request if still in queue.
|
|
127
|
+
# @return [Hash] cancellation response
|
|
128
|
+
def cancel!
|
|
129
|
+
@client.put("/#{endpoint_id_without_subpath}/requests/#{@id}/cancel")
|
|
130
|
+
end
|
|
131
|
+
|
|
132
|
+
# @return [Boolean]
|
|
133
|
+
def in_queue?
|
|
134
|
+
@status == Status::IN_QUEUE
|
|
135
|
+
end
|
|
136
|
+
|
|
137
|
+
# @return [Boolean]
|
|
138
|
+
def in_progress?
|
|
139
|
+
@status == Status::IN_PROGRESS
|
|
140
|
+
end
|
|
141
|
+
|
|
142
|
+
# @return [Boolean]
|
|
143
|
+
def completed?
|
|
144
|
+
@status == Status::COMPLETED
|
|
145
|
+
end
|
|
146
|
+
|
|
147
|
+
private
|
|
148
|
+
|
|
149
|
+
# Normalize attributes from different Queue API responses.
|
|
150
|
+
# @param attributes [Hash]
|
|
151
|
+
# @return [void]
|
|
152
|
+
def reset_attributes(attributes)
|
|
153
|
+
@id = attributes["request_id"] || @id
|
|
154
|
+
# Default to IN_QUEUE if no status provided and no previous status
|
|
155
|
+
@status = attributes["status"] || @status || Status::IN_QUEUE
|
|
156
|
+
@queue_position = attributes["queue_position"]
|
|
157
|
+
@logs = attributes["logs"]
|
|
158
|
+
@response = attributes["response"]
|
|
159
|
+
end
|
|
160
|
+
end
|
|
161
|
+
end
|
data/lib/fal/stream.rb
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Fal
|
|
4
|
+
# Streaming helper for Server-Sent Events from fal.run synchronous endpoints.
|
|
5
|
+
# It parses SSE lines and yields decoded event hashes with symbolized keys.
|
|
6
|
+
class Stream
|
|
7
|
+
# @return [String] endpoint path under fal.run, e.g. "/fal-ai/flux/dev/stream"
|
|
8
|
+
attr_reader :path
|
|
9
|
+
|
|
10
|
+
# @param path [String] full path under sync_base (leading slash), ex: "/fal-ai/flux/dev/stream"
|
|
11
|
+
# @param input [Hash] request input payload
|
|
12
|
+
# @param client [Fal::Client] HTTP client
|
|
13
|
+
def initialize(path:, input:, client: Fal.client)
|
|
14
|
+
@path = path
|
|
15
|
+
@input = input
|
|
16
|
+
@client = client
|
|
17
|
+
end
|
|
18
|
+
|
|
19
|
+
# Stream events; yields a Hash for each event data chunk. Blocks until stream ends.
|
|
20
|
+
# @yield [event] yields decoded event hash
|
|
21
|
+
# @yieldparam event [Hash]
|
|
22
|
+
# @return [void]
|
|
23
|
+
def each(&block)
|
|
24
|
+
buffer = ""
|
|
25
|
+
decoder = SSEDecoder.new
|
|
26
|
+
|
|
27
|
+
@client.post_stream(@path, @input, on_data: proc do |chunk, _total_bytes|
|
|
28
|
+
buffer = (buffer + chunk).gsub(/\r\n?/, "\n")
|
|
29
|
+
lines = buffer.split("\n", -1)
|
|
30
|
+
buffer = lines.pop || ""
|
|
31
|
+
lines.each do |line|
|
|
32
|
+
event = decoder.decode(line)
|
|
33
|
+
block.call(event) if event
|
|
34
|
+
end
|
|
35
|
+
end)
|
|
36
|
+
end
|
|
37
|
+
|
|
38
|
+
# Minimal SSE decoder for parsing standard server-sent event stream lines.
|
|
39
|
+
class SSEDecoder
|
|
40
|
+
def initialize
|
|
41
|
+
@event = ""
|
|
42
|
+
@data = ""
|
|
43
|
+
@id = nil
|
|
44
|
+
@retry = nil
|
|
45
|
+
end
|
|
46
|
+
|
|
47
|
+
# @param line [String]
|
|
48
|
+
# @return [Hash, nil]
|
|
49
|
+
def decode(line)
|
|
50
|
+
return flush_event if line.empty?
|
|
51
|
+
return if line.start_with?(":")
|
|
52
|
+
|
|
53
|
+
field, _, value = line.partition(":")
|
|
54
|
+
value = value.lstrip
|
|
55
|
+
|
|
56
|
+
case field
|
|
57
|
+
when "event"
|
|
58
|
+
@event = value
|
|
59
|
+
when "data"
|
|
60
|
+
@data += "#{value}\n"
|
|
61
|
+
when "id"
|
|
62
|
+
@id = value
|
|
63
|
+
when "retry"
|
|
64
|
+
@retry = value.to_i
|
|
65
|
+
end
|
|
66
|
+
|
|
67
|
+
nil
|
|
68
|
+
end
|
|
69
|
+
|
|
70
|
+
private
|
|
71
|
+
|
|
72
|
+
def flush_event
|
|
73
|
+
return if @data.empty?
|
|
74
|
+
|
|
75
|
+
data = @data.chomp
|
|
76
|
+
parsed = JSON.parse(data)
|
|
77
|
+
|
|
78
|
+
event = { "data" => parsed }
|
|
79
|
+
event["event"] = @event unless @event.empty?
|
|
80
|
+
event["id"] = @id if @id
|
|
81
|
+
event["retry"] = @retry if @retry
|
|
82
|
+
|
|
83
|
+
@event = ""
|
|
84
|
+
@data = ""
|
|
85
|
+
@id = nil
|
|
86
|
+
@retry = nil
|
|
87
|
+
|
|
88
|
+
event
|
|
89
|
+
end
|
|
90
|
+
end
|
|
91
|
+
end
|
|
92
|
+
end
|
data/lib/fal/version.rb
CHANGED