ahamid-postgres-pr 0.6.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,174 @@
1
+ #
2
+ # Author:: Michael Neumann
3
+ # Copyright:: (c) 2005 by Michael Neumann
4
+ # License:: Same as Ruby's or BSD
5
+ #
6
+
7
+ require 'postgres-pr/message'
8
+ require 'postgres-pr/version'
9
+ require 'uri'
10
+ require 'socket'
11
+
12
+ module PostgresPR
13
+
14
+ PROTO_VERSION = 3 << 16 #196608
15
+
16
+ class Connection
17
+
18
+ # A block which is called with the NoticeResponse object as parameter.
19
+ attr_accessor :notice_processor
20
+
21
+ #
22
+ # Returns one of the following statuses:
23
+ #
24
+ # PQTRANS_IDLE = 0 (connection idle)
25
+ # PQTRANS_INTRANS = 2 (idle, within transaction block)
26
+ # PQTRANS_INERROR = 3 (idle, within failed transaction)
27
+ # PQTRANS_UNKNOWN = 4 (cannot determine status)
28
+ #
29
+ # Not yet implemented is:
30
+ #
31
+ # PQTRANS_ACTIVE = 1 (command in progress)
32
+ #
33
+ def transaction_status
34
+ case @transaction_status
35
+ when ?I
36
+ 0
37
+ when ?T
38
+ 2
39
+ when ?E
40
+ 3
41
+ else
42
+ 4
43
+ end
44
+ end
45
+
46
+ def initialize(database, user, password=nil, uri = nil)
47
+ uri ||= DEFAULT_URI
48
+
49
+ @transaction_status = nil
50
+ @params = {}
51
+ establish_connection(uri)
52
+
53
+ @conn << StartupMessage.new(PROTO_VERSION, 'user' => user, 'database' => database).dump
54
+
55
+ loop do
56
+ msg = Message.read(@conn)
57
+
58
+ case msg
59
+ when AuthentificationClearTextPassword
60
+ raise ArgumentError, "no password specified" if password.nil?
61
+ @conn << PasswordMessage.new(password).dump
62
+
63
+ when AuthentificationCryptPassword
64
+ raise ArgumentError, "no password specified" if password.nil?
65
+ @conn << PasswordMessage.new(password.crypt(msg.salt)).dump
66
+
67
+ when AuthentificationMD5Password
68
+ raise ArgumentError, "no password specified" if password.nil?
69
+ require 'digest/md5'
70
+
71
+ m = Digest::MD5.hexdigest(password + user)
72
+ m = Digest::MD5.hexdigest(m + msg.salt)
73
+ m = 'md5' + m
74
+ @conn << PasswordMessage.new(m).dump
75
+
76
+ when AuthentificationKerberosV4, AuthentificationKerberosV5, AuthentificationSCMCredential
77
+ raise "unsupported authentification"
78
+
79
+ when AuthentificationOk
80
+ when ErrorResponse
81
+ raise msg.field_values.join("\t")
82
+ when NoticeResponse
83
+ @notice_processor.call(msg) if @notice_processor
84
+ when ParameterStatus
85
+ @params[msg.key] = msg.value
86
+ when BackendKeyData
87
+ # TODO
88
+ #p msg
89
+ when ReadyForQuery
90
+ @transaction_status = msg.backend_transaction_status_indicator
91
+ break
92
+ else
93
+ raise "unhandled message type"
94
+ end
95
+ end
96
+ end
97
+
98
+ def close
99
+ raise "connection already closed" if @conn.nil?
100
+ @conn.shutdown
101
+ @conn = nil
102
+ end
103
+
104
+ class Result
105
+ attr_accessor :rows, :fields, :cmd_tag
106
+ def initialize(rows=[], fields=[])
107
+ @rows, @fields = rows, fields
108
+ end
109
+ end
110
+
111
+ def query(sql)
112
+ @conn << Query.dump(sql)
113
+
114
+ result = Result.new
115
+ errors = []
116
+
117
+ loop do
118
+ msg = Message.read(@conn)
119
+ case msg
120
+ when DataRow
121
+ result.rows << msg.columns
122
+ when CommandComplete
123
+ result.cmd_tag = msg.cmd_tag
124
+ when ReadyForQuery
125
+ @transaction_status = msg.backend_transaction_status_indicator
126
+ break
127
+ when RowDescription
128
+ result.fields = msg.fields
129
+ when CopyInResponse
130
+ when CopyOutResponse
131
+ when EmptyQueryResponse
132
+ when ErrorResponse
133
+ # TODO
134
+ errors << msg
135
+ when NoticeResponse
136
+ @notice_processor.call(msg) if @notice_processor
137
+ else
138
+ # TODO
139
+ end
140
+ end
141
+
142
+ raise errors.map{|e| e.field_values.join("\t") }.join("\n") unless errors.empty?
143
+
144
+ result
145
+ end
146
+
147
+ DEFAULT_PORT = 5432
148
+ DEFAULT_HOST = 'localhost'
149
+ DEFAULT_PATH = '/tmp'
150
+ DEFAULT_URI =
151
+ if RUBY_PLATFORM.include?('win')
152
+ 'tcp://' + DEFAULT_HOST + ':' + DEFAULT_PORT.to_s
153
+ else
154
+ 'unix:' + File.join(DEFAULT_PATH, '.s.PGSQL.' + DEFAULT_PORT.to_s)
155
+ end
156
+
157
+ private
158
+
159
+ # tcp://localhost:5432
160
+ # unix:/tmp/.s.PGSQL.5432
161
+ def establish_connection(uri)
162
+ u = URI.parse(uri)
163
+ case u.scheme
164
+ when 'tcp'
165
+ @conn = TCPSocket.new(u.host || DEFAULT_HOST, u.port || DEFAULT_PORT)
166
+ when 'unix'
167
+ @conn = UNIXSocket.new(u.path)
168
+ else
169
+ raise 'unrecognized uri scheme format (must be tcp or unix)'
170
+ end
171
+ end
172
+ end
173
+
174
+ end # module PostgresPR
@@ -0,0 +1,542 @@
1
+ #
2
+ # Author:: Michael Neumann
3
+ # Copyright:: (c) 2005 by Michael Neumann
4
+ # License:: Same as Ruby's or BSD
5
+ #
6
+
7
+ require 'buffer'
8
+ class IO
9
+ def read_exactly_n_bytes(n)
10
+ buf = read(n)
11
+ raise EOFError if buf == nil
12
+ return buf if buf.size == n
13
+
14
+ n -= buf.size
15
+
16
+ while n > 0
17
+ str = read(n)
18
+ raise EOFError if str == nil
19
+ buf << str
20
+ n -= str.size
21
+ end
22
+ return buf
23
+ end
24
+ end
25
+
26
+ module PostgresPR
27
+
28
+ class ParseError < RuntimeError; end
29
+ class DumpError < RuntimeError; end
30
+
31
+
32
+ # Base class representing a PostgreSQL protocol message
33
+ class Message
34
+ # One character message-typecode to class map
35
+ MsgTypeMap = Hash.new { UnknownMessageType }
36
+
37
+ def self.register_message_type(type)
38
+ raise "duplicate message type registration" if MsgTypeMap.has_key?(type)
39
+
40
+ MsgTypeMap[type] = self
41
+
42
+ self.const_set(:MsgType, type)
43
+ class_eval "def message_type; MsgType end"
44
+ end
45
+
46
+ def self.read(stream, startup=false)
47
+ type = stream.read_exactly_n_bytes(1) unless startup
48
+ length = stream.read_exactly_n_bytes(4).unpack('N').first # FIXME: length should be signed, not unsigned
49
+
50
+ raise ParseError unless length >= 4
51
+
52
+ # initialize buffer
53
+ buffer = Buffer.of_size(startup ? length : 1+length)
54
+ buffer.write(type) unless startup
55
+ buffer.write_int32_network(length)
56
+ buffer.copy_from_stream(stream, length-4)
57
+
58
+ (startup ? StartupMessage : MsgTypeMap[type]).create(buffer)
59
+ end
60
+
61
+ def self.create(buffer)
62
+ obj = allocate
63
+ obj.parse(buffer)
64
+ obj
65
+ end
66
+
67
+ def self.dump(*args)
68
+ new(*args).dump
69
+ end
70
+
71
+ def dump(body_size=0)
72
+ buffer = Buffer.of_size(5 + body_size)
73
+ buffer.write(self.message_type)
74
+ buffer.write_int32_network(4 + body_size)
75
+ yield buffer if block_given?
76
+ raise DumpError unless buffer.at_end?
77
+ return buffer.content
78
+ end
79
+
80
+ def parse(buffer)
81
+ buffer.position = 5
82
+ yield buffer if block_given?
83
+ raise ParseError, buffer.inspect unless buffer.at_end?
84
+ end
85
+
86
+ def self.fields(*attribs)
87
+ names = attribs.map {|name, type| name.to_s}
88
+ arg_list = names.join(", ")
89
+ ivar_list = names.map {|name| "@" + name }.join(", ")
90
+ sym_list = names.map {|name| ":" + name }.join(", ")
91
+ class_eval %[
92
+ attr_accessor #{ sym_list }
93
+ def initialize(#{ arg_list })
94
+ #{ ivar_list } = #{ arg_list }
95
+ end
96
+ ]
97
+ end
98
+ end
99
+
100
+ class UnknownMessageType < Message
101
+ def dump
102
+ raise
103
+ end
104
+ end
105
+
106
+ class Authentification < Message
107
+ register_message_type 'R'
108
+
109
+ AuthTypeMap = Hash.new { UnknownAuthType }
110
+
111
+ def self.create(buffer)
112
+ buffer.position = 5
113
+ authtype = buffer.read_int32_network
114
+ klass = AuthTypeMap[authtype]
115
+ obj = klass.allocate
116
+ obj.parse(buffer)
117
+ obj
118
+ end
119
+
120
+ def self.register_auth_type(type)
121
+ raise "duplicate auth type registration" if AuthTypeMap.has_key?(type)
122
+ AuthTypeMap[type] = self
123
+ self.const_set(:AuthType, type)
124
+ class_eval "def auth_type() AuthType end"
125
+ end
126
+
127
+ # the dump method of class Message
128
+ alias message__dump dump
129
+
130
+ def dump
131
+ super(4) do |buffer|
132
+ buffer.write_int32_network(self.auth_type)
133
+ end
134
+ end
135
+
136
+ def parse(buffer)
137
+ super do
138
+ auth_t = buffer.read_int32_network
139
+ raise ParseError unless auth_t == self.auth_type
140
+ yield if block_given?
141
+ end
142
+ end
143
+ end
144
+
145
+ class UnknownAuthType < Authentification
146
+ end
147
+
148
+ class AuthentificationOk < Authentification
149
+ register_auth_type 0
150
+ end
151
+
152
+ class AuthentificationKerberosV4 < Authentification
153
+ register_auth_type 1
154
+ end
155
+
156
+ class AuthentificationKerberosV5 < Authentification
157
+ register_auth_type 2
158
+ end
159
+
160
+ class AuthentificationClearTextPassword < Authentification
161
+ register_auth_type 3
162
+ end
163
+
164
+ module SaltedAuthentificationMixin
165
+ attr_accessor :salt
166
+
167
+ def initialize(salt)
168
+ @salt = salt
169
+ end
170
+
171
+ def dump
172
+ raise DumpError unless @salt.size == self.salt_size
173
+
174
+ message__dump(4 + self.salt_size) do |buffer|
175
+ buffer.write_int32_network(self.auth_type)
176
+ buffer.write(@salt)
177
+ end
178
+ end
179
+
180
+ def parse(buffer)
181
+ super do
182
+ @salt = buffer.read(self.salt_size)
183
+ end
184
+ end
185
+ end
186
+
187
+ class AuthentificationCryptPassword < Authentification
188
+ register_auth_type 4
189
+ include SaltedAuthentificationMixin
190
+ def salt_size; 2 end
191
+ end
192
+
193
+
194
+ class AuthentificationMD5Password < Authentification
195
+ register_auth_type 5
196
+ include SaltedAuthentificationMixin
197
+ def salt_size; 4 end
198
+ end
199
+
200
+ class AuthentificationSCMCredential < Authentification
201
+ register_auth_type 6
202
+ end
203
+
204
+ class PasswordMessage < Message
205
+ register_message_type 'p'
206
+ fields :password
207
+
208
+ def dump
209
+ super(@password.size + 1) do |buffer|
210
+ buffer.write_cstring(@password)
211
+ end
212
+ end
213
+
214
+ def parse(buffer)
215
+ super do
216
+ @password = buffer.read_cstring
217
+ end
218
+ end
219
+ end
220
+
221
+ class ParameterStatus < Message
222
+ register_message_type 'S'
223
+ fields :key, :value
224
+
225
+ def dump
226
+ super(@key.size + 1 + @value.size + 1) do |buffer|
227
+ buffer.write_cstring(@key)
228
+ buffer.write_cstring(@value)
229
+ end
230
+ end
231
+
232
+ def parse(buffer)
233
+ super do
234
+ @key = buffer.read_cstring
235
+ @value = buffer.read_cstring
236
+ end
237
+ end
238
+ end
239
+
240
+ class BackendKeyData < Message
241
+ register_message_type 'K'
242
+ fields :process_id, :secret_key
243
+
244
+ def dump
245
+ super(4 + 4) do |buffer|
246
+ buffer.write_int32_network(@process_id)
247
+ buffer.write_int32_network(@secret_key)
248
+ end
249
+ end
250
+
251
+ def parse(buffer)
252
+ super do
253
+ @process_id = buffer.read_int32_network
254
+ @secret_key = buffer.read_int32_network
255
+ end
256
+ end
257
+ end
258
+
259
+ class ReadyForQuery < Message
260
+ register_message_type 'Z'
261
+ fields :backend_transaction_status_indicator
262
+
263
+ def dump
264
+ super(1) do |buffer|
265
+ buffer.write_byte(@backend_transaction_status_indicator)
266
+ end
267
+ end
268
+
269
+ def parse(buffer)
270
+ super do
271
+ @backend_transaction_status_indicator = buffer.read_byte
272
+ end
273
+ end
274
+ end
275
+
276
+ class DataRow < Message
277
+ register_message_type 'D'
278
+ fields :columns
279
+
280
+ def dump
281
+ sz = @columns.inject(2) {|sum, col| sum + 4 + (col ? col.size : 0)}
282
+ super(sz) do |buffer|
283
+ buffer.write_int16_network(@columns.size)
284
+ @columns.each {|col|
285
+ buffer.write_int32_network(col ? col.size : -1)
286
+ buffer.write(col) if col
287
+ }
288
+ end
289
+ end
290
+
291
+ def parse(buffer)
292
+ super do
293
+ n_cols = buffer.read_int16_network
294
+ @columns = (1..n_cols).collect {
295
+ len = buffer.read_int32_network
296
+ if len == -1
297
+ nil
298
+ else
299
+ buffer.read(len)
300
+ end
301
+ }
302
+ end
303
+ end
304
+ end
305
+
306
+ class CommandComplete < Message
307
+ register_message_type 'C'
308
+ fields :cmd_tag
309
+
310
+ def dump
311
+ super(@cmd_tag.size + 1) do |buffer|
312
+ buffer.write_cstring(@cmd_tag)
313
+ end
314
+ end
315
+
316
+ def parse(buffer)
317
+ super do
318
+ @cmd_tag = buffer.read_cstring
319
+ end
320
+ end
321
+ end
322
+
323
+ class EmptyQueryResponse < Message
324
+ register_message_type 'I'
325
+ end
326
+
327
+ module NoticeErrorMixin
328
+ attr_accessor :field_type, :field_values
329
+
330
+ def initialize(field_type=0, field_values=[])
331
+ raise ArgumentError if field_type == 0 and not field_values.empty?
332
+ @field_type, @field_values = field_type, field_values
333
+ end
334
+
335
+ def dump
336
+ raise ArgumentError if @field_type == 0 and not @field_values.empty?
337
+
338
+ sz = 1
339
+ sz += @field_values.inject(1) {|sum, fld| sum + fld.size + 1} unless @field_type == 0
340
+
341
+ super(sz) do |buffer|
342
+ buffer.write_byte(@field_type)
343
+ break if @field_type == 0
344
+ @field_values.each {|fld| buffer.write_cstring(fld) }
345
+ buffer.write_byte(0)
346
+ end
347
+ end
348
+
349
+ def parse(buffer)
350
+ super do
351
+ @field_type = buffer.read_byte
352
+ break if @field_type == 0
353
+ @field_values = []
354
+ while buffer.position < buffer.size-1
355
+ @field_values << buffer.read_cstring
356
+ end
357
+ terminator = buffer.read_byte
358
+ raise ParseError unless terminator == 0
359
+ end
360
+ end
361
+ end
362
+
363
+ class NoticeResponse < Message
364
+ register_message_type 'N'
365
+ include NoticeErrorMixin
366
+ end
367
+
368
+ class ErrorResponse < Message
369
+ register_message_type 'E'
370
+ include NoticeErrorMixin
371
+ end
372
+
373
+ # TODO
374
+ class CopyInResponse < Message
375
+ register_message_type 'G'
376
+ end
377
+
378
+ # TODO
379
+ class CopyOutResponse < Message
380
+ register_message_type 'H'
381
+ end
382
+
383
+ class Parse < Message
384
+ register_message_type 'P'
385
+ fields :query, :stmt_name, :parameter_oids
386
+
387
+ def initialize(query, stmt_name="", parameter_oids=[])
388
+ @query, @stmt_name, @parameter_oids = query, stmt_name, parameter_oids
389
+ end
390
+
391
+ def dump
392
+ sz = @stmt_name.size + 1 + @query.size + 1 + 2 + (4 * @parameter_oids.size)
393
+ super(sz) do |buffer|
394
+ buffer.write_cstring(@stmt_name)
395
+ buffer.write_cstring(@query)
396
+ buffer.write_int16_network(@parameter_oids.size)
397
+ @parameter_oids.each {|oid| buffer.write_int32_network(oid) }
398
+ end
399
+ end
400
+
401
+ def parse(buffer)
402
+ super do
403
+ @stmt_name = buffer.read_cstring
404
+ @query = buffer.read_cstring
405
+ n_oids = buffer.read_int16_network
406
+ @parameter_oids = (1..n_oids).collect {
407
+ # TODO: zero means unspecified. map to nil?
408
+ buffer.read_int32_network
409
+ }
410
+ end
411
+ end
412
+ end
413
+
414
+ class ParseComplete < Message
415
+ register_message_type '1'
416
+ end
417
+
418
+ class Query < Message
419
+ register_message_type 'Q'
420
+ fields :query
421
+
422
+ def dump
423
+ super(@query.size + 1) do |buffer|
424
+ buffer.write_cstring(@query)
425
+ end
426
+ end
427
+
428
+ def parse(buffer)
429
+ super do
430
+ @query = buffer.read_cstring
431
+ end
432
+ end
433
+ end
434
+
435
+ class RowDescription < Message
436
+ register_message_type 'T'
437
+ fields :fields
438
+
439
+ class FieldInfo < Struct.new(:name, :oid, :attr_nr, :type_oid, :typlen, :atttypmod, :formatcode); end
440
+
441
+ def dump
442
+ sz = @fields.inject(2) {|sum, fld| sum + 18 + fld.name.size + 1 }
443
+ super(sz) do |buffer|
444
+ buffer.write_int16_network(@fields.size)
445
+ @fields.each { |f|
446
+ buffer.write_cstring(f.name)
447
+ buffer.write_int32_network(f.oid)
448
+ buffer.write_int16_network(f.attr_nr)
449
+ buffer.write_int32_network(f.type_oid)
450
+ buffer.write_int16_network(f.typlen)
451
+ buffer.write_int32_network(f.atttypmod)
452
+ buffer.write_int16_network(f.formatcode)
453
+ }
454
+ end
455
+ end
456
+
457
+ def parse(buffer)
458
+ super do
459
+ n_fields = buffer.read_int16_network
460
+ @fields = (1..n_fields).collect {
461
+ f = FieldInfo.new
462
+ f.name = buffer.read_cstring
463
+ f.oid = buffer.read_int32_network
464
+ f.attr_nr = buffer.read_int16_network
465
+ f.type_oid = buffer.read_int32_network
466
+ f.typlen = buffer.read_int16_network
467
+ f.atttypmod = buffer.read_int32_network
468
+ f.formatcode = buffer.read_int16_network
469
+ f
470
+ }
471
+ end
472
+ end
473
+ end
474
+
475
+ class StartupMessage < Message
476
+ fields :proto_version, :params
477
+
478
+ def dump
479
+ sz = @params.inject(4 + 4) {|sum, kv| sum + kv[0].size + 1 + kv[1].size + 1} + 1
480
+
481
+ buffer = Buffer.of_size(sz)
482
+ buffer.write_int32_network(sz)
483
+ buffer.write_int32_network(@proto_version)
484
+ @params.each_pair {|key, value|
485
+ buffer.write_cstring(key)
486
+ buffer.write_cstring(value)
487
+ }
488
+ buffer.write_byte(0)
489
+
490
+ raise DumpError unless buffer.at_end?
491
+ return buffer.content
492
+ end
493
+
494
+ def parse(buffer)
495
+ buffer.position = 4
496
+
497
+ @proto_version = buffer.read_int32_network
498
+ @params = {}
499
+
500
+ while buffer.position < buffer.size-1
501
+ key = buffer.read_cstring
502
+ val = buffer.read_cstring
503
+ @params[key] = val
504
+ end
505
+
506
+ nul = buffer.read_byte
507
+ raise ParseError unless nul == 0
508
+ raise ParseError unless buffer.at_end?
509
+ end
510
+ end
511
+
512
+ class SSLRequest < Message
513
+ fields :ssl_request_code
514
+
515
+ def dump
516
+ sz = 4 + 4
517
+ buffer = Buffer.of_size(sz)
518
+ buffer.write_int32_network(sz)
519
+ buffer.write_int32_network(@ssl_request_code)
520
+ raise DumpError unless buffer.at_end?
521
+ return buffer.content
522
+ end
523
+
524
+ def parse(buffer)
525
+ buffer.position = 4
526
+ @ssl_request_code = buffer.read_int32_network
527
+ raise ParseError unless buffer.at_end?
528
+ end
529
+ end
530
+
531
+ =begin
532
+ # TODO: duplicate message-type, split into client/server messages
533
+ class Sync < Message
534
+ register_message_type 'S'
535
+ end
536
+ =end
537
+
538
+ class Terminate < Message
539
+ register_message_type 'X'
540
+ end
541
+
542
+ end # module PostgresPR