Source code for iodata.attrutils

# IODATA is an input and output module for quantum chemistry.
# Copyright (C) 2011-2019 The IODATA Development Team
#
# This file is part of IODATA.
#
# IODATA is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 3
# of the License, or (at your option) any later version.
#
# IODATA is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, see <http://www.gnu.org/licenses/>
# --
"""Utilities for building attr classes."""


import numpy as np


__all__ = ["convert_array_to", "validate_shape"]


[docs]def convert_array_to(dtype): """Return a function to convert arrays to the given type.""" def converter(array): if array is None: return None return np.array(array, copy=False, dtype=dtype) return converter
# pylint: disable=too-many-branches
[docs]def validate_shape(*shape_requirements: tuple): """Return a validator for the shape of an array or the length of an iterable. Parameters ---------- shape_requirements Specifications for the required shape. Every item of the tuple describes the required size of the corresponding axis of an array. Also the number of items should match the dimensionality of the array. When the validator is used for general iterables, this tuple should contain just one element. Possible values for each item are explained in the "Notes" section below. Returns ------- validator A validator function for the attr library. Notes ----- Every element of ``shape_requirements`` defines the expected size of an array along the corresponding axis. An item in this tuple at position (or index) ``i`` can be one of the following: 1. An integer, which is taken as the expected size along axis ``i``. 2. None. In this case, the size of the array along axis ``i`` is not checked. 3. A string, which should be the name of another integer attribute with the expected size along axis ``i``. The other attribute is always an attribute of the same object as the attribute being checked. 4. A 2-tuple containing a name and an integer. In this case, the name refers to another attribute which is an array or an iterable. When the integer is 0, just the length of the other attribute is used. When the integer is non-zero, the other attribute must be an array and the integer selects an axis. The size of the other array along the selected axis is then used as the expected size of the array being checked along axis ``i``. """ def validator(obj, attribute, value): # Build the expected shape, with the rules from the docstring. expected_shape = [] for item in shape_requirements: if isinstance(item, int) or item is None: expected_shape.append(item) elif isinstance(item, str): expected_shape.append(getattr(obj, item)) elif isinstance(item, tuple) and len(item) == 2: other_name, other_axis = item other = getattr(obj, other_name) if other is None: raise TypeError( "Other attribute '{}' is not set.".format(other_name) ) if other_axis == 0: expected_shape.append(len(other)) else: if other_axis >= other.ndim or other_axis < 0: raise TypeError( "Cannot get length along axis " "{} of attribute {} with ndim {}.".format( other_axis, other_name, other.ndim ) ) expected_shape.append(other.shape[other_axis]) else: raise ValueError(f"Cannot interpret item in shape_requirements: {item}") expected_shape = tuple(expected_shape) # Get the actual shape if isinstance(value, np.ndarray): observed_shape = value.shape else: observed_shape = (len(value),) # Compare match = True if len(expected_shape) != len(observed_shape): match = False if match: for es, os in zip(expected_shape, observed_shape): if es is None: continue if es != os: match = False break # Raise TypeError if needed. if not match: raise TypeError( "Expecting shape {} for attribute {}, got {}".format( expected_shape, attribute.name, observed_shape ) ) return validator