Remove schema field functionality
This commit is contained in:
@@ -1419,14 +1419,6 @@ class DbReadBase:
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def _hash_name(self, table, name):
|
|
||||||
"""
|
|
||||||
Used in SQL functions to eval expressions involving selected
|
|
||||||
data.
|
|
||||||
"""
|
|
||||||
name = self.get_table_func(table, "class_func").get_field_alias(name)
|
|
||||||
return name.replace(".", "__")
|
|
||||||
|
|
||||||
|
|
||||||
class DbWriteBase(DbReadBase):
|
class DbWriteBase(DbReadBase):
|
||||||
"""
|
"""
|
||||||
|
@@ -155,21 +155,6 @@ class TableObject(BaseObject):
|
|||||||
"""
|
"""
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def field_aliases(cls):
|
|
||||||
"""
|
|
||||||
Return dictionary of alias to full field names
|
|
||||||
for this object class.
|
|
||||||
"""
|
|
||||||
return {}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_field_alias(cls, alias):
|
|
||||||
"""
|
|
||||||
Return full field name for an alias, if one.
|
|
||||||
"""
|
|
||||||
return cls.field_aliases().get(alias, alias)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_schema(cls):
|
def get_schema(cls):
|
||||||
"""
|
"""
|
||||||
@@ -177,21 +162,6 @@ class TableObject(BaseObject):
|
|||||||
"""
|
"""
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_extra_secondary_fields(cls):
|
|
||||||
"""
|
|
||||||
Return a list of full field names and types for secondary
|
|
||||||
fields that are not directly listed in the schema.
|
|
||||||
"""
|
|
||||||
return []
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_index_fields(cls):
|
|
||||||
"""
|
|
||||||
Return a list of full field names for indices.
|
|
||||||
"""
|
|
||||||
return []
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_secondary_fields(cls):
|
def get_secondary_fields(cls):
|
||||||
"""
|
"""
|
||||||
@@ -210,193 +180,3 @@ class TableObject(BaseObject):
|
|||||||
schema_type,
|
schema_type,
|
||||||
value.get("maxLength")))
|
value.get("maxLength")))
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_label(cls, field, _):
|
|
||||||
"""
|
|
||||||
Get the associated label given a field name of this object.
|
|
||||||
No index positions allowed on lists.
|
|
||||||
"""
|
|
||||||
chain = field.split(".")
|
|
||||||
path = cls._follow_schema_path(chain[:-1])
|
|
||||||
labels = path.get_labels(_)
|
|
||||||
if chain[-1] in labels:
|
|
||||||
return labels[chain[-1]]
|
|
||||||
else:
|
|
||||||
raise Exception("%s has no such label on %s: '%s'" %
|
|
||||||
(cls, path, field))
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_field_type(cls, field):
|
|
||||||
"""
|
|
||||||
Get the associated label given a field name of this object.
|
|
||||||
No index positions allowed on lists.
|
|
||||||
"""
|
|
||||||
field = cls.get_field_alias(field)
|
|
||||||
chain = field.split(".")
|
|
||||||
ftype = cls._follow_schema_path(chain)
|
|
||||||
return ftype
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _follow_schema_path(cls, chain):
|
|
||||||
"""
|
|
||||||
Follow a list of schema items. Return endpoint.
|
|
||||||
"""
|
|
||||||
path = cls
|
|
||||||
for part in chain:
|
|
||||||
schema = path.get_schema()
|
|
||||||
if part.isdigit():
|
|
||||||
pass # skip over
|
|
||||||
elif part in schema.keys():
|
|
||||||
path = schema[part]
|
|
||||||
else:
|
|
||||||
raise Exception("No such field. Valid fields are: %s" % list(schema.keys()))
|
|
||||||
if isinstance(path, (list, tuple)):
|
|
||||||
path = path[0]
|
|
||||||
return path
|
|
||||||
|
|
||||||
def get_field(self, field, db=None, ignore_errors=False):
|
|
||||||
"""
|
|
||||||
Get the value of a field.
|
|
||||||
"""
|
|
||||||
field = self.__class__.get_field_alias(field)
|
|
||||||
chain = field.split(".")
|
|
||||||
path = self._follow_field_path(chain, db, ignore_errors)
|
|
||||||
return path
|
|
||||||
|
|
||||||
def _follow_field_path(self, chain, db=None, ignore_errors=False):
|
|
||||||
"""
|
|
||||||
Follow a list of items. Return endpoint(s) only.
|
|
||||||
With the db argument, can do joins across tables.
|
|
||||||
self - current object
|
|
||||||
returns - None, endpoint, of recursive list of endpoints
|
|
||||||
"""
|
|
||||||
from .handle import HandleClass
|
|
||||||
# start with [self, self, chain, path_to=[]]
|
|
||||||
# results = []
|
|
||||||
# expand when you reach multiple answers [obj, chain_left, []]
|
|
||||||
# if you get to an endpoint, put results
|
|
||||||
# go until nothing left to expand
|
|
||||||
todo = [(self, self, [], chain)]
|
|
||||||
results = []
|
|
||||||
while todo:
|
|
||||||
parent, current, path_to, chain = todo.pop()
|
|
||||||
#print("expand:", parent.__class__.__name__,
|
|
||||||
# current.__class__.__name__,
|
|
||||||
# path_to,
|
|
||||||
# chain)
|
|
||||||
keep_going = True
|
|
||||||
p = 0
|
|
||||||
while p < len(chain) and keep_going:
|
|
||||||
#print("while:", path_to, chain[p:])
|
|
||||||
part = chain[p]
|
|
||||||
if hasattr(current, part): # attribute
|
|
||||||
current = getattr(current, part)
|
|
||||||
path_to.append(part)
|
|
||||||
# need to consider current+part if current is list:
|
|
||||||
elif isinstance(current, (list, tuple)):
|
|
||||||
if part.isdigit():
|
|
||||||
# followed by index, so continue here
|
|
||||||
if int(part) < len(current):
|
|
||||||
current = current[int(part)]
|
|
||||||
path_to.append(part)
|
|
||||||
elif ignore_errors:
|
|
||||||
current = None
|
|
||||||
keeping_going = False
|
|
||||||
else:
|
|
||||||
raise Exception("invalid index position")
|
|
||||||
else: # else branch! in middle, split paths
|
|
||||||
for i in range(len(current)):
|
|
||||||
#print("split list:", self.__class__.__name__,
|
|
||||||
# current.__class__.__name__,
|
|
||||||
# path_to[:],
|
|
||||||
# [str(i)] + chain[p:])
|
|
||||||
todo.append([self, current, path_to[:], [str(i)] + chain[p:]])
|
|
||||||
current = None
|
|
||||||
keep_going = False
|
|
||||||
else: # part not found on this self
|
|
||||||
# current is a handle
|
|
||||||
# part is something on joined object
|
|
||||||
if parent:
|
|
||||||
ptype = parent.__class__.get_field_type(".".join(path_to))
|
|
||||||
if isinstance(ptype, HandleClass):
|
|
||||||
if db:
|
|
||||||
# start over here:
|
|
||||||
obj = None
|
|
||||||
if current:
|
|
||||||
try:
|
|
||||||
obj = ptype.join(db, current)
|
|
||||||
except HandleError:
|
|
||||||
if ignore_errors:
|
|
||||||
obj = None
|
|
||||||
else:
|
|
||||||
raise
|
|
||||||
if part == "self":
|
|
||||||
current = obj
|
|
||||||
path_to = []
|
|
||||||
#print("split self:", obj.__class__.__name__,
|
|
||||||
# current.__class__.__name__,
|
|
||||||
# path_to,
|
|
||||||
# chain[p + 1:])
|
|
||||||
todo.append([obj, current, path_to, chain[p + 1:]])
|
|
||||||
elif obj:
|
|
||||||
current = getattr(obj, part)
|
|
||||||
#print("split :", obj.__class__.__name__,
|
|
||||||
# current.__class__.__name__,
|
|
||||||
# [part],
|
|
||||||
# chain[p + 1:])
|
|
||||||
todo.append([obj, current, [part], chain[p + 1:]])
|
|
||||||
current = None
|
|
||||||
keep_going = False
|
|
||||||
else:
|
|
||||||
raise Exception("Can't join without database")
|
|
||||||
elif part == "self":
|
|
||||||
pass
|
|
||||||
elif ignore_errors:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
raise Exception("%s is not a valid field of %s; use %s" %
|
|
||||||
(part, current, dir(current)))
|
|
||||||
current = None
|
|
||||||
keep_going = False
|
|
||||||
p += 1
|
|
||||||
if keep_going:
|
|
||||||
results.append(current)
|
|
||||||
if len(results) == 1:
|
|
||||||
return results[0]
|
|
||||||
elif len(results) == 0:
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
return results
|
|
||||||
|
|
||||||
def set_field(self, field, value, db=None, ignore_errors=False):
|
|
||||||
"""
|
|
||||||
Set the value of a basic field (str, int, float, or bool).
|
|
||||||
value can be a string or actual value.
|
|
||||||
Returns number of items changed.
|
|
||||||
"""
|
|
||||||
field = self.__class__.get_field_alias(field)
|
|
||||||
chain = field.split(".")
|
|
||||||
path = self._follow_field_path(chain[:-1], db, ignore_errors)
|
|
||||||
ftype = self.get_field_type(field)
|
|
||||||
# ftype is str, bool, float, or int
|
|
||||||
value = (value in ['True', True]) if ftype is bool else value
|
|
||||||
return self._set_fields(path, chain[-1], value, ftype)
|
|
||||||
|
|
||||||
def _set_fields(self, path, attr, value, ftype):
|
|
||||||
"""
|
|
||||||
Helper function to handle recursive lists of items.
|
|
||||||
"""
|
|
||||||
from .handle import HandleClass
|
|
||||||
if isinstance(path, (list, tuple)):
|
|
||||||
count = 0
|
|
||||||
for item in path:
|
|
||||||
count += self._set_fields(item, attr, value, ftype)
|
|
||||||
elif isinstance(ftype, HandleClass):
|
|
||||||
setattr(path, attr, value)
|
|
||||||
count = 1
|
|
||||||
else:
|
|
||||||
setattr(path, attr, ftype(value))
|
|
||||||
count = 1
|
|
||||||
return count
|
|
||||||
|
|
||||||
|
@@ -1,69 +0,0 @@
|
|||||||
#
|
|
||||||
# Gramps - a GTK+/GNOME based genealogy program
|
|
||||||
#
|
|
||||||
# Copyright (C) 2016 Gramps Development Team
|
|
||||||
#
|
|
||||||
# This program is free software; you can redistribute it and/or modify
|
|
||||||
# it under the terms of the GNU General Public License as published by
|
|
||||||
# the Free Software Foundation; either version 2 of the License, or
|
|
||||||
# (at your option) any later version.
|
|
||||||
#
|
|
||||||
# This program is distributed in the hope that it will be useful,
|
|
||||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
||||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
||||||
# GNU General Public License for more details.
|
|
||||||
#
|
|
||||||
# You should have received a copy of the GNU General Public License
|
|
||||||
# along with this program; if not, write to the Free Software
|
|
||||||
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
|
|
||||||
#
|
|
||||||
|
|
||||||
""" Tests for using database fields """
|
|
||||||
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
from gramps.gen.db import DbTxn
|
|
||||||
from gramps.gen.db.utils import make_database
|
|
||||||
|
|
||||||
from ..import (Person, Surname, Name, NameType, Family, FamilyRelType,
|
|
||||||
Event, EventType, Source, Place, PlaceName, Citation, Date,
|
|
||||||
Repository, RepositoryType, Media, Note, NoteType,
|
|
||||||
StyledText, StyledTextTag, StyledTextTagType, Tag,
|
|
||||||
ChildRef, ChildRefType, Attribute, MediaRef, AttributeType,
|
|
||||||
Url, UrlType, Address, EventRef, EventRoleType, RepoRef,
|
|
||||||
FamilyRelType, LdsOrd, MediaRef, PersonRef, PlaceType,
|
|
||||||
SrcAttribute, SrcAttributeType)
|
|
||||||
|
|
||||||
class FieldBaseTest(unittest.TestCase):
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
db = make_database("inmemorydb")
|
|
||||||
db.load(None)
|
|
||||||
with DbTxn("Test", db) as trans:
|
|
||||||
# Add some people:
|
|
||||||
person1 = Person()
|
|
||||||
person1.primary_name = Name()
|
|
||||||
person1.primary_name.surname_list.append(Surname())
|
|
||||||
person1.primary_name.surname_list[0].surname = "Smith"
|
|
||||||
person1.gramps_id = "I0001"
|
|
||||||
db.add_person(person1, trans) # person gets a handle
|
|
||||||
|
|
||||||
# Add some families:
|
|
||||||
family1 = Family()
|
|
||||||
family1.father_handle = person1.handle
|
|
||||||
family1.gramps_id = "F0001"
|
|
||||||
db.add_family(family1, trans)
|
|
||||||
self.db = db
|
|
||||||
|
|
||||||
def test_field_access01(self):
|
|
||||||
person = self.db.get_person_from_gramps_id("I0001")
|
|
||||||
self.assertEqual(person.get_field("primary_name.surname_list.0.surname"),
|
|
||||||
"Smith")
|
|
||||||
|
|
||||||
def test_field_join01(self):
|
|
||||||
family = self.db.get_family_from_gramps_id("F0001")
|
|
||||||
self.assertEqual(family.get_field("father_handle.primary_name.surname_list.0.surname", self.db),
|
|
||||||
"Smith")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
@@ -920,7 +920,6 @@ class DBAPI(DbGeneric):
|
|||||||
table_name = table.lower()
|
table_name = table.lower()
|
||||||
for field, schema_type, max_length in self.get_table_func(
|
for field, schema_type, max_length in self.get_table_func(
|
||||||
table, "class_func").get_secondary_fields():
|
table, "class_func").get_secondary_fields():
|
||||||
field = self._hash_name(table, field)
|
|
||||||
sql_type = self._sql_type(schema_type, max_length)
|
sql_type = self._sql_type(schema_type, max_length)
|
||||||
try:
|
try:
|
||||||
# test to see if it exists:
|
# test to see if it exists:
|
||||||
@@ -947,10 +946,8 @@ class DBAPI(DbGeneric):
|
|||||||
sets = []
|
sets = []
|
||||||
values = []
|
values = []
|
||||||
for field in fields:
|
for field in fields:
|
||||||
value = obj.get_field(field, self, ignore_errors=True)
|
|
||||||
field = self._hash_name(obj.__class__.__name__, field)
|
|
||||||
sets.append("%s = ?" % field)
|
sets.append("%s = ?" % field)
|
||||||
values.append(value)
|
values.append(getattr(obj, field))
|
||||||
|
|
||||||
# Derived fields
|
# Derived fields
|
||||||
if table == 'Person':
|
if table == 'Person':
|
||||||
|
Reference in New Issue
Block a user