Source code for gustaf.helpers.data

"""gustaf/gustaf/helpers/data.py.

Helps helpee to manage data. Some useful data structures.
"""


from collections import namedtuple
from functools import wraps

import numpy as np

from gustaf._base import GustafBase


[docs] class TrackedArray(np.ndarray): """Taken from nice implementations of `trimesh` (see LICENSE.txt). `https://github.com/mikedh/trimesh/blob/main/trimesh/caching.py`. Minor adaption, since we don't have hashing functionalities. All the inplace functions will set modified flag and if some operations has potential to cause un-trackable behavior, writeable flags will be set to False. Note, if you really really want, it is possible to change the tracked array without setting modified flag. """ __slots__ = ("_modified", "_source") def __array_finalize__(self, obj): """Sets default flags for any arrays that maybe generated based on tracked array.""" self._modified = True self._source = int(0) if isinstance(obj, type(self)): if isinstance(obj._source, int): self._source = obj else: self._source = obj._source @property def mutable(self): return self.flags["WRITEABLE"] @mutable.setter def mutable(self, value): self.flags.writeable = value def _set_modified(self): """set modified flags to itself and to the source.""" self._modified = True if isinstance(self._source, type(self)): self._source._modified = True
[docs] def copy(self, *args, **kwargs): """copy gives np.ndarray. no more tracking. """ return np.array(self, copy=True)
[docs] def view(self, *args, **kwargs): """Set writeable flags to False for the view.""" v = super(self.__class__, self).view(*args, **kwargs) v.flags.writeable = False return v
def __iadd__(self, *args, **kwargs): self._set_modified() return super(self.__class__, self).__iadd__(*args, **kwargs) def __isub__(self, *args, **kwargs): self._set_modified() return super(self.__class__, self).__isub__(*args, **kwargs) def __imul__(self, *args, **kwargs): self._set_modified() return super(self.__class__, self).__imul__(*args, **kwargs) def __idiv__(self, *args, **kwargs): self._set_modified() return super(self.__class__, self).__idiv__(*args, **kwargs) def __itruediv__(self, *args, **kwargs): self._set_modified() return super(self.__class__, self).__itruediv__(*args, **kwargs) def __imatmul__(self, *args, **kwargs): self._set_modified() return super(self.__class__, self).__imatmul__(*args, **kwargs) def __ipow__(self, *args, **kwargs): self._set_modified() return super(self.__class__, self).__ipow__(*args, **kwargs) def __imod__(self, *args, **kwargs): self._set_modified() return super(self.__class__, self).__imod__(*args, **kwargs) def __ifloordiv__(self, *args, **kwargs): self._set_modified() return super(self.__class__, self).__ifloordiv__(*args, **kwargs) def __ilshift__(self, *args, **kwargs): self._set_modified() return super(self.__class__, self).__ilshift__(*args, **kwargs) def __irshift__(self, *args, **kwargs): self._set_modified() return super(self.__class__, self).__irshift__(*args, **kwargs) def __iand__(self, *args, **kwargs): self._set_modified() return super(self.__class__, self).__iand__(*args, **kwargs) def __ixor__(self, *args, **kwargs): self._set_modified() return super(self.__class__, self).__ixor__(*args, **kwargs) def __ior__(self, *args, **kwargs): self._set_modified() return super(self.__class__, self).__ior__(*args, **kwargs) def __setitem__(self, *args, **kwargs): self._set_modified() super(self.__class__, self).__setitem__(*args, **kwargs) def __setslice__(self, *args, **kwargs): self._set_modified() super(self.__class__, self).__setslice__(*args, **kwargs) def __getslice__(self, *args, **kwargs): self._set_modified() """ return slices I am pretty sure np.ndarray does not have __*slice__ """ slices = super(self.__class__, self).__getitem__(*args, **kwargs) if isinstance(slices, np.ndarray): slices.flags.writeable = False return slices
[docs] def make_tracked_array(array, dtype=None, copy=True): """Taken from nice implementations of `trimesh` (see LICENSE.txt). `https://github.com/mikedh/trimesh/blob/main/trimesh/caching.py`. ``Properly subclass a numpy ndarray to track changes. Avoids some pitfalls of subclassing by forcing contiguous arrays and does a view into a TrackedArray.`` Factory-like wrapper function for TrackedArray. Parameters ------------ array: array- like object To be turned into a TrackedArray dtype: np.dtype Which dtype to use for the array copy: bool Default is True. copy if True. Returns ------------ tracked : TrackedArray Contains input array data """ # if someone passed us None, just create an empty array if array is None: array = [] # make sure it is contiguous then view it as our subclass tracked = np.ascontiguousarray(array, dtype=dtype) if copy: tracked = tracked.copy().view(TrackedArray) else: tracked = tracked.view(TrackedArray) # should always be contiguous here assert tracked.flags["C_CONTIGUOUS"] return tracked
[docs] class DataHolder(GustafBase): __slots__ = ( "_helpee", "_saved", ) def __init__(self, helpee): """Base class for any data holder. Behaves similar to dict. Parameters ----------- helpee: object GustafBase objects would probably make the most sense here. """ self._helpee = helpee self._saved = dict() def __setitem__(self, key, value): """Raise Error to disable direct value setting. Parameters ----------- key: str value: object """ raise NotImplementedError( "Sorry, you can't set items directly for " f"{type(self).__qualname__}" ) def __getitem__(self, key): """Returns stored item if the key exists. Parameters ----------- key: str Returns -------- value: object """ if key in self._saved.keys(): return self._saved[key] else: raise KeyError(f"`{key}` is not stored for {type(self._helpee)}")
[docs] def pop(self, key, default=None): """ Applied pop() to saved data Parameters ---------- key: str default: object Returns ------- value: object """ return self._saved.pop(key, default)
[docs] def clear(self): """ Clears saved data by reassigning new dict """ self._saved = dict()
[docs] def get(self, key, default_values=None): """Returns stored item if the key exists. Else, given default value. If the key exist, default value always exists, since it is initialized that way. Parameters ----------- key: str default_values: object Returns -------- value: object """ if key in self._saved.keys(): return self._saved[key] else: return default_values
[docs] def keys(self): """Returns keys of data holding dict. Returns -------- keys: dict_keys """ return self._saved.keys()
[docs] def values(self): """Returns values of data holding dict. Returns -------- values: dict_values """ return self._saved.values()
[docs] def items(self): """Returns items of data holding dict. Returns -------- values: dict_values """ return self._saved.items()
[docs] class ComputedData(DataHolder): _depends = None _inv_depends = None __slots__ = () def __init__(self, helpee, **kwargs): """Stores last computed values. Keys are expected to be the same as helpee's function that computes the value. Parameters ----------- helpee: GustafBase """ super().__init__(helpee)
[docs] @classmethod def depends_on(cls, var_names, make_property=False): """Decorator as classmethod. checks if the key should be computed. Two cases, where the answer is yes: 1. there's modification on arrays that the key depend on. ->erases all other 2. is corresponding value None? Supports multi-dependency Parameters ----------- var_name: list make_property: """ def inner(func): # followings are done once while modules are loaded # just subclass this class to make a special helper # for each helpee class. assert isinstance(var_names, list), "var_names should be a list" # initialize property # _depends is dict(str: list) if cls._depends is None: cls._depends = dict() if cls._depends.get(func.__name__, None) is None: cls._depends[func.__name__] = list() # add dependency info cls._depends[func.__name__].extend(var_names) # _inv_depends is dict(str: list) if cls._inv_depends is None: cls._inv_depends = dict() # add inverse dependency for vn in var_names: if cls._inv_depends.get(vn, None) is None: cls._inv_depends[vn] = list() cls._inv_depends[vn].append(func.__name__) @wraps(func) def compute_or_return_saved(*args, **kwargs): """Check if the key should be computed,""" # extract some related info self = args[0] # the helpee itself # explicitly settable kwargs. # unless recompute flag is set False, # it will always recompute and save them # if you call the same function without kwargs # the last one with kwargs will be returned recompute = False if kwargs: recompute = kwargs.get("recompute", True) # computed arrays are called _computed. # loop over dependees and check if they are modified for dependee_str in cls._depends[func.__name__]: dependee = getattr(self, dependee_str) # is modified? if dependee._modified: for inv in cls._inv_depends[dependee_str]: self._computed._saved[inv] = None # is saved / want to recompute? # recompute is added for computed values that accepts params. saved = self._computed._saved.get(func.__name__, None) if saved is not None and not recompute: return saved # we've reached this point because we have to compute this computed = func(*args, **kwargs) if isinstance(computed, np.ndarray): computed.flags.writeable = False # configurable? self._computed._saved[func.__name__] = computed # so, all fresh. we can press NOT-modified button for dependee_str in cls._depends[func.__name__]: dependee = getattr(self, dependee_str) dependee._modified = False return computed if make_property: return property(compute_or_return_saved) else: return compute_or_return_saved return inner
[docs] class VertexData(DataHolder): """ Minimal manager for vertex data. Checks input array size, transforms data on request. __setitem__ and __getitem__ will perform length checks. key(), values(), items(), and get() will return whatever is currently stored. gustaf supports two kinds of data representation: scalar-data with cmap and vector-data with arrows. """ __slots__ = () def __init__(self, helpee): """Checks if helpee has vertices as attr beforehand. Parameters ---------- helpee: Vertices Vertices and its derived classes. """ if not hasattr(helpee, "vertices"): raise AttributeError("Helpee does not have `vertices`.") super().__init__(helpee) def _validate_len(self, value=None, raise_=True): """Checks if given value is a valid vertex_data based of its length. If raise_, throws error, else, deletes all incompatible values. Only checks len(). If array has (1, len) shape, this will still return False. Parameters ---------- value: array-like Default is None. If None, checks all existing values. raise_: bool Default is True, If True, raises in case of incompatibility. Returns ------- validity: bool If raise_ is False. """ valid = True helpee_len = len(self._helpee.vertices) if value is not None: if len(value) != helpee_len: valid = False if raise_ and not valid: raise ValueError( f"Expected ({helpee_len}) length data, " f"Given ({len(value)})" ) return valid # here, check all saved values. to_pop = [] for key, value in self._saved.items(): if len(value) != helpee_len: valid = False if not valid: if raise_: raise ValueError( f"`{key}`-data len ({len(value)}) doesn't match " f"expected len ({helpee_len})" ) else: self._logd( f"`{key}`-data len ({len(value)}) doesn't match " f"expected len ({helpee_len}). Deleting `{key}`." ) # pop invalid data to_pop.append(key) to_pop.append(key + "__norm") # pop if needed for tp in to_pop: self._saved.pop(tp, None) return valid def __setitem__(self, key, value): """ Performs len() based check before storing vertex_data. Parameters ---------- key: str value: object Returns ------- None """ self._validate_len(value, raise_=True) # we are here because this is valid self._saved[key] = make_tracked_array(value, copy=False).reshape( len(self._helpee.vertices), -1 ) def __getitem__(self, key): """ Validates data length before returning item. Parameters ---------- key: str Returns ------- data: array-like """ value = super().__getitem__(key) # raises KeyError valid = self._validate_len(value, raise_=False) if valid: return value else: raise KeyError( "Either requested data is not stored or deleted due to " "changes in number of vertices." )
[docs] def as_scalar(self, key, default=None): """ Returns scalar version of requested data. If it is already a scalar, will return as is. Else, will return a norm. using `np.linalg.norm()`. Parameters ---------- key: str default: object Returns ------- data_as_scalar: (n, 1) np.ndarray """ if key not in self.keys(): return default # interpret scalar as norm # save the norm once it is called. if "__norm" not in key: norm_key = key + "__norm" else: norm_key = key key = key.replace("__norm", "") if norm_key in self.keys(): saved = self[norm_key] # performs len check # return if original is not modified if not self[key]._modified: # check if original data is modified return saved else: self._saved.pop(norm_key) # we are here because we have to compute norm. let's save norm value = self[key] if value.shape[1] == 1: value_norm = value else: value_norm = np.linalg.norm(value, axis=1).reshape(-1, 1) # save norm self[norm_key] = value_norm # considered not modified self[key]._modified = False return value_norm
[docs] def as_arrow(self, key, default=None, raise_=True): """ Returns an array as is, only if it is showable as arrow. Parameters ---------- key: str default: object raise_: bool Returns ------- data: (n, d) np.ndarray """ if key not in self.keys(): return default value = self[key] if value.shape[1] == 1: self._logd(f"as_arrow() requested data ({key}) is 1D data.") if raise_: raise ValueError( f"`{key}`-data is 1D and cannot be represented as arrows." ) return value
Unique2DFloats = namedtuple( "Unique2DFloats", ["values", "ids", "inverse", "intersection"] ) Unique2DFloats.__doc__ = """ namedtuple to hold unique information of float type arrays. Note that for float types, "close enough" might be a better name than unique. This way, all tracked arrays, as long as they are 2D, have a dot separated syntax to access unique info. For example, `mesh.unique_vertices.ids`. """ Unique2DFloats.values.__doc__ = """`(n, d) np.ndarray` Field number 0""" Unique2DFloats.ids.__doc__ = """`(n, d) np.ndarray` Field number 1""" Unique2DFloats.inverse.__doc__ = """`(n, d) np.ndarray` Field number 2""" Unique2DFloats.intersection.__doc__ = """`(m) list of list` given original array's index, returns overlapping arrays, including itself. Field number 3 """ Unique2DIntegers = namedtuple( "Unique2DIntegers", ["values", "ids", "inverse", "counts"] ) Unique2DIntegers.__doc__ = """ namedtuple to hold unique information of integer type arrays. Similar approach to Unique2DFloats. """ Unique2DIntegers.values.__doc__ = """`(n, d) np.ndarray` Field number 0""" Unique2DIntegers.ids.__doc__ = """`(n) np.ndarray` Field number 1""" Unique2DIntegers.inverse.__doc__ = """`(m) np.ndarray` Field number 2""" Unique2DIntegers.counts.__doc__ = """`(n) np.ndarray` Field number 3"""
[docs] class ComputedMeshData(ComputedData): """A class to hold computed-mesh-data. Subclassed to keep its own dependency info. """ pass