ahamid-postgres-pr 0.6.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,174 @@
1
+ #
2
+ # Author:: Michael Neumann
3
+ # Copyright:: (c) 2005 by Michael Neumann
4
+ # License:: Same as Ruby's or BSD
5
+ #
6
+
7
+ require 'postgres-pr/message'
8
+ require 'postgres-pr/version'
9
+ require 'uri'
10
+ require 'socket'
11
+
12
+ module PostgresPR
13
+
14
+ PROTO_VERSION = 3 << 16 #196608
15
+
16
+ class Connection
17
+
18
+ # A block which is called with the NoticeResponse object as parameter.
19
+ attr_accessor :notice_processor
20
+
21
+ #
22
+ # Returns one of the following statuses:
23
+ #
24
+ # PQTRANS_IDLE = 0 (connection idle)
25
+ # PQTRANS_INTRANS = 2 (idle, within transaction block)
26
+ # PQTRANS_INERROR = 3 (idle, within failed transaction)
27
+ # PQTRANS_UNKNOWN = 4 (cannot determine status)
28
+ #
29
+ # Not yet implemented is:
30
+ #
31
+ # PQTRANS_ACTIVE = 1 (command in progress)
32
+ #
33
+ def transaction_status
34
+ case @transaction_status
35
+ when ?I
36
+ 0
37
+ when ?T
38
+ 2
39
+ when ?E
40
+ 3
41
+ else
42
+ 4
43
+ end
44
+ end
45
+
46
+ def initialize(database, user, password=nil, uri = nil)
47
+ uri ||= DEFAULT_URI
48
+
49
+ @transaction_status = nil
50
+ @params = {}
51
+ establish_connection(uri)
52
+
53
+ @conn << StartupMessage.new(PROTO_VERSION, 'user' => user, 'database' => database).dump
54
+
55
+ loop do
56
+ msg = Message.read(@conn)
57
+
58
+ case msg
59
+ when AuthentificationClearTextPassword
60
+ raise ArgumentError, "no password specified" if password.nil?
61
+ @conn << PasswordMessage.new(password).dump
62
+
63
+ when AuthentificationCryptPassword
64
+ raise ArgumentError, "no password specified" if password.nil?
65
+ @conn << PasswordMessage.new(password.crypt(msg.salt)).dump
66
+
67
+ when AuthentificationMD5Password
68
+ raise ArgumentError, "no password specified" if password.nil?
69
+ require 'digest/md5'
70
+
71
+ m = Digest::MD5.hexdigest(password + user)
72
+ m = Digest::MD5.hexdigest(m + msg.salt)
73
+ m = 'md5' + m
74
+ @conn << PasswordMessage.new(m).dump
75
+
76
+ when AuthentificationKerberosV4, AuthentificationKerberosV5, AuthentificationSCMCredential
77
+ raise "unsupported authentification"
78
+
79
+ when AuthentificationOk
80
+ when ErrorResponse
81
+ raise msg.field_values.join("\t")
82
+ when NoticeResponse
83
+ @notice_processor.call(msg) if @notice_processor
84
+ when ParameterStatus
85
+ @params[msg.key] = msg.value
86
+ when BackendKeyData
87
+ # TODO
88
+ #p msg
89
+ when ReadyForQuery
90
+ @transaction_status = msg.backend_transaction_status_indicator
91
+ break
92
+ else
93
+ raise "unhandled message type"
94
+ end
95
+ end
96
+ end
97
+
98
+ def close
99
+ raise "connection already closed" if @conn.nil?
100
+ @conn.shutdown
101
+ @conn = nil
102
+ end
103
+
104
+ class Result
105
+ attr_accessor :rows, :fields, :cmd_tag
106
+ def initialize(rows=[], fields=[])
107
+ @rows, @fields = rows, fields
108
+ end
109
+ end
110
+
111
+ def query(sql)
112
+ @conn << Query.dump(sql)
113
+
114
+ result = Result.new
115
+ errors = []
116
+
117
+ loop do
118
+ msg = Message.read(@conn)
119
+ case msg
120
+ when DataRow
121
+ result.rows << msg.columns
122
+ when CommandComplete
123
+ result.cmd_tag = msg.cmd_tag
124
+ when ReadyForQuery
125
+ @transaction_status = msg.backend_transaction_status_indicator
126
+ break
127
+ when RowDescription
128
+ result.fields = msg.fields
129
+ when CopyInResponse
130
+ when CopyOutResponse
131
+ when EmptyQueryResponse
132
+ when ErrorResponse
133
+ # TODO
134
+ errors << msg
135
+ when NoticeResponse
136
+ @notice_processor.call(msg) if @notice_processor
137
+ else
138
+ # TODO
139
+ end
140
+ end
141
+
142
+ raise errors.map{|e| e.field_values.join("\t") }.join("\n") unless errors.empty?
143
+
144
+ result
145
+ end
146
+
147
+ DEFAULT_PORT = 5432
148
+ DEFAULT_HOST = 'localhost'
149
+ DEFAULT_PATH = '/tmp'
150
+ DEFAULT_URI =
151
+ if RUBY_PLATFORM.include?('win')
152
+ 'tcp://' + DEFAULT_HOST + ':' + DEFAULT_PORT.to_s
153
+ else
154
+ 'unix:' + File.join(DEFAULT_PATH, '.s.PGSQL.' + DEFAULT_PORT.to_s)
155
+ end
156
+
157
+ private
158
+
159
+ # tcp://localhost:5432
160
+ # unix:/tmp/.s.PGSQL.5432
161
+ def establish_connection(uri)
162
+ u = URI.parse(uri)
163
+ case u.scheme
164
+ when 'tcp'
165
+ @conn = TCPSocket.new(u.host || DEFAULT_HOST, u.port || DEFAULT_PORT)
166
+ when 'unix'
167
+ @conn = UNIXSocket.new(u.path)
168
+ else
169
+ raise 'unrecognized uri scheme format (must be tcp or unix)'
170
+ end
171
+ end
172
+ end
173
+
174
+ end # module PostgresPR
@@ -0,0 +1,542 @@
1
+ #
2
+ # Author:: Michael Neumann
3
+ # Copyright:: (c) 2005 by Michael Neumann
4
+ # License:: Same as Ruby's or BSD
5
+ #
6
+
7
+ require 'buffer'
8
+ class IO
9
+ def read_exactly_n_bytes(n)
10
+ buf = read(n)
11
+ raise EOFError if buf == nil
12
+ return buf if buf.size == n
13
+
14
+ n -= buf.size
15
+
16
+ while n > 0
17
+ str = read(n)
18
+ raise EOFError if str == nil
19
+ buf << str
20
+ n -= str.size
21
+ end
22
+ return buf
23
+ end
24
+ end
25
+
26
+ module PostgresPR
27
+
28
+ class ParseError < RuntimeError; end
29
+ class DumpError < RuntimeError; end
30
+
31
+
32
+ # Base class representing a PostgreSQL protocol message
33
+ class Message
34
+ # One character message-typecode to class map
35
+ MsgTypeMap = Hash.new { UnknownMessageType }
36
+
37
+ def self.register_message_type(type)
38
+ raise "duplicate message type registration" if MsgTypeMap.has_key?(type)
39
+
40
+ MsgTypeMap[type] = self
41
+
42
+ self.const_set(:MsgType, type)
43
+ class_eval "def message_type; MsgType end"
44
+ end
45
+
46
+ def self.read(stream, startup=false)
47
+ type = stream.read_exactly_n_bytes(1) unless startup
48
+ length = stream.read_exactly_n_bytes(4).unpack('N').first # FIXME: length should be signed, not unsigned
49
+
50
+ raise ParseError unless length >= 4
51
+
52
+ # initialize buffer
53
+ buffer = Buffer.of_size(startup ? length : 1+length)
54
+ buffer.write(type) unless startup
55
+ buffer.write_int32_network(length)
56
+ buffer.copy_from_stream(stream, length-4)
57
+
58
+ (startup ? StartupMessage : MsgTypeMap[type]).create(buffer)
59
+ end
60
+
61
+ def self.create(buffer)
62
+ obj = allocate
63
+ obj.parse(buffer)
64
+ obj
65
+ end
66
+
67
+ def self.dump(*args)
68
+ new(*args).dump
69
+ end
70
+
71
+ def dump(body_size=0)
72
+ buffer = Buffer.of_size(5 + body_size)
73
+ buffer.write(self.message_type)
74
+ buffer.write_int32_network(4 + body_size)
75
+ yield buffer if block_given?
76
+ raise DumpError unless buffer.at_end?
77
+ return buffer.content
78
+ end
79
+
80
+ def parse(buffer)
81
+ buffer.position = 5
82
+ yield buffer if block_given?
83
+ raise ParseError, buffer.inspect unless buffer.at_end?
84
+ end
85
+
86
+ def self.fields(*attribs)
87
+ names = attribs.map {|name, type| name.to_s}
88
+ arg_list = names.join(", ")
89
+ ivar_list = names.map {|name| "@" + name }.join(", ")
90
+ sym_list = names.map {|name| ":" + name }.join(", ")
91
+ class_eval %[
92
+ attr_accessor #{ sym_list }
93
+ def initialize(#{ arg_list })
94
+ #{ ivar_list } = #{ arg_list }
95
+ end
96
+ ]
97
+ end
98
+ end
99
+
100
+ class UnknownMessageType < Message
101
+ def dump
102
+ raise
103
+ end
104
+ end
105
+
106
+ class Authentification < Message
107
+ register_message_type 'R'
108
+
109
+ AuthTypeMap = Hash.new { UnknownAuthType }
110
+
111
+ def self.create(buffer)
112
+ buffer.position = 5
113
+ authtype = buffer.read_int32_network
114
+ klass = AuthTypeMap[authtype]
115
+ obj = klass.allocate
116
+ obj.parse(buffer)
117
+ obj
118
+ end
119
+
120
+ def self.register_auth_type(type)
121
+ raise "duplicate auth type registration" if AuthTypeMap.has_key?(type)
122
+ AuthTypeMap[type] = self
123
+ self.const_set(:AuthType, type)
124
+ class_eval "def auth_type() AuthType end"
125
+ end
126
+
127
+ # the dump method of class Message
128
+ alias message__dump dump
129
+
130
+ def dump
131
+ super(4) do |buffer|
132
+ buffer.write_int32_network(self.auth_type)
133
+ end
134
+ end
135
+
136
+ def parse(buffer)
137
+ super do
138
+ auth_t = buffer.read_int32_network
139
+ raise ParseError unless auth_t == self.auth_type
140
+ yield if block_given?
141
+ end
142
+ end
143
+ end
144
+
145
+ class UnknownAuthType < Authentification
146
+ end
147
+
148
+ class AuthentificationOk < Authentification
149
+ register_auth_type 0
150
+ end
151
+
152
+ class AuthentificationKerberosV4 < Authentification
153
+ register_auth_type 1
154
+ end
155
+
156
+ class AuthentificationKerberosV5 < Authentification
157
+ register_auth_type 2
158
+ end
159
+
160
+ class AuthentificationClearTextPassword < Authentification
161
+ register_auth_type 3
162
+ end
163
+
164
+ module SaltedAuthentificationMixin
165
+ attr_accessor :salt
166
+
167
+ def initialize(salt)
168
+ @salt = salt
169
+ end
170
+
171
+ def dump
172
+ raise DumpError unless @salt.size == self.salt_size
173
+
174
+ message__dump(4 + self.salt_size) do |buffer|
175
+ buffer.write_int32_network(self.auth_type)
176
+ buffer.write(@salt)
177
+ end
178
+ end
179
+
180
+ def parse(buffer)
181
+ super do
182
+ @salt = buffer.read(self.salt_size)
183
+ end
184
+ end
185
+ end
186
+
187
+ class AuthentificationCryptPassword < Authentification
188
+ register_auth_type 4
189
+ include SaltedAuthentificationMixin
190
+ def salt_size; 2 end
191
+ end
192
+
193
+
194
+ class AuthentificationMD5Password < Authentification
195
+ register_auth_type 5
196
+ include SaltedAuthentificationMixin
197
+ def salt_size; 4 end
198
+ end
199
+
200
+ class AuthentificationSCMCredential < Authentification
201
+ register_auth_type 6
202
+ end
203
+
204
+ class PasswordMessage < Message
205
+ register_message_type 'p'
206
+ fields :password
207
+
208
+ def dump
209
+ super(@password.size + 1) do |buffer|
210
+ buffer.write_cstring(@password)
211
+ end
212
+ end
213
+
214
+ def parse(buffer)
215
+ super do
216
+ @password = buffer.read_cstring
217
+ end
218
+ end
219
+ end
220
+
221
+ class ParameterStatus < Message
222
+ register_message_type 'S'
223
+ fields :key, :value
224
+
225
+ def dump
226
+ super(@key.size + 1 + @value.size + 1) do |buffer|
227
+ buffer.write_cstring(@key)
228
+ buffer.write_cstring(@value)
229
+ end
230
+ end
231
+
232
+ def parse(buffer)
233
+ super do
234
+ @key = buffer.read_cstring
235
+ @value = buffer.read_cstring
236
+ end
237
+ end
238
+ end
239
+
240
+ class BackendKeyData < Message
241
+ register_message_type 'K'
242
+ fields :process_id, :secret_key
243
+
244
+ def dump
245
+ super(4 + 4) do |buffer|
246
+ buffer.write_int32_network(@process_id)
247
+ buffer.write_int32_network(@secret_key)
248
+ end
249
+ end
250
+
251
+ def parse(buffer)
252
+ super do
253
+ @process_id = buffer.read_int32_network
254
+ @secret_key = buffer.read_int32_network
255
+ end
256
+ end
257
+ end
258
+
259
+ class ReadyForQuery < Message
260
+ register_message_type 'Z'
261
+ fields :backend_transaction_status_indicator
262
+
263
+ def dump
264
+ super(1) do |buffer|
265
+ buffer.write_byte(@backend_transaction_status_indicator)
266
+ end
267
+ end
268
+
269
+ def parse(buffer)
270
+ super do
271
+ @backend_transaction_status_indicator = buffer.read_byte
272
+ end
273
+ end
274
+ end
275
+
276
+ class DataRow < Message
277
+ register_message_type 'D'
278
+ fields :columns
279
+
280
+ def dump
281
+ sz = @columns.inject(2) {|sum, col| sum + 4 + (col ? col.size : 0)}
282
+ super(sz) do |buffer|
283
+ buffer.write_int16_network(@columns.size)
284
+ @columns.each {|col|
285
+ buffer.write_int32_network(col ? col.size : -1)
286
+ buffer.write(col) if col
287
+ }
288
+ end
289
+ end
290
+
291
+ def parse(buffer)
292
+ super do
293
+ n_cols = buffer.read_int16_network
294
+ @columns = (1..n_cols).collect {
295
+ len = buffer.read_int32_network
296
+ if len == -1
297
+ nil
298
+ else
299
+ buffer.read(len)
300
+ end
301
+ }
302
+ end
303
+ end
304
+ end
305
+
306
+ class CommandComplete < Message
307
+ register_message_type 'C'
308
+ fields :cmd_tag
309
+
310
+ def dump
311
+ super(@cmd_tag.size + 1) do |buffer|
312
+ buffer.write_cstring(@cmd_tag)
313
+ end
314
+ end
315
+
316
+ def parse(buffer)
317
+ super do
318
+ @cmd_tag = buffer.read_cstring
319
+ end
320
+ end
321
+ end
322
+
323
+ class EmptyQueryResponse < Message
324
+ register_message_type 'I'
325
+ end
326
+
327
+ module NoticeErrorMixin
328
+ attr_accessor :field_type, :field_values
329
+
330
+ def initialize(field_type=0, field_values=[])
331
+ raise ArgumentError if field_type == 0 and not field_values.empty?
332
+ @field_type, @field_values = field_type, field_values
333
+ end
334
+
335
+ def dump
336
+ raise ArgumentError if @field_type == 0 and not @field_values.empty?
337
+
338
+ sz = 1
339
+ sz += @field_values.inject(1) {|sum, fld| sum + fld.size + 1} unless @field_type == 0
340
+
341
+ super(sz) do |buffer|
342
+ buffer.write_byte(@field_type)
343
+ break if @field_type == 0
344
+ @field_values.each {|fld| buffer.write_cstring(fld) }
345
+ buffer.write_byte(0)
346
+ end
347
+ end
348
+
349
+ def parse(buffer)
350
+ super do
351
+ @field_type = buffer.read_byte
352
+ break if @field_type == 0
353
+ @field_values = []
354
+ while buffer.position < buffer.size-1
355
+ @field_values << buffer.read_cstring
356
+ end
357
+ terminator = buffer.read_byte
358
+ raise ParseError unless terminator == 0
359
+ end
360
+ end
361
+ end
362
+
363
+ class NoticeResponse < Message
364
+ register_message_type 'N'
365
+ include NoticeErrorMixin
366
+ end
367
+
368
+ class ErrorResponse < Message
369
+ register_message_type 'E'
370
+ include NoticeErrorMixin
371
+ end
372
+
373
+ # TODO
374
+ class CopyInResponse < Message
375
+ register_message_type 'G'
376
+ end
377
+
378
+ # TODO
379
+ class CopyOutResponse < Message
380
+ register_message_type 'H'
381
+ end
382
+
383
+ class Parse < Message
384
+ register_message_type 'P'
385
+ fields :query, :stmt_name, :parameter_oids
386
+
387
+ def initialize(query, stmt_name="", parameter_oids=[])
388
+ @query, @stmt_name, @parameter_oids = query, stmt_name, parameter_oids
389
+ end
390
+
391
+ def dump
392
+ sz = @stmt_name.size + 1 + @query.size + 1 + 2 + (4 * @parameter_oids.size)
393
+ super(sz) do |buffer|
394
+ buffer.write_cstring(@stmt_name)
395
+ buffer.write_cstring(@query)
396
+ buffer.write_int16_network(@parameter_oids.size)
397
+ @parameter_oids.each {|oid| buffer.write_int32_network(oid) }
398
+ end
399
+ end
400
+
401
+ def parse(buffer)
402
+ super do
403
+ @stmt_name = buffer.read_cstring
404
+ @query = buffer.read_cstring
405
+ n_oids = buffer.read_int16_network
406
+ @parameter_oids = (1..n_oids).collect {
407
+ # TODO: zero means unspecified. map to nil?
408
+ buffer.read_int32_network
409
+ }
410
+ end
411
+ end
412
+ end
413
+
414
+ class ParseComplete < Message
415
+ register_message_type '1'
416
+ end
417
+
418
+ class Query < Message
419
+ register_message_type 'Q'
420
+ fields :query
421
+
422
+ def dump
423
+ super(@query.size + 1) do |buffer|
424
+ buffer.write_cstring(@query)
425
+ end
426
+ end
427
+
428
+ def parse(buffer)
429
+ super do
430
+ @query = buffer.read_cstring
431
+ end
432
+ end
433
+ end
434
+
435
+ class RowDescription < Message
436
+ register_message_type 'T'
437
+ fields :fields
438
+
439
+ class FieldInfo < Struct.new(:name, :oid, :attr_nr, :type_oid, :typlen, :atttypmod, :formatcode); end
440
+
441
+ def dump
442
+ sz = @fields.inject(2) {|sum, fld| sum + 18 + fld.name.size + 1 }
443
+ super(sz) do |buffer|
444
+ buffer.write_int16_network(@fields.size)
445
+ @fields.each { |f|
446
+ buffer.write_cstring(f.name)
447
+ buffer.write_int32_network(f.oid)
448
+ buffer.write_int16_network(f.attr_nr)
449
+ buffer.write_int32_network(f.type_oid)
450
+ buffer.write_int16_network(f.typlen)
451
+ buffer.write_int32_network(f.atttypmod)
452
+ buffer.write_int16_network(f.formatcode)
453
+ }
454
+ end
455
+ end
456
+
457
+ def parse(buffer)
458
+ super do
459
+ n_fields = buffer.read_int16_network
460
+ @fields = (1..n_fields).collect {
461
+ f = FieldInfo.new
462
+ f.name = buffer.read_cstring
463
+ f.oid = buffer.read_int32_network
464
+ f.attr_nr = buffer.read_int16_network
465
+ f.type_oid = buffer.read_int32_network
466
+ f.typlen = buffer.read_int16_network
467
+ f.atttypmod = buffer.read_int32_network
468
+ f.formatcode = buffer.read_int16_network
469
+ f
470
+ }
471
+ end
472
+ end
473
+ end
474
+
475
+ class StartupMessage < Message
476
+ fields :proto_version, :params
477
+
478
+ def dump
479
+ sz = @params.inject(4 + 4) {|sum, kv| sum + kv[0].size + 1 + kv[1].size + 1} + 1
480
+
481
+ buffer = Buffer.of_size(sz)
482
+ buffer.write_int32_network(sz)
483
+ buffer.write_int32_network(@proto_version)
484
+ @params.each_pair {|key, value|
485
+ buffer.write_cstring(key)
486
+ buffer.write_cstring(value)
487
+ }
488
+ buffer.write_byte(0)
489
+
490
+ raise DumpError unless buffer.at_end?
491
+ return buffer.content
492
+ end
493
+
494
+ def parse(buffer)
495
+ buffer.position = 4
496
+
497
+ @proto_version = buffer.read_int32_network
498
+ @params = {}
499
+
500
+ while buffer.position < buffer.size-1
501
+ key = buffer.read_cstring
502
+ val = buffer.read_cstring
503
+ @params[key] = val
504
+ end
505
+
506
+ nul = buffer.read_byte
507
+ raise ParseError unless nul == 0
508
+ raise ParseError unless buffer.at_end?
509
+ end
510
+ end
511
+
512
+ class SSLRequest < Message
513
+ fields :ssl_request_code
514
+
515
+ def dump
516
+ sz = 4 + 4
517
+ buffer = Buffer.of_size(sz)
518
+ buffer.write_int32_network(sz)
519
+ buffer.write_int32_network(@ssl_request_code)
520
+ raise DumpError unless buffer.at_end?
521
+ return buffer.content
522
+ end
523
+
524
+ def parse(buffer)
525
+ buffer.position = 4
526
+ @ssl_request_code = buffer.read_int32_network
527
+ raise ParseError unless buffer.at_end?
528
+ end
529
+ end
530
+
531
+ =begin
532
+ # TODO: duplicate message-type, split into client/server messages
533
+ class Sync < Message
534
+ register_message_type 'S'
535
+ end
536
+ =end
537
+
538
+ class Terminate < Message
539
+ register_message_type 'X'
540
+ end
541
+
542
+ end # module PostgresPR