"""gustaf/gustaf/helpers/data.py.
Helps helpee to manage data. Some useful data structures.
"""
from collections import namedtuple
from functools import wraps
import numpy as np
from gustaf._base import GustafBase
[docs]
class TrackedArray(np.ndarray):
"""Taken from nice implementations of `trimesh` (see LICENSE.txt).
`https://github.com/mikedh/trimesh/blob/main/trimesh/caching.py`. Minor
adaption, since we don't have hashing functionalities.
All the inplace functions will set modified flag and if some operations
has potential to cause un-trackable behavior, writeable flags will be set
to False.
Note, if you really really want, it is possible to change the tracked
array without setting modified flag.
"""
__slots__ = ("_modified", "_source")
def __array_finalize__(self, obj):
"""Sets default flags for any arrays that maybe generated based on
tracked array."""
self._modified = True
self._source = int(0)
if isinstance(obj, type(self)):
if isinstance(obj._source, int):
self._source = obj
else:
self._source = obj._source
@property
def mutable(self):
return self.flags["WRITEABLE"]
@mutable.setter
def mutable(self, value):
self.flags.writeable = value
def _set_modified(self):
"""set modified flags to itself and to the source."""
self._modified = True
if isinstance(self._source, type(self)):
self._source._modified = True
[docs]
def copy(self, *args, **kwargs):
"""copy gives np.ndarray.
no more tracking.
"""
return np.array(self, copy=True)
[docs]
def view(self, *args, **kwargs):
"""Set writeable flags to False for the view."""
v = super(self.__class__, self).view(*args, **kwargs)
v.flags.writeable = False
return v
def __iadd__(self, *args, **kwargs):
self._set_modified()
return super(self.__class__, self).__iadd__(*args, **kwargs)
def __isub__(self, *args, **kwargs):
self._set_modified()
return super(self.__class__, self).__isub__(*args, **kwargs)
def __imul__(self, *args, **kwargs):
self._set_modified()
return super(self.__class__, self).__imul__(*args, **kwargs)
def __idiv__(self, *args, **kwargs):
self._set_modified()
return super(self.__class__, self).__idiv__(*args, **kwargs)
def __itruediv__(self, *args, **kwargs):
self._set_modified()
return super(self.__class__, self).__itruediv__(*args, **kwargs)
def __imatmul__(self, *args, **kwargs):
self._set_modified()
return super(self.__class__, self).__imatmul__(*args, **kwargs)
def __ipow__(self, *args, **kwargs):
self._set_modified()
return super(self.__class__, self).__ipow__(*args, **kwargs)
def __imod__(self, *args, **kwargs):
self._set_modified()
return super(self.__class__, self).__imod__(*args, **kwargs)
def __ifloordiv__(self, *args, **kwargs):
self._set_modified()
return super(self.__class__, self).__ifloordiv__(*args, **kwargs)
def __ilshift__(self, *args, **kwargs):
self._set_modified()
return super(self.__class__, self).__ilshift__(*args, **kwargs)
def __irshift__(self, *args, **kwargs):
self._set_modified()
return super(self.__class__, self).__irshift__(*args, **kwargs)
def __iand__(self, *args, **kwargs):
self._set_modified()
return super(self.__class__, self).__iand__(*args, **kwargs)
def __ixor__(self, *args, **kwargs):
self._set_modified()
return super(self.__class__, self).__ixor__(*args, **kwargs)
def __ior__(self, *args, **kwargs):
self._set_modified()
return super(self.__class__, self).__ior__(*args, **kwargs)
def __setitem__(self, *args, **kwargs):
self._set_modified()
super(self.__class__, self).__setitem__(*args, **kwargs)
def __setslice__(self, *args, **kwargs):
self._set_modified()
super(self.__class__, self).__setslice__(*args, **kwargs)
def __getslice__(self, *args, **kwargs):
self._set_modified()
"""
return slices I am pretty sure np.ndarray does not have __*slice__
"""
slices = super(self.__class__, self).__getitem__(*args, **kwargs)
if isinstance(slices, np.ndarray):
slices.flags.writeable = False
return slices
[docs]
def make_tracked_array(array, dtype=None, copy=True):
"""Taken from nice implementations of `trimesh` (see LICENSE.txt).
`https://github.com/mikedh/trimesh/blob/main/trimesh/caching.py`.
``Properly subclass a numpy ndarray to track changes.
Avoids some pitfalls of subclassing by forcing contiguous
arrays and does a view into a TrackedArray.``
Factory-like wrapper function for TrackedArray.
Parameters
------------
array: array- like object
To be turned into a TrackedArray
dtype: np.dtype
Which dtype to use for the array
copy: bool
Default is True. copy if True.
Returns
------------
tracked : TrackedArray
Contains input array data
"""
# if someone passed us None, just create an empty array
if array is None:
array = []
# make sure it is contiguous then view it as our subclass
tracked = np.ascontiguousarray(array, dtype=dtype)
if copy:
tracked = tracked.copy().view(TrackedArray)
else:
tracked = tracked.view(TrackedArray)
# should always be contiguous here
assert tracked.flags["C_CONTIGUOUS"]
return tracked
[docs]
class DataHolder(GustafBase):
__slots__ = (
"_helpee",
"_saved",
)
def __init__(self, helpee):
"""Base class for any data holder. Behaves similar to dict.
Parameters
-----------
helpee: object
GustafBase objects would probably make the most sense here.
"""
self._helpee = helpee
self._saved = dict()
def __setitem__(self, key, value):
"""Raise Error to disable direct value setting.
Parameters
-----------
key: str
value: object
"""
raise NotImplementedError(
"Sorry, you can't set items directly for "
f"{type(self).__qualname__}"
)
def __getitem__(self, key):
"""Returns stored item if the key exists.
Parameters
-----------
key: str
Returns
--------
value: object
"""
if key in self._saved.keys():
return self._saved[key]
else:
raise KeyError(f"`{key}` is not stored for {type(self._helpee)}")
[docs]
def pop(self, key, default=None):
"""
Applied pop() to saved data
Parameters
----------
key: str
default: object
Returns
-------
value: object
"""
return self._saved.pop(key, default)
[docs]
def clear(self):
"""
Clears saved data by reassigning new dict
"""
self._saved = dict()
[docs]
def get(self, key, default_values=None):
"""Returns stored item if the key exists. Else, given default value. If
the key exist, default value always exists, since it is initialized
that way.
Parameters
-----------
key: str
default_values: object
Returns
--------
value: object
"""
if key in self._saved.keys():
return self._saved[key]
else:
return default_values
[docs]
def keys(self):
"""Returns keys of data holding dict.
Returns
--------
keys: dict_keys
"""
return self._saved.keys()
[docs]
def values(self):
"""Returns values of data holding dict.
Returns
--------
values: dict_values
"""
return self._saved.values()
[docs]
def items(self):
"""Returns items of data holding dict.
Returns
--------
values: dict_values
"""
return self._saved.items()
[docs]
class ComputedData(DataHolder):
_depends = None
_inv_depends = None
__slots__ = ()
def __init__(self, helpee, **kwargs):
"""Stores last computed values.
Keys are expected to be the same as helpee's function that computes the
value.
Parameters
-----------
helpee: GustafBase
"""
super().__init__(helpee)
[docs]
@classmethod
def depends_on(cls, var_names, make_property=False):
"""Decorator as classmethod.
checks if the key should be computed. Two cases, where the answer is
yes:
1. there's modification on arrays that the key depend on.
->erases all other
2. is corresponding value None?
Supports multi-dependency
Parameters
-----------
var_name: list
make_property:
"""
def inner(func):
# followings are done once while modules are loaded
# just subclass this class to make a special helper
# for each helpee class.
assert isinstance(var_names, list), "var_names should be a list"
# initialize property
# _depends is dict(str: list)
if cls._depends is None:
cls._depends = dict()
if cls._depends.get(func.__name__, None) is None:
cls._depends[func.__name__] = list()
# add dependency info
cls._depends[func.__name__].extend(var_names)
# _inv_depends is dict(str: list)
if cls._inv_depends is None:
cls._inv_depends = dict()
# add inverse dependency
for vn in var_names:
if cls._inv_depends.get(vn, None) is None:
cls._inv_depends[vn] = list()
cls._inv_depends[vn].append(func.__name__)
@wraps(func)
def compute_or_return_saved(*args, **kwargs):
"""Check if the key should be computed,"""
# extract some related info
self = args[0] # the helpee itself
# explicitly settable kwargs.
# unless recompute flag is set False,
# it will always recompute and save them
# if you call the same function without kwargs
# the last one with kwargs will be returned
recompute = False
if kwargs:
recompute = kwargs.get("recompute", True)
# computed arrays are called _computed.
# loop over dependees and check if they are modified
for dependee_str in cls._depends[func.__name__]:
dependee = getattr(self, dependee_str)
# is modified?
if dependee._modified:
for inv in cls._inv_depends[dependee_str]:
self._computed._saved[inv] = None
# is saved / want to recompute?
# recompute is added for computed values that accepts params.
saved = self._computed._saved.get(func.__name__, None)
if saved is not None and not recompute:
return saved
# we've reached this point because we have to compute this
computed = func(*args, **kwargs)
if isinstance(computed, np.ndarray):
computed.flags.writeable = False # configurable?
self._computed._saved[func.__name__] = computed
# so, all fresh. we can press NOT-modified button
for dependee_str in cls._depends[func.__name__]:
dependee = getattr(self, dependee_str)
dependee._modified = False
return computed
if make_property:
return property(compute_or_return_saved)
else:
return compute_or_return_saved
return inner
[docs]
class VertexData(DataHolder):
"""
Minimal manager for vertex data. Checks input array size, transforms
data on request. __setitem__ and __getitem__ will perform length checks.
key(), values(), items(), and get() will return whatever is currently
stored.
gustaf supports two kinds of data representation: scalar-data with cmap
and vector-data with arrows.
"""
__slots__ = ()
def __init__(self, helpee):
"""Checks if helpee has vertices as attr beforehand.
Parameters
----------
helpee: Vertices
Vertices and its derived classes.
"""
if not hasattr(helpee, "vertices"):
raise AttributeError("Helpee does not have `vertices`.")
super().__init__(helpee)
def _validate_len(self, value=None, raise_=True):
"""Checks if given value is a valid vertex_data based of its length.
If raise_, throws error, else, deletes all incompatible values.
Only checks len(). If array has (1, len) shape, this will still return
False.
Parameters
----------
value: array-like
Default is None. If None, checks all existing values.
raise_: bool
Default is True, If True, raises in case of incompatibility.
Returns
-------
validity: bool
If raise_ is False.
"""
valid = True
helpee_len = len(self._helpee.vertices)
if value is not None:
if len(value) != helpee_len:
valid = False
if raise_ and not valid:
raise ValueError(
f"Expected ({helpee_len}) length data, "
f"Given ({len(value)})"
)
return valid
# here, check all saved values.
to_pop = []
for key, value in self._saved.items():
if len(value) != helpee_len:
valid = False
if not valid:
if raise_:
raise ValueError(
f"`{key}`-data len ({len(value)}) doesn't match "
f"expected len ({helpee_len})"
)
else:
self._logd(
f"`{key}`-data len ({len(value)}) doesn't match "
f"expected len ({helpee_len}). Deleting `{key}`."
)
# pop invalid data
to_pop.append(key)
to_pop.append(key + "__norm")
# pop if needed
for tp in to_pop:
self._saved.pop(tp, None)
return valid
def __setitem__(self, key, value):
"""
Performs len() based check before storing vertex_data.
Parameters
----------
key: str
value: object
Returns
-------
None
"""
self._validate_len(value, raise_=True)
# we are here because this is valid
self._saved[key] = make_tracked_array(value, copy=False).reshape(
len(self._helpee.vertices), -1
)
def __getitem__(self, key):
"""
Validates data length before returning item.
Parameters
----------
key: str
Returns
-------
data: array-like
"""
value = super().__getitem__(key) # raises KeyError
valid = self._validate_len(value, raise_=False)
if valid:
return value
else:
raise KeyError(
"Either requested data is not stored or deleted due to "
"changes in number of vertices."
)
[docs]
def as_scalar(self, key, default=None):
"""
Returns scalar version of requested data. If it is already a scalar,
will return as is. Else, will return a norm. using `np.linalg.norm()`.
Parameters
----------
key: str
default: object
Returns
-------
data_as_scalar: (n, 1) np.ndarray
"""
if key not in self.keys():
return default
# interpret scalar as norm
# save the norm once it is called.
if "__norm" not in key:
norm_key = key + "__norm"
else:
norm_key = key
key = key.replace("__norm", "")
if norm_key in self.keys():
saved = self[norm_key] # performs len check
# return if original is not modified
if not self[key]._modified: # check if original data is modified
return saved
else:
self._saved.pop(norm_key)
# we are here because we have to compute norm. let's save norm
value = self[key]
if value.shape[1] == 1:
value_norm = value
else:
value_norm = np.linalg.norm(value, axis=1).reshape(-1, 1)
# save norm
self[norm_key] = value_norm
# considered not modified
self[key]._modified = False
return value_norm
[docs]
def as_arrow(self, key, default=None, raise_=True):
"""
Returns an array as is, only if it is showable as arrow.
Parameters
----------
key: str
default: object
raise_: bool
Returns
-------
data: (n, d) np.ndarray
"""
if key not in self.keys():
return default
value = self[key]
if value.shape[1] == 1:
self._logd(f"as_arrow() requested data ({key}) is 1D data.")
if raise_:
raise ValueError(
f"`{key}`-data is 1D and cannot be represented as arrows."
)
return value
Unique2DFloats = namedtuple(
"Unique2DFloats", ["values", "ids", "inverse", "intersection"]
)
Unique2DFloats.__doc__ = """
namedtuple to hold unique information of float type arrays.
Note that for float types, "close enough" might be a better name than unique.
This way, all tracked arrays, as long as they are 2D, have a dot separated
syntax to access unique info. For example, `mesh.unique_vertices.ids`.
"""
Unique2DFloats.values.__doc__ = """`(n, d) np.ndarray`
Field number 0"""
Unique2DFloats.ids.__doc__ = """`(n, d) np.ndarray`
Field number 1"""
Unique2DFloats.inverse.__doc__ = """`(n, d) np.ndarray`
Field number 2"""
Unique2DFloats.intersection.__doc__ = """`(m) list of list`
given original array's index, returns overlapping arrays, including itself.
Field number 3
"""
Unique2DIntegers = namedtuple(
"Unique2DIntegers", ["values", "ids", "inverse", "counts"]
)
Unique2DIntegers.__doc__ = """
namedtuple to hold unique information of integer type arrays.
Similar approach to Unique2DFloats.
"""
Unique2DIntegers.values.__doc__ = """`(n, d) np.ndarray`
Field number 0"""
Unique2DIntegers.ids.__doc__ = """`(n) np.ndarray`
Field number 1"""
Unique2DIntegers.inverse.__doc__ = """`(m) np.ndarray`
Field number 2"""
Unique2DIntegers.counts.__doc__ = """`(n) np.ndarray`
Field number 3"""
[docs]
class ComputedMeshData(ComputedData):
"""A class to hold computed-mesh-data.
Subclassed to keep its own dependency info.
"""
pass