diff --git a/gramps/gen/lib/primaryobj.py b/gramps/gen/lib/primaryobj.py index 9c4f5f16c..1c48f2e64 100644 --- a/gramps/gen/lib/primaryobj.py +++ b/gramps/gen/lib/primaryobj.py @@ -204,66 +204,105 @@ class BasicPrimaryObject(TableObject, PrivacyBase, TagBase): def _follow_field_path(self, chain, db=None, ignore_errors=False): """ - Follow a list of items. Return endpoint. + Follow a list of items. Return endpoint(s) only. With the db argument, can do joins across tables. self - current object + returns - None, endpoint, of recursive list of endpoints """ from .handle import HandleClass + # start with [self, self, chain, path_to=[]] + # results = [] + # expand when you reach multiple answers [obj, chain_left, []] + # if you get to an endpoint, put results + # go until nothing left to expand current = self - path_to = [] - parent = self - for part in chain: - path_to.append(part) - if hasattr(current, part): # attribute - current = getattr(current, part) - elif part.isdigit(): # index into list - if int(part) < len(current): - current = current[int(part)] - continue - elif ignore_errors: - return - else: - raise Exception("Can't get position %s of %s" % (part, current)) - elif isinstance(current, (list, tuple)): - current = [getattr(attr, part) for attr in current] - else: # part not found on this self - # current is a handle - # part is something on joined object - ptype = parent.__class__.get_field_type(".".join(path_to[:-1])) - if isinstance(ptype, HandleClass): - if db: - # start over here: - try: - parent = ptype.join(db, current) - current = getattr(parent, part) - path_to = [] - continue - except: - if ignore_errors: - return + todo = [(self, current, chain, [])] + results = [] + while todo: + parent, current, chain, path_to = todo.pop() + keep_going = True + p = 0 + while p < len(chain) and keep_going: + part = chain[p] + if hasattr(current, part): # attribute + current = getattr(current, part) + path_to.append(part) + # need to consider current+part if current is list: + elif isinstance(current, (list, tuple)): + if part.isdigit(): + # followed by index, so continue here + current = current[int(part)] + path_to.append(part) + else: # else branch! in middle, split paths + for i in range(len(current)): + todo.append([self, current, [str(i)] + chain[p:], path_to]) + current = None + keep_going = False + else: # part not found on this self + # current is a handle + # part is something on joined object + if parent: + ptype = parent.__class__.get_field_type(".".join(path_to)) + if isinstance(ptype, HandleClass): + if db: + # start over here: + obj = ptype.join(db, current) + if part == "self": + current = obj + elif obj: + current = getattr(obj, part) + if current: + path_to = [] + todo.append([obj, current, chain[p + 1:], path_to]) + current = None + keep_going = False else: - raise - else: - raise Exception("Can't join without database") - if ignore_errors: - return - else: - raise Exception("%s is not a valid field of %s; use %s" % - (part, current, dir(current))) - return current + raise Exception("Can't join without database") + elif part == "self": + pass + elif ignore_errors: + pass + else: + raise Exception("%s is not a valid field of %s; use %s" % + (part, current, dir(current))) + current = None + keep_going = False + p += 1 + if keep_going: + results.append(current) + if len(results) == 1: + return results[0] + elif len(results) == 0: + return None + else: + return results def set_field(self, field, value, db=None, ignore_errors=False): """ Set the value of a basic field (str, int, float, or bool). value can be a string or actual value. + Returns number of items changed. """ field = self.__class__.get_field_alias(field) chain = field.split(".") - path = self._follow_field_path(chain[:-1], db, ignore_errors) + path = self._follow_field_path(chain[:-1] + ["self"], db, ignore_errors) ftype = self.get_field_type(field) # ftype is str, bool, float, or int value = (value in ['True', True]) if ftype is bool else value - setattr(path, chain[-1], ftype(value)) + return self._set_fields(path, chain[-1], value, ftype) + + def _set_fields(self, path, attr, value, ftype): + """ + Helper function to handle recursive lists of items. + """ + if isinstance(path, (list, tuple)): + count = 0 + for item in path: + count += self._set_fields(item, attr, value, ftype) + else: + setattr(path, attr, ftype(value)) + count = 1 + return count def set_gramps_id(self, gramps_id): """