jeremyevans-postgres-pr 0.6.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,544 @@
1
+ #
2
+ # Author:: Michael Neumann
3
+ # Copyright:: (c) 2005 by Michael Neumann
4
+ # License:: Same as Ruby's or BSD
5
+ #
6
+
7
+ require 'buffer'
8
+ class IO
9
+ def read_exactly_n_bytes(n)
10
+ buf = read(n)
11
+ raise EOFError if buf == nil
12
+ return buf if buf.size == n
13
+
14
+ n -= buf.size
15
+
16
+ while n > 0
17
+ str = read(n)
18
+ raise EOFError if str == nil
19
+ buf << str
20
+ n -= str.size
21
+ end
22
+ return buf
23
+ end
24
+ end
25
+
26
+
27
+ module PostgresPR
28
+
29
+ class PGError < StandardError; end
30
+ class ParseError < PGError; end
31
+ class DumpError < PGError; end
32
+
33
+
34
+ # Base class representing a PostgreSQL protocol message
35
+ class Message
36
+ # One character message-typecode to class map
37
+ MsgTypeMap = Hash.new { UnknownMessageType }
38
+
39
+ def self.register_message_type(type)
40
+ raise(PGError, "duplicate message type registration") if MsgTypeMap.has_key?(type)
41
+
42
+ MsgTypeMap[type] = self
43
+
44
+ self.const_set(:MsgType, type)
45
+ class_eval "def message_type; MsgType end"
46
+ end
47
+
48
+ def self.read(stream, startup=false)
49
+ type = stream.read_exactly_n_bytes(1) unless startup
50
+ length = stream.read_exactly_n_bytes(4).unpack('N').first # FIXME: length should be signed, not unsigned
51
+
52
+ raise ParseError unless length >= 4
53
+
54
+ # initialize buffer
55
+ buffer = Buffer.of_size(startup ? length : 1+length)
56
+ buffer.write(type) unless startup
57
+ buffer.write_int32_network(length)
58
+ buffer.copy_from_stream(stream, length-4)
59
+
60
+ (startup ? StartupMessage : MsgTypeMap[type]).create(buffer)
61
+ end
62
+
63
+ def self.create(buffer)
64
+ obj = allocate
65
+ obj.parse(buffer)
66
+ obj
67
+ end
68
+
69
+ def self.dump(*args)
70
+ new(*args).dump
71
+ end
72
+
73
+ def dump(body_size=0)
74
+ buffer = Buffer.of_size(5 + body_size)
75
+ buffer.write(self.message_type)
76
+ buffer.write_int32_network(4 + body_size)
77
+ yield buffer if block_given?
78
+ raise DumpError unless buffer.at_end?
79
+ return buffer.content
80
+ end
81
+
82
+ def parse(buffer)
83
+ buffer.position = 5
84
+ yield buffer if block_given?
85
+ raise ParseError, buffer.inspect unless buffer.at_end?
86
+ end
87
+
88
+ def self.fields(*attribs)
89
+ names = attribs.map {|name, type| name.to_s}
90
+ arg_list = names.join(", ")
91
+ ivar_list = names.map {|name| "@" + name }.join(", ")
92
+ sym_list = names.map {|name| ":" + name }.join(", ")
93
+ class_eval %[
94
+ attr_accessor #{ sym_list }
95
+ def initialize(#{ arg_list })
96
+ #{ ivar_list } = #{ arg_list }
97
+ end
98
+ ]
99
+ end
100
+ end
101
+
102
+ class UnknownMessageType < Message
103
+ def dump
104
+ raise PGError, "dumping unknown message"
105
+ end
106
+ end
107
+
108
+ class Authentification < Message
109
+ register_message_type 'R'
110
+
111
+ AuthTypeMap = Hash.new { UnknownAuthType }
112
+
113
+ def self.create(buffer)
114
+ buffer.position = 5
115
+ authtype = buffer.read_int32_network
116
+ klass = AuthTypeMap[authtype]
117
+ obj = klass.allocate
118
+ obj.parse(buffer)
119
+ obj
120
+ end
121
+
122
+ def self.register_auth_type(type)
123
+ raise(PGError, "duplicate auth type registration") if AuthTypeMap.has_key?(type)
124
+ AuthTypeMap[type] = self
125
+ self.const_set(:AuthType, type)
126
+ class_eval "def auth_type() AuthType end"
127
+ end
128
+
129
+ # the dump method of class Message
130
+ alias message__dump dump
131
+
132
+ def dump
133
+ super(4) do |buffer|
134
+ buffer.write_int32_network(self.auth_type)
135
+ end
136
+ end
137
+
138
+ def parse(buffer)
139
+ super do
140
+ auth_t = buffer.read_int32_network
141
+ raise ParseError unless auth_t == self.auth_type
142
+ yield if block_given?
143
+ end
144
+ end
145
+ end
146
+
147
+ class UnknownAuthType < Authentification
148
+ end
149
+
150
+ class AuthentificationOk < Authentification
151
+ register_auth_type 0
152
+ end
153
+
154
+ class AuthentificationKerberosV4 < Authentification
155
+ register_auth_type 1
156
+ end
157
+
158
+ class AuthentificationKerberosV5 < Authentification
159
+ register_auth_type 2
160
+ end
161
+
162
+ class AuthentificationClearTextPassword < Authentification
163
+ register_auth_type 3
164
+ end
165
+
166
+ module SaltedAuthentificationMixin
167
+ attr_accessor :salt
168
+
169
+ def initialize(salt)
170
+ @salt = salt
171
+ end
172
+
173
+ def dump
174
+ raise DumpError unless @salt.size == self.salt_size
175
+
176
+ message__dump(4 + self.salt_size) do |buffer|
177
+ buffer.write_int32_network(self.auth_type)
178
+ buffer.write(@salt)
179
+ end
180
+ end
181
+
182
+ def parse(buffer)
183
+ super do
184
+ @salt = buffer.read(self.salt_size)
185
+ end
186
+ end
187
+ end
188
+
189
+ class AuthentificationCryptPassword < Authentification
190
+ register_auth_type 4
191
+ include SaltedAuthentificationMixin
192
+ def salt_size; 2 end
193
+ end
194
+
195
+
196
+ class AuthentificationMD5Password < Authentification
197
+ register_auth_type 5
198
+ include SaltedAuthentificationMixin
199
+ def salt_size; 4 end
200
+ end
201
+
202
+ class AuthentificationSCMCredential < Authentification
203
+ register_auth_type 6
204
+ end
205
+
206
+ class PasswordMessage < Message
207
+ register_message_type 'p'
208
+ fields :password
209
+
210
+ def dump
211
+ super(@password.size + 1) do |buffer|
212
+ buffer.write_cstring(@password)
213
+ end
214
+ end
215
+
216
+ def parse(buffer)
217
+ super do
218
+ @password = buffer.read_cstring
219
+ end
220
+ end
221
+ end
222
+
223
+ class ParameterStatus < Message
224
+ register_message_type 'S'
225
+ fields :key, :value
226
+
227
+ def dump
228
+ super(@key.size + 1 + @value.size + 1) do |buffer|
229
+ buffer.write_cstring(@key)
230
+ buffer.write_cstring(@value)
231
+ end
232
+ end
233
+
234
+ def parse(buffer)
235
+ super do
236
+ @key = buffer.read_cstring
237
+ @value = buffer.read_cstring
238
+ end
239
+ end
240
+ end
241
+
242
+ class BackendKeyData < Message
243
+ register_message_type 'K'
244
+ fields :process_id, :secret_key
245
+
246
+ def dump
247
+ super(4 + 4) do |buffer|
248
+ buffer.write_int32_network(@process_id)
249
+ buffer.write_int32_network(@secret_key)
250
+ end
251
+ end
252
+
253
+ def parse(buffer)
254
+ super do
255
+ @process_id = buffer.read_int32_network
256
+ @secret_key = buffer.read_int32_network
257
+ end
258
+ end
259
+ end
260
+
261
+ class ReadyForQuery < Message
262
+ register_message_type 'Z'
263
+ fields :backend_transaction_status_indicator
264
+
265
+ def dump
266
+ super(1) do |buffer|
267
+ buffer.write_byte(@backend_transaction_status_indicator)
268
+ end
269
+ end
270
+
271
+ def parse(buffer)
272
+ super do
273
+ @backend_transaction_status_indicator = buffer.read_byte
274
+ end
275
+ end
276
+ end
277
+
278
+ class DataRow < Message
279
+ register_message_type 'D'
280
+ fields :columns
281
+
282
+ def dump
283
+ sz = @columns.inject(2) {|sum, col| sum + 4 + (col ? col.size : 0)}
284
+ super(sz) do |buffer|
285
+ buffer.write_int16_network(@columns.size)
286
+ @columns.each {|col|
287
+ buffer.write_int32_network(col ? col.size : -1)
288
+ buffer.write(col) if col
289
+ }
290
+ end
291
+ end
292
+
293
+ def parse(buffer)
294
+ super do
295
+ n_cols = buffer.read_int16_network
296
+ @columns = (1..n_cols).collect {
297
+ len = buffer.read_int32_network
298
+ if len == -1
299
+ nil
300
+ else
301
+ buffer.read(len)
302
+ end
303
+ }
304
+ end
305
+ end
306
+ end
307
+
308
+ class CommandComplete < Message
309
+ register_message_type 'C'
310
+ fields :cmd_tag
311
+
312
+ def dump
313
+ super(@cmd_tag.size + 1) do |buffer|
314
+ buffer.write_cstring(@cmd_tag)
315
+ end
316
+ end
317
+
318
+ def parse(buffer)
319
+ super do
320
+ @cmd_tag = buffer.read_cstring
321
+ end
322
+ end
323
+ end
324
+
325
+ class EmptyQueryResponse < Message
326
+ register_message_type 'I'
327
+ end
328
+
329
+ module NoticeErrorMixin
330
+ attr_accessor :field_type, :field_values
331
+
332
+ def initialize(field_type=0, field_values=[])
333
+ raise PGError if field_type == 0 and not field_values.empty?
334
+ @field_type, @field_values = field_type, field_values
335
+ end
336
+
337
+ def dump
338
+ raise PGError if @field_type == 0 and not @field_values.empty?
339
+
340
+ sz = 1
341
+ sz += @field_values.inject(1) {|sum, fld| sum + fld.size + 1} unless @field_type == 0
342
+
343
+ super(sz) do |buffer|
344
+ buffer.write_byte(@field_type)
345
+ break if @field_type == 0
346
+ @field_values.each {|fld| buffer.write_cstring(fld) }
347
+ buffer.write_byte(0)
348
+ end
349
+ end
350
+
351
+ def parse(buffer)
352
+ super do
353
+ @field_type = buffer.read_byte
354
+ break if @field_type == 0
355
+ @field_values = []
356
+ while buffer.position < buffer.size-1
357
+ @field_values << buffer.read_cstring
358
+ end
359
+ terminator = buffer.read_byte
360
+ raise ParseError unless terminator == 0
361
+ end
362
+ end
363
+ end
364
+
365
+ class NoticeResponse < Message
366
+ register_message_type 'N'
367
+ include NoticeErrorMixin
368
+ end
369
+
370
+ class ErrorResponse < Message
371
+ register_message_type 'E'
372
+ include NoticeErrorMixin
373
+ end
374
+
375
+ # TODO
376
+ class CopyInResponse < Message
377
+ register_message_type 'G'
378
+ end
379
+
380
+ # TODO
381
+ class CopyOutResponse < Message
382
+ register_message_type 'H'
383
+ end
384
+
385
+ class Parse < Message
386
+ register_message_type 'P'
387
+ fields :query, :stmt_name, :parameter_oids
388
+
389
+ def initialize(query, stmt_name="", parameter_oids=[])
390
+ @query, @stmt_name, @parameter_oids = query, stmt_name, parameter_oids
391
+ end
392
+
393
+ def dump
394
+ sz = @stmt_name.size + 1 + @query.size + 1 + 2 + (4 * @parameter_oids.size)
395
+ super(sz) do |buffer|
396
+ buffer.write_cstring(@stmt_name)
397
+ buffer.write_cstring(@query)
398
+ buffer.write_int16_network(@parameter_oids.size)
399
+ @parameter_oids.each {|oid| buffer.write_int32_network(oid) }
400
+ end
401
+ end
402
+
403
+ def parse(buffer)
404
+ super do
405
+ @stmt_name = buffer.read_cstring
406
+ @query = buffer.read_cstring
407
+ n_oids = buffer.read_int16_network
408
+ @parameter_oids = (1..n_oids).collect {
409
+ # TODO: zero means unspecified. map to nil?
410
+ buffer.read_int32_network
411
+ }
412
+ end
413
+ end
414
+ end
415
+
416
+ class ParseComplete < Message
417
+ register_message_type '1'
418
+ end
419
+
420
+ class Query < Message
421
+ register_message_type 'Q'
422
+ fields :query
423
+
424
+ def dump
425
+ super(@query.size + 1) do |buffer|
426
+ buffer.write_cstring(@query)
427
+ end
428
+ end
429
+
430
+ def parse(buffer)
431
+ super do
432
+ @query = buffer.read_cstring
433
+ end
434
+ end
435
+ end
436
+
437
+ class RowDescription < Message
438
+ register_message_type 'T'
439
+ fields :fields
440
+
441
+ class FieldInfo < Struct.new(:name, :oid, :attr_nr, :type_oid, :typlen, :atttypmod, :formatcode); end
442
+
443
+ def dump
444
+ sz = @fields.inject(2) {|sum, fld| sum + 18 + fld.name.size + 1 }
445
+ super(sz) do |buffer|
446
+ buffer.write_int16_network(@fields.size)
447
+ @fields.each { |f|
448
+ buffer.write_cstring(f.name)
449
+ buffer.write_int32_network(f.oid)
450
+ buffer.write_int16_network(f.attr_nr)
451
+ buffer.write_int32_network(f.type_oid)
452
+ buffer.write_int16_network(f.typlen)
453
+ buffer.write_int32_network(f.atttypmod)
454
+ buffer.write_int16_network(f.formatcode)
455
+ }
456
+ end
457
+ end
458
+
459
+ def parse(buffer)
460
+ super do
461
+ n_fields = buffer.read_int16_network
462
+ @fields = (1..n_fields).collect {
463
+ f = FieldInfo.new
464
+ f.name = buffer.read_cstring
465
+ f.oid = buffer.read_int32_network
466
+ f.attr_nr = buffer.read_int16_network
467
+ f.type_oid = buffer.read_int32_network
468
+ f.typlen = buffer.read_int16_network
469
+ f.atttypmod = buffer.read_int32_network
470
+ f.formatcode = buffer.read_int16_network
471
+ f
472
+ }
473
+ end
474
+ end
475
+ end
476
+
477
+ class StartupMessage < Message
478
+ fields :proto_version, :params
479
+
480
+ def dump
481
+ sz = @params.inject(4 + 4) {|sum, kv| sum + kv[0].size + 1 + kv[1].size + 1} + 1
482
+
483
+ buffer = Buffer.of_size(sz)
484
+ buffer.write_int32_network(sz)
485
+ buffer.write_int32_network(@proto_version)
486
+ @params.each_pair {|key, value|
487
+ buffer.write_cstring(key)
488
+ buffer.write_cstring(value)
489
+ }
490
+ buffer.write_byte(0)
491
+
492
+ raise DumpError unless buffer.at_end?
493
+ return buffer.content
494
+ end
495
+
496
+ def parse(buffer)
497
+ buffer.position = 4
498
+
499
+ @proto_version = buffer.read_int32_network
500
+ @params = {}
501
+
502
+ while buffer.position < buffer.size-1
503
+ key = buffer.read_cstring
504
+ val = buffer.read_cstring
505
+ @params[key] = val
506
+ end
507
+
508
+ nul = buffer.read_byte
509
+ raise ParseError unless nul == 0
510
+ raise ParseError unless buffer.at_end?
511
+ end
512
+ end
513
+
514
+ class SSLRequest < Message
515
+ fields :ssl_request_code
516
+
517
+ def dump
518
+ sz = 4 + 4
519
+ buffer = Buffer.of_size(sz)
520
+ buffer.write_int32_network(sz)
521
+ buffer.write_int32_network(@ssl_request_code)
522
+ raise DumpError unless buffer.at_end?
523
+ return buffer.content
524
+ end
525
+
526
+ def parse(buffer)
527
+ buffer.position = 4
528
+ @ssl_request_code = buffer.read_int32_network
529
+ raise ParseError unless buffer.at_end?
530
+ end
531
+ end
532
+
533
+ =begin
534
+ # TODO: duplicate message-type, split into client/server messages
535
+ class Sync < Message
536
+ register_message_type 'S'
537
+ end
538
+ =end
539
+
540
+ class Terminate < Message
541
+ register_message_type 'X'
542
+ end
543
+
544
+ end # module PostgresPR
@@ -0,0 +1,140 @@
1
+ # This is a compatibility layer for using the pure Ruby postgres-pr instead of
2
+ # the C interface of postgres.
3
+
4
+ require 'postgres-pr/connection'
5
+
6
+ class PGconn
7
+ class << self
8
+ alias connect new
9
+ end
10
+
11
+ def initialize(host, port, options, tty, database, user, auth)
12
+ uri =
13
+ if host.nil?
14
+ nil
15
+ elsif host[0] != ?/
16
+ "tcp://#{ host }:#{ port }"
17
+ else
18
+ "unix:#{ host }/.s.PGSQL.#{ port }"
19
+ end
20
+ @host = host
21
+ @db = database
22
+ @user = user
23
+ @conn = PostgresPR::Connection.new(database, user, auth, uri)
24
+ end
25
+
26
+ def close
27
+ @conn.close
28
+ end
29
+
30
+ attr_reader :host, :db, :user
31
+
32
+ def query(sql)
33
+ PGresult.new(@conn.query(sql))
34
+ end
35
+
36
+ alias exec query
37
+
38
+ def transaction_status
39
+ @conn.transaction_status
40
+ end
41
+
42
+ def self.escape(str)
43
+ str.gsub("'","''").gsub("\\", "\\\\\\\\")
44
+ end
45
+
46
+ end
47
+
48
+ class PGresult
49
+ include Enumerable
50
+
51
+ EMPTY_QUERY = 0
52
+ COMMAND_OK = 1
53
+ TUPLES_OK = 2
54
+ COPY_OUT = 3
55
+ COPY_IN = 4
56
+ BAD_RESPONSE = 5
57
+ NONFATAL_ERROR = 6
58
+ FATAL_ERROR = 7
59
+
60
+ def each(&block)
61
+ @result.each(&block)
62
+ end
63
+
64
+ def [](index)
65
+ @result[index]
66
+ end
67
+
68
+ def initialize(res)
69
+ @res = res
70
+ @fields = @res.fields.map {|f| f.name}
71
+ @result = @res.rows
72
+ end
73
+
74
+ # TODO: status, getlength, cmdstatus
75
+
76
+ attr_reader :result, :fields
77
+
78
+ def num_tuples
79
+ @result.size
80
+ end
81
+
82
+ def num_fields
83
+ @fields.size
84
+ end
85
+
86
+ def fieldname(index)
87
+ @fields[index]
88
+ end
89
+
90
+ def fieldnum(name)
91
+ @fields.index(name)
92
+ end
93
+
94
+ def type(index)
95
+ # TODO: correct?
96
+ @res.fields[index].type_oid
97
+ end
98
+
99
+ def size(index)
100
+ raise PGError, 'size not implemented'
101
+ # TODO: correct?
102
+ @res.fields[index].typlen
103
+ end
104
+
105
+ def getvalue(tup_num, field_num)
106
+ @result[tup_num][field_num]
107
+ end
108
+
109
+ def status
110
+ if num_tuples > 0
111
+ TUPLES_OK
112
+ else
113
+ COMMAND_OK
114
+ end
115
+ end
116
+
117
+ def cmdstatus
118
+ @res.cmd_tag || ''
119
+ end
120
+
121
+ # free the result set
122
+ def clear
123
+ @res = @fields = @result = nil
124
+ end
125
+
126
+ # Returns the number of rows affected by the SQL command
127
+ def cmdtuples
128
+ case @res.cmd_tag
129
+ when nil
130
+ return nil
131
+ when /^INSERT\s+(\d+)\s+(\d+)$/, /^(DELETE|UPDATE|MOVE|FETCH)\s+(\d+)$/
132
+ $2.to_i
133
+ else
134
+ nil
135
+ end
136
+ end
137
+
138
+ end
139
+
140
+ PGError = PostgresPR::PGError
@@ -0,0 +1,18 @@
1
+ require 'test/unit'
2
+ require 'conv'
3
+ require 'array'
4
+ require 'bytea'
5
+
6
+ class TC_Conversion < Test::Unit::TestCase
7
+ def test_decode_array
8
+ assert_equal ["abcdef ", "hallo", ["1", "2"]], decode_array("{ abcdef , hallo, { 1, 2} }")
9
+ assert_equal [""], decode_array("{ }") # TODO: Correct?
10
+ assert_equal [], decode_array("{}")
11
+ assert_equal ["hallo", ""], decode_array("{hallo,}")
12
+ end
13
+
14
+ def test_bytea
15
+ end
16
+
17
+ include Postgres::Conversion
18
+ end