Revised db.select to be completely general

This commit is contained in:
Doug Blank 2016-01-13 08:20:50 -05:00
parent dc698782b9
commit f9930c1bcf

View File

@ -204,66 +204,105 @@ class BasicPrimaryObject(TableObject, PrivacyBase, TagBase):
def _follow_field_path(self, chain, db=None, ignore_errors=False): def _follow_field_path(self, chain, db=None, ignore_errors=False):
""" """
Follow a list of items. Return endpoint. Follow a list of items. Return endpoint(s) only.
With the db argument, can do joins across tables. With the db argument, can do joins across tables.
self - current object self - current object
returns - None, endpoint, of recursive list of endpoints
""" """
from .handle import HandleClass from .handle import HandleClass
# start with [self, self, chain, path_to=[]]
# results = []
# expand when you reach multiple answers [obj, chain_left, []]
# if you get to an endpoint, put results
# go until nothing left to expand
current = self current = self
path_to = [] todo = [(self, current, chain, [])]
parent = self results = []
for part in chain: while todo:
path_to.append(part) parent, current, chain, path_to = todo.pop()
keep_going = True
p = 0
while p < len(chain) and keep_going:
part = chain[p]
if hasattr(current, part): # attribute if hasattr(current, part): # attribute
current = getattr(current, part) current = getattr(current, part)
elif part.isdigit(): # index into list path_to.append(part)
if int(part) < len(current): # need to consider current+part if current is list:
current = current[int(part)]
continue
elif ignore_errors:
return
else:
raise Exception("Can't get position %s of %s" % (part, current))
elif isinstance(current, (list, tuple)): elif isinstance(current, (list, tuple)):
current = [getattr(attr, part) for attr in current] if part.isdigit():
# followed by index, so continue here
current = current[int(part)]
path_to.append(part)
else: # else branch! in middle, split paths
for i in range(len(current)):
todo.append([self, current, [str(i)] + chain[p:], path_to])
current = None
keep_going = False
else: # part not found on this self else: # part not found on this self
# current is a handle # current is a handle
# part is something on joined object # part is something on joined object
ptype = parent.__class__.get_field_type(".".join(path_to[:-1])) if parent:
ptype = parent.__class__.get_field_type(".".join(path_to))
if isinstance(ptype, HandleClass): if isinstance(ptype, HandleClass):
if db: if db:
# start over here: # start over here:
try: obj = ptype.join(db, current)
parent = ptype.join(db, current) if part == "self":
current = getattr(parent, part) current = obj
elif obj:
current = getattr(obj, part)
if current:
path_to = [] path_to = []
continue todo.append([obj, current, chain[p + 1:], path_to])
except: current = None
if ignore_errors: keep_going = False
return
else:
raise
else: else:
raise Exception("Can't join without database") raise Exception("Can't join without database")
if ignore_errors: elif part == "self":
return pass
elif ignore_errors:
pass
else: else:
raise Exception("%s is not a valid field of %s; use %s" % raise Exception("%s is not a valid field of %s; use %s" %
(part, current, dir(current))) (part, current, dir(current)))
return current current = None
keep_going = False
p += 1
if keep_going:
results.append(current)
if len(results) == 1:
return results[0]
elif len(results) == 0:
return None
else:
return results
def set_field(self, field, value, db=None, ignore_errors=False): def set_field(self, field, value, db=None, ignore_errors=False):
""" """
Set the value of a basic field (str, int, float, or bool). Set the value of a basic field (str, int, float, or bool).
value can be a string or actual value. value can be a string or actual value.
Returns number of items changed.
""" """
field = self.__class__.get_field_alias(field) field = self.__class__.get_field_alias(field)
chain = field.split(".") chain = field.split(".")
path = self._follow_field_path(chain[:-1], db, ignore_errors) path = self._follow_field_path(chain[:-1] + ["self"], db, ignore_errors)
ftype = self.get_field_type(field) ftype = self.get_field_type(field)
# ftype is str, bool, float, or int # ftype is str, bool, float, or int
value = (value in ['True', True]) if ftype is bool else value value = (value in ['True', True]) if ftype is bool else value
setattr(path, chain[-1], ftype(value)) return self._set_fields(path, chain[-1], value, ftype)
def _set_fields(self, path, attr, value, ftype):
"""
Helper function to handle recursive lists of items.
"""
if isinstance(path, (list, tuple)):
count = 0
for item in path:
count += self._set_fields(item, attr, value, ftype)
else:
setattr(path, attr, ftype(value))
count = 1
return count
def set_gramps_id(self, gramps_id): def set_gramps_id(self, gramps_id):
""" """