rhubarb 0.2.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,545 @@
1
+ # Rhubarb is a simple persistence layer for Ruby objects and SQLite.
2
+ # For now, see the test cases for example usage.
3
+ #
4
+ # Copyright (c) 2009 Red Hat, Inc.
5
+ #
6
+ # Author: William Benton (willb@redhat.com)
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+
14
+ require 'rubygems'
15
+ require 'set'
16
+ require 'time'
17
+ require 'sqlite3'
18
+
19
+ module Rhubarb
20
+
21
+ module SQLBUtil
22
+ def self.timestamp(tm=nil)
23
+ tm ||= Time.now.utc
24
+ (tm.tv_sec * 1000000) + tm.tv_usec
25
+ end
26
+ end
27
+
28
+
29
+ module Persistence
30
+ class DbCollection < Hash
31
+ alias orig_set []=
32
+
33
+ def []=(k,v)
34
+ v.results_as_hash = true if v
35
+ v.type_translation = true if v
36
+ orig_set(k,v)
37
+ end
38
+ end
39
+
40
+ @@dbs = DbCollection.new
41
+
42
+ def self.open(filename, which=:default)
43
+ dbs[which] = SQLite3::Database.new(filename)
44
+ end
45
+
46
+ def self.close(which=:default)
47
+ if dbs[which]
48
+ dbs[which].close
49
+ dbs.delete(which)
50
+ end
51
+ end
52
+
53
+ def self.db
54
+ dbs[:default]
55
+ end
56
+
57
+ def self.db=(d)
58
+ dbs[:default] = d
59
+ end
60
+
61
+ def self.dbs
62
+ @@dbs
63
+ end
64
+ end
65
+
66
+ class Column
67
+ attr_reader :name
68
+
69
+ def initialize(name, kind, quals)
70
+ @name, @kind = name, kind
71
+ @quals = quals.map {|x| x.to_s.gsub("_", " ") if x.class == Symbol}
72
+ @quals.map
73
+ end
74
+
75
+ def to_s
76
+ qualifiers = @quals.join(" ")
77
+ if qualifiers == ""
78
+ "#@name #@kind"
79
+ else
80
+ "#@name #@kind #{qualifiers}"
81
+ end
82
+ end
83
+ end
84
+
85
+ class Reference
86
+ attr_reader :referent, :column, :options
87
+
88
+ # Creates a new Reference object, modeling a foreign-key relationship to another table. +klass+ is a class that includes Persisting; +options+ is a hash of options, which include
89
+ # +:column => + _name_:: specifies the name of the column to reference in +klass+ (defaults to row id)
90
+ # +:on_delete => :cascade+:: specifies that deleting the referenced row in +klass+ will delete all rows referencing that row through this reference
91
+ def initialize(klass, options={})
92
+ @referent = klass
93
+ @options = options
94
+ @options[:column] ||= "row_id"
95
+ @column = options[:column]
96
+ end
97
+
98
+ def to_s
99
+ trigger = ""
100
+ trigger = " on delete cascade" if options[:on_delete] == :cascade
101
+ "references #{@referent}(#{@column})#{trigger}"
102
+ end
103
+
104
+ def managed_ref?
105
+ # XXX?
106
+ return false if referent.class == String
107
+ referent.ancestors.include? Persisting
108
+ end
109
+ end
110
+
111
+
112
+ # Methods mixed in to the class object of a persisting class
113
+ module PersistingClassMixins
114
+ # Returns the name of the database table modeled by this class.
115
+ # Defaults to the name of the class (sans module names)
116
+ def table_name
117
+ @table_name ||= self.name.split("::").pop.downcase
118
+ end
119
+
120
+ # Enables setting the table name to a custom name
121
+ def declare_table_name(nm)
122
+ @table_name = nm
123
+ end
124
+
125
+ # Models a foreign-key relationship. +options+ is a hash of options, which include
126
+ # +:column => + _name_:: specifies the name of the column to reference in +klass+ (defaults to row id)
127
+ # +:on_delete => :cascade+:: specifies that deleting the referenced row in +klass+ will delete all rows referencing that row through this reference
128
+ def references(table, options={})
129
+ Reference.new(table, options)
130
+ end
131
+
132
+ # Models a CHECK constraint.
133
+ def check(condition)
134
+ "check (#{condition})"
135
+ end
136
+
137
+ # Returns an object corresponding to the row with the given ID, or +nil+ if no such row exists.
138
+ def find(id)
139
+ tup = self.find_tuple(id)
140
+ return self.new(tup) if tup
141
+ nil
142
+ end
143
+
144
+ alias find_by_id find
145
+
146
+ def find_by(arg_hash)
147
+ arg_hash = arg_hash.dup
148
+ valid_cols = self.colnames.intersection arg_hash.keys
149
+ select_criteria = valid_cols.map {|col| "#{col.to_s} = #{col.inspect}"}.join(" AND ")
150
+ arg_hash.each {|k,v| arg_hash[k] = v.row_id if v.respond_to? :row_id}
151
+
152
+ self.db.execute("select * from #{table_name} where #{select_criteria} order by row_id", arg_hash).map {|tup| self.new(tup) }
153
+ end
154
+
155
+ # args contains the following keys
156
+ # * :group_by maps to a list of columns to group by (mandatory)
157
+ # * :select_by maps to a hash mapping from column symbols to values (optional)
158
+ # * :version maps to some version to be considered "current" for the purposes of this query; that is, all rows later than the "current" version will be disregarded (optional, defaults to latest version)
159
+ def find_freshest(args)
160
+ args = args.dup
161
+
162
+ args[:version] ||= SQLBUtil::timestamp
163
+ args[:select_by] ||= {}
164
+
165
+ query_params = {}
166
+ query_params[:version] = args[:version]
167
+
168
+ select_clauses = ["created <= :version"]
169
+
170
+ valid_cols = self.colnames.intersection args[:select_by].keys
171
+
172
+ valid_cols.map do |col|
173
+ select_clauses << "#{col.to_s} = #{col.inspect}"
174
+ val = args[:select_by][col]
175
+ val = val.row_id if val.respond_to? :row_id
176
+ query_params[col] = val
177
+ end
178
+
179
+ group_by_clause = "GROUP BY " + args[:group_by].join(", ")
180
+ where_clause = "WHERE " + select_clauses.join(" AND ")
181
+ projection = self.colnames - [:created]
182
+ join_clause = projection.map do |column|
183
+ "__fresh.#{column} = __freshest.#{column}"
184
+ end
185
+
186
+ projection << "MAX(created) AS __current_version"
187
+ join_clause << "__fresh.__current_version = __freshest.created"
188
+
189
+ query = "
190
+ SELECT __freshest.* FROM (
191
+ SELECT #{projection.to_a.join(', ')} FROM (
192
+ SELECT * from #{table_name} #{where_clause}
193
+ ) #{group_by_clause}
194
+ ) as __fresh INNER JOIN #{table_name} as __freshest ON
195
+ #{join_clause.join(' AND ')}
196
+ ORDER BY row_id
197
+ "
198
+
199
+ self.db.execute(query, query_params).map {|tup| self.new(tup) }
200
+ end
201
+
202
+ # Does what it says on the tin. Since this will allocate an object for each row, it isn't recomended for huge tables.
203
+ def find_all
204
+ self.db.execute("SELECT * from #{table_name}").map {|tup| self.new(tup)}
205
+ end
206
+
207
+ def delete_all
208
+ self.db.execute("DELETE from #{table_name}")
209
+ end
210
+
211
+ # Declares a query method named +name+ and adds it to this class. The query method returns a list of objects corresponding to the rows returned by executing "+SELECT * FROM+ _table_ +WHERE+ _query_" on the database.
212
+ def declare_query(name, query)
213
+ klass = (class << self; self end)
214
+ klass.class_eval do
215
+ define_method name.to_s do |*args|
216
+ # handle reference parameters
217
+ args = args.map {|x| (x.row_id if x.class.ancestors.include? Persisting) or x}
218
+
219
+ res = self.db.execute("select * from #{table_name} where #{query}", args)
220
+ res.map {|row| self.new(row)}
221
+ end
222
+ end
223
+ end
224
+
225
+ # Declares a custom query method named +name+, and adds it to this class. The custom query method returns a list of objects corresponding to the rows returned by executing +query+ on the database. +query+ should select all fields (with +SELECT *+). If +query+ includes the string +\_\_TABLE\_\_+, it will be expanded to the table name. Typically, you will want to use +declare\_query+ instead; this method is most useful for self-joins.
226
+ def declare_custom_query(name, query)
227
+ klass = (class << self; self end)
228
+ klass.class_eval do
229
+ define_method name.to_s do |*args|
230
+ # handle reference parameters
231
+ args = args.map {|x| (x.row_id if x.class.ancestors.include? Persisting) or x}
232
+
233
+ res = self.db.execute(query.gsub("__TABLE__", "#{self.table_name}"), args)
234
+ # XXX: should freshen each row?
235
+ res.map {|row| self.new(row) }
236
+ end
237
+ end
238
+ end
239
+
240
+ def declare_index_on(*fields)
241
+ @creation_callbacks << Proc.new do
242
+ idx_name = "idx_#{self.table_name}__#{fields.join('__')}__#{@creation_callbacks.size}"
243
+ creation_cmd = "create index #{idx_name} on #{self.table_name} (#{fields.join(', ')})"
244
+ self.db.execute(creation_cmd)
245
+ end if fields.size > 0
246
+ end
247
+
248
+ # Adds a column named +cname+ to this table declaration, and adds the following methods to the class:
249
+ # * accessors for +cname+, called +cname+ and +cname=+
250
+ # * +find\_by\_cname+ and +find\_first\_by\_cname+ methods, which return a list of rows and the first row that have the given value for +cname+, respectively
251
+ # If the column references a column in another table (given via a +references(...)+ argument to +quals+), then add triggers to the database to ensure referential integrity and cascade-on-delete (if specified)
252
+ def declare_column(cname, kind, *quals)
253
+ ensure_accessors
254
+
255
+ find_method_name = "find_by_#{cname}".to_sym
256
+ find_first_method_name = "find_first_by_#{cname}".to_sym
257
+
258
+ get_method_name = "#{cname}".to_sym
259
+ set_method_name = "#{cname}=".to_sym
260
+
261
+ # does this column reference another table?
262
+ rf = quals.find {|q| q.class == Reference}
263
+ if rf
264
+ self.refs[cname] = rf
265
+ end
266
+
267
+ # add a find for this column (a class method)
268
+ klass = (class << self; self end)
269
+ klass.class_eval do
270
+ define_method find_method_name do |arg|
271
+ res = self.db.execute("select * from #{table_name} where #{cname} = ?", arg)
272
+ res.map {|row| self.new(row)}
273
+ end
274
+
275
+ define_method find_first_method_name do |arg|
276
+ res = self.db.execute("select * from #{table_name} where #{cname} = ?", arg)
277
+ return self.new(res[0]) if res.size > 0
278
+ nil
279
+ end
280
+ end
281
+
282
+ self.colnames.merge([cname])
283
+ self.columns << Column.new(cname, kind, quals)
284
+
285
+ # add accessors
286
+ define_method get_method_name do
287
+ freshen
288
+ return @tuple["#{cname}"] if @tuple
289
+ nil
290
+ end
291
+
292
+ if not rf
293
+ define_method set_method_name do |arg|
294
+ @tuple["#{cname}"] = arg
295
+ update cname, arg
296
+ end
297
+ else
298
+ # this column references another table; create a set
299
+ # method that can handle either row objects or row IDs
300
+ define_method set_method_name do |arg|
301
+ freshen
302
+
303
+ arg_id = nil
304
+
305
+ if arg.class == Fixnum
306
+ arg_id = arg
307
+ arg = rf.referent.find arg_id
308
+ else
309
+ arg_id = arg.row_id
310
+ end
311
+ @tuple["#{cname}"] = arg
312
+
313
+ update cname, arg_id
314
+ end
315
+
316
+ # Finally, add appropriate triggers to ensure referential integrity.
317
+ # If rf has an on_delete trigger, also add the necessary
318
+ # triggers to cascade deletes.
319
+ # Note that we do not support update triggers, since the API does
320
+ # not expose the capacity to change row IDs.
321
+
322
+ self.creation_callbacks << Proc.new do
323
+ @ccount ||= 0
324
+
325
+ insert_trigger_name = "ri_insert_#{self.table_name}_#{@ccount}_#{rf.referent.table_name}"
326
+ delete_trigger_name = "ri_delete_#{self.table_name}_#{@ccount}_#{rf.referent.table_name}"
327
+
328
+ self.db.execute_batch("CREATE TRIGGER #{insert_trigger_name} BEFORE INSERT ON \"#{self.table_name}\" WHEN new.\"#{cname}\" IS NOT NULL AND NOT EXISTS (SELECT 1 FROM \"#{rf.referent.table_name}\" WHERE new.\"#{cname}\" == \"#{rf.column}\") BEGIN SELECT RAISE(ABORT, 'constraint #{insert_trigger_name} (#{rf.referent.table_name} missing foreign key row) failed'); END;")
329
+
330
+ self.db.execute_batch("CREATE TRIGGER #{delete_trigger_name} BEFORE DELETE ON \"#{rf.referent.table_name}\" WHEN EXISTS (SELECT 1 FROM \"#{self.table_name}\" WHERE old.\"#{rf.column}\" == \"#{cname}\") BEGIN DELETE FROM \"#{self.table_name}\" WHERE \"#{cname}\" = old.\"#{rf.column}\"; END;") if rf.options[:on_delete] == :cascade
331
+
332
+ @ccount = @ccount + 1
333
+ end
334
+ end
335
+ end
336
+
337
+ # Declares a constraint. Only check constraints are supported; see
338
+ # the check method.
339
+ def declare_constraint(cname, kind, *details)
340
+ ensure_accessors
341
+ info = details.join(" ")
342
+ @constraints << "constraint #{cname} #{kind} #{info}"
343
+ end
344
+
345
+ # Creates a new row in the table with the supplied column values.
346
+ # May throw a SQLite3::SQLException.
347
+ def create(*args)
348
+ new_row = args[0]
349
+ new_row[:created] = new_row[:updated] = SQLBUtil::timestamp
350
+
351
+ cols = colnames.intersection new_row.keys
352
+ colspec = (cols.map {|col| col.to_s}).join(", ")
353
+ valspec = (cols.map {|col| col.inspect}).join(", ")
354
+ res = nil
355
+
356
+ # resolve any references in the args
357
+ new_row.each do |k,v|
358
+ new_row[k] = v.row_id if v.class.ancestors.include? Persisting
359
+ end
360
+
361
+ self.db.transaction do |db|
362
+ stmt = "insert into #{table_name} (#{colspec}) values (#{valspec})"
363
+ # p stmt
364
+ db.execute(stmt, new_row)
365
+ res = find(db.last_insert_row_id)
366
+ end
367
+ res
368
+ end
369
+
370
+ # Returns a string consisting of the DDL statement to create a table
371
+ # corresponding to this class.
372
+ def table_decl
373
+ cols = columns.join(", ")
374
+ consts = constraints.join(", ")
375
+ if consts.size > 0
376
+ "create table #{table_name} (#{cols}, #{consts});"
377
+ else
378
+ "create table #{table_name} (#{cols});"
379
+ end
380
+ end
381
+
382
+ # Creates a table in the database corresponding to this class.
383
+ def create_table(dbkey=:default)
384
+ self.db ||= Persistence::dbs[dbkey]
385
+ self.db.execute(table_decl)
386
+ @creation_callbacks.each {|func| func.call}
387
+ end
388
+
389
+ def db
390
+ @db || Persistence::db
391
+ end
392
+
393
+ def db=(d)
394
+ @db = d
395
+ end
396
+
397
+ # Ensure that all the necessary accessors on our class instance are defined
398
+ # and that all metaclass fields have the appropriate values
399
+ def ensure_accessors
400
+ # Define singleton accessors
401
+ if not self.respond_to? :columns
402
+ class << self
403
+ # Arrays of columns, column names, and column constraints.
404
+ # Note that colnames does not contain id, created, or updated.
405
+ # The API purposefully does not expose the ability to create a
406
+ # row with a given id, and created and updated values are
407
+ # maintained automatically by the API.
408
+ attr_accessor :columns, :colnames, :constraints, :dirtied, :refs, :creation_callbacks
409
+ end
410
+ end
411
+
412
+ # Ensure singleton fields are initialized
413
+ self.columns ||= [Column.new(:row_id, :integer, [:primary_key]), Column.new(:created, :integer, []), Column.new(:updated, :integer, [])]
414
+ self.colnames ||= Set.new [:created, :updated]
415
+ self.constraints ||= []
416
+ self.dirtied ||= {}
417
+ self.refs ||= {}
418
+ self.creation_callbacks ||= []
419
+ end
420
+
421
+ # Returns the number of rows in the table backing this class
422
+ def count
423
+ result = self.db.execute("select count(row_id) from #{table_name}")[0]
424
+ result[0].to_i
425
+ end
426
+
427
+ def find_tuple(id)
428
+ res = self.db.execute("select * from #{table_name} where row_id = ?", id)
429
+ if res.size == 0
430
+ nil
431
+ else
432
+ res[0]
433
+ end
434
+ end
435
+ end
436
+
437
+ module Persisting
438
+ def self.included(other)
439
+ class << other
440
+ include PersistingClassMixins
441
+ end
442
+
443
+ other.class_eval do
444
+ attr_reader :row_id
445
+ attr_reader :created
446
+ attr_reader :updated
447
+ end
448
+ end
449
+
450
+ def db
451
+ self.class.db
452
+ end
453
+
454
+ # Returns true if the row backing this object has been deleted from the database
455
+ def deleted?
456
+ freshen
457
+ not @tuple
458
+ end
459
+
460
+ # Initializes a new instance backed by a tuple of values. Do not call this directly.
461
+ # Create new instances with the create or find methods.
462
+ def initialize(tup)
463
+ @backed = true
464
+ @tuple = tup
465
+ mark_fresh
466
+ @row_id = @tuple["row_id"]
467
+ @created = @tuple["created"]
468
+ @updated = @tuple["updated"]
469
+ resolve_referents
470
+ self.class.dirtied[@row_id] ||= @expired_after
471
+ self
472
+ end
473
+
474
+ # Deletes the row corresponding to this object from the database;
475
+ # invalidates =self= and any other objects backed by this row
476
+ def delete
477
+ self.db.execute("delete from #{self.class.table_name} where row_id = ?", @row_id)
478
+ mark_dirty
479
+ @tuple = nil
480
+ @row_id = nil
481
+ end
482
+
483
+ ## Begin private methods
484
+
485
+ private
486
+
487
+ # Fetches updated attribute values from the database if necessary
488
+ def freshen
489
+ if needs_refresh?
490
+ @tuple = self.class.find_tuple(@row_id)
491
+ if @tuple
492
+ @updated = @tuple["updated"]
493
+ else
494
+ @row_id = @updated = @created = nil
495
+ end
496
+ mark_fresh
497
+ resolve_referents
498
+ end
499
+ end
500
+
501
+ # True if the underlying row in the database is inconsistent with the state
502
+ # of this object, whether because the row has changed, or because this object has no row id
503
+ def needs_refresh?
504
+ if not @row_id
505
+ @tuple != nil
506
+ else
507
+ @expired_after < self.class.dirtied[@row_id]
508
+ end
509
+ end
510
+
511
+ # Mark this row as dirty so that any other objects backed by this row will
512
+ # update from the database before their attributes are inspected
513
+ def mark_dirty
514
+ self.class.dirtied[@row_id] = SQLBUtil::timestamp
515
+ end
516
+
517
+ # Mark this row as consistent with the underlying database as of now
518
+ def mark_fresh
519
+ @expired_after = SQLBUtil::timestamp
520
+ end
521
+
522
+ # Helper method to update the row in the database when one of our fields changes
523
+ def update(attr_name, value)
524
+ mark_dirty
525
+ self.db.execute("update #{self.class.table_name} set #{attr_name} = ?, updated = ? where row_id = ?", value, SQLBUtil::timestamp, @row_id)
526
+ end
527
+
528
+ # Resolve any fields that reference other tables, replacing row ids with referred objects
529
+ def resolve_referents
530
+ refs = self.class.refs
531
+
532
+ refs.each do |c,r|
533
+ c = c.to_s
534
+ if r.referent == self.class and @tuple[c] == row_id
535
+ @tuple[c] = self
536
+ else
537
+ row = r.referent.find @tuple[c]
538
+ @tuple[c] = row if row
539
+ end
540
+ end
541
+ end
542
+
543
+ end
544
+
545
+ end
data/test/helper.rb ADDED
@@ -0,0 +1,7 @@
1
+ $LOAD_PATH.unshift(File.dirname(__FILE__))
2
+ $LOAD_PATH.unshift(File.join(File.dirname(__FILE__), '..', 'lib'))
3
+
4
+ require 'rhubarb/rhubarb'
5
+
6
+ class Test::Unit::TestCase
7
+ end