* src/RelLib.py: added get_backlink_handles method to PrimaryObject

* test/GrampsDbBase_Test.py: factored out common db test methods
* test/GrampsDbTestBase.py: new base class for unittests that
   need to create database records
* test/RelLib_Test.py: unittest for the
* get_backlink_handles method


svn: r5583
This commit is contained in:
Richard Taylor 2005-12-19 13:45:05 +00:00
parent 05a4cf032a
commit bd1b437256
5 changed files with 251 additions and 122 deletions

View File

@ -1,3 +1,10 @@
2005-12-19 Richard Taylor <rjt-gramps@thegrindstone.me.uk>
* src/RelLib.py: added get_backlink_handles method to PrimaryObject
* test/GrampsDbBase_Test.py: factored out common db test methods
* test/GrampsDbTestBase.py: new base class for unittests that need to
create database records
* test/RelLib_Test.py: unittest for the get_backlink_handles method
2005-12-17 Alex Roitman <shura@gramps-project.org>
* src/GrampsBSDDB.py (gramps_upgrade_9): Switch to using keys in
upgrade. When using DB cusrsor, modifying the record sometimes

View File

@ -347,6 +347,24 @@ class PrimaryObject(BaseObject):
def _replace_handle_reference(self,classname,old_handle,new_handle):
pass
def get_backlink_handles(self,db,include_classes=None):
"""Get a list of all primary objects that make some reference to this
primary object, either directly or via a child object.
Returns an iterator over tuples each of the form (class_name,handle).
To get a list use:
references = [ ref for ref in obj.get_backlink_handles() ]
@param db: a object with the find_backlink_handles method
@type db: usually a instance of a class derived from GrampsDbBase.
@param include_classes: the primary classes to include in the result.
@type: tuple of primary class names as strings, or None for all classes.
"""
return db.find_backlink_handles(self.get_handle(),include_classes)
def set_marker(self,marker):
self.marker = marker

View File

@ -19,128 +19,12 @@ import RelLib
logger = logging.getLogger('Gramps.GrampsDbBase_Test')
class ReferenceMapTest (unittest.TestCase):
def setUp(self):
self._tmpdir = tempfile.mkdtemp()
self._filename = os.path.join(self._tmpdir,'test.grdb')
self._db = GrampsBSDDB.GrampsBSDDB()
self._db.load(self._filename,
None, # callback
"w")
def tearDown(self):
shutil.rmtree(self._tmpdir)
def _populate_database(self,
num_sources = 1,
num_persons = 0,
num_families = 0,
num_events = 0,
num_places = 0,
num_media_objects = 0,
num_links = 1):
# start with sources
sources = []
for i in xrange(0,num_sources):
sources.append(self._add_source())
# now for each of the other tables. Give each entry a link
# to num_link sources, sources are chosen on a round robin
# basis
for num, add_func in ((num_persons, self._add_person_with_sources),
(num_families, self._add_family_with_sources),
(num_events, self._add_event_with_sources),
(num_places, self._add_place_with_sources),
(num_media_objects, self._add_media_object_with_sources)):
source_idx = 1
for person_idx in xrange(0,num):
# Get the list of sources to link
lnk_sources = set()
for i in xrange(0,num_links):
lnk_sources.add(sources[source_idx-1])
source_idx = (source_idx+1) % len(sources)
try:
add_func(lnk_sources)
except:
print "person_idx = ", person_idx
print "lnk_sources = ", repr(lnk_sources)
raise
return
def _add_source(self):
# Add a Source
tran = self._db.transaction_begin()
source = RelLib.Source()
self._db.add_source(source,tran)
self._db.commit_source(source,tran)
self._db.transaction_commit(tran, "Add Source")
return source
def _add_object_with_source(self,sources,object_class,add_method,commit_method):
object = object_class()
for source in sources:
src_ref = RelLib.SourceRef()
src_ref.set_base_handle(source.get_handle())
object.add_source_reference(src_ref)
tran = self._db.transaction_begin()
add_method(object,tran)
commit_method(object,tran)
self._db.transaction_commit(tran, "Add Object")
return object
def _add_person_with_sources(self,sources):
return self._add_object_with_source(sources,
RelLib.Person,
self._db.add_person,
self._db.commit_person)
def _add_family_with_sources(self,sources):
return self._add_object_with_source(sources,
RelLib.Family,
self._db.add_family,
self._db.commit_family)
def _add_event_with_sources(self,sources):
return self._add_object_with_source(sources,
RelLib.Event,
self._db.add_event,
self._db.commit_event)
def _add_place_with_sources(self,sources):
return self._add_object_with_source(sources,
RelLib.Place,
self._db.add_place,
self._db.commit_place)
def _add_media_object_with_sources(self,sources):
return self._add_object_with_source(sources,
RelLib.MediaObject,
self._db.add_object,
self._db.commit_media_object)
from GrampsDbTestBase import GrampsDbBaseTest
class ReferenceMapTest (GrampsDbBaseTest):
"""Test methods on the GrampsDbBase class that are related to the reference_map
index implementation."""
def test_simple_lookup(self):
"""insert a record and a reference and check that
@ -295,7 +179,8 @@ class ReferenceMapTest (unittest.TestCase):
assert with_reference_map < (without_reference_map / 10), "Reference_map should an order of magnitude faster."
def testSuite():
return unittest.makeSuite(ReferenceMapTest,'test')
suite = unittest.makeSuite(ReferenceMapTest,'test')
return suite
def perfSuite():
return unittest.makeSuite(ReferenceMapTest,'perf')

144
test/GrampsDbTestBase.py Normal file
View File

@ -0,0 +1,144 @@
import unittest
import logging
import os
import tempfile
import shutil
import time
import traceback
import sys
sys.path.append('../src')
try:
set()
except NameError:
from sets import Set as set
import GrampsBSDDB
import RelLib
logger = logging.getLogger('Gramps.GrampsDbTestBase')
class GrampsDbBaseTest(unittest.TestCase):
"""Base class for unittest that need to be able to create
test databases."""
def setUp(self):
self._tmpdir = tempfile.mkdtemp()
self._filename = os.path.join(self._tmpdir,'test.grdb')
self._db = GrampsBSDDB.GrampsBSDDB()
self._db.load(self._filename,
None, # callback
"w")
def tearDown(self):
shutil.rmtree(self._tmpdir)
def _populate_database(self,
num_sources = 1,
num_persons = 0,
num_families = 0,
num_events = 0,
num_places = 0,
num_media_objects = 0,
num_links = 1):
# start with sources
sources = []
for i in xrange(0,num_sources):
sources.append(self._add_source())
# now for each of the other tables. Give each entry a link
# to num_link sources, sources are chosen on a round robin
# basis
for num, add_func in ((num_persons, self._add_person_with_sources),
(num_families, self._add_family_with_sources),
(num_events, self._add_event_with_sources),
(num_places, self._add_place_with_sources),
(num_media_objects, self._add_media_object_with_sources)):
source_idx = 1
for person_idx in xrange(0,num):
# Get the list of sources to link
lnk_sources = set()
for i in xrange(0,num_links):
lnk_sources.add(sources[source_idx-1])
source_idx = (source_idx+1) % len(sources)
try:
add_func(lnk_sources)
except:
print "person_idx = ", person_idx
print "lnk_sources = ", repr(lnk_sources)
raise
return
def _add_source(self):
# Add a Source
tran = self._db.transaction_begin()
source = RelLib.Source()
self._db.add_source(source,tran)
self._db.commit_source(source,tran)
self._db.transaction_commit(tran, "Add Source")
return source
def _add_object_with_source(self,sources,object_class,add_method,commit_method):
object = object_class()
for source in sources:
src_ref = RelLib.SourceRef()
src_ref.set_base_handle(source.get_handle())
object.add_source_reference(src_ref)
tran = self._db.transaction_begin()
add_method(object,tran)
commit_method(object,tran)
self._db.transaction_commit(tran, "Add Object")
return object
def _add_person_with_sources(self,sources):
return self._add_object_with_source(sources,
RelLib.Person,
self._db.add_person,
self._db.commit_person)
def _add_family_with_sources(self,sources):
return self._add_object_with_source(sources,
RelLib.Family,
self._db.add_family,
self._db.commit_family)
def _add_event_with_sources(self,sources):
return self._add_object_with_source(sources,
RelLib.Event,
self._db.add_event,
self._db.commit_event)
def _add_place_with_sources(self,sources):
return self._add_object_with_source(sources,
RelLib.Place,
self._db.add_place,
self._db.commit_place)
def _add_media_object_with_sources(self,sources):
return self._add_object_with_source(sources,
RelLib.MediaObject,
self._db.add_object,
self._db.commit_media_object)

75
test/RelLib_Test.py Normal file
View File

@ -0,0 +1,75 @@
import unittest
import logging
import os
import tempfile
import shutil
import time
import traceback
import sys
sys.path.append('../src')
try:
set()
except NameError:
from sets import Set as set
import RelLib
logger = logging.getLogger('Gramps.RelLib_Test')
from GrampsDbTestBase import GrampsDbBaseTest
class PrimaryObjectTest (GrampsDbBaseTest):
"""Test methods on the PrimaryObject class"""
def test_get_backlink_handles(self):
"""Check that backlink lookup works."""
source = self._add_source()
person = self._add_person_with_sources([source])
references = [ ref for ref in source.get_backlink_handles(self._db) ]
assert len(references) == 1
assert references[0] == (RelLib.Person.__name__,person.get_handle())
def test_get_backlink_handles_with_class_list(self):
"""Check backlink lookup with class list."""
source = self._add_source()
person = self._add_person_with_sources([source])
self._add_family_with_sources([source])
self._add_event_with_sources([source])
self._add_place_with_sources([source])
self._add_media_object_with_sources([source])
references = [ ref for ref in source.get_backlink_handles(self._db) ]
# make sure that we have the correct number of references (one for each object)
references = [ ref for ref in source.get_backlink_handles(self._db) ]
assert len(references) == 5, "len(references) == %s " % str(len(references))
# should just return the person reference
references = [ ref for ref in source.get_backlink_handles(self._db,(RelLib.Person.__name__,)) ]
assert len(references) == 1, "len(references) == %s " % str(len(references))
assert references[0][0] == RelLib.Person.__name__, "references = %s" % repr(references)
# should just return the person and event reference
references = [ ref for ref in source.get_backlink_handles(self._db,(RelLib.Person.__name__,
RelLib.Event.__name__)) ]
assert len(references) == 2, "len(references) == %s " % str(len(references))
assert references[0][0] == RelLib.Person.__name__, "references = %s" % repr(references)
assert references[1][0] == RelLib.Event.__name__, "references = %s" % repr(references)
def testSuite():
suite = unittest.makeSuite(PrimaryObjectTest,'test')
return suite
if __name__ == '__main__':
unittest.TextTestRunner().run(testSuite())