# IODATA is an input and output module for quantum chemistry.
# Copyright (C) 2011-2019 The IODATA Development Team
#
# This file is part of IODATA.
#
# IODATA is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 3
# of the License, or (at your option) any later version.
#
# IODATA is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, see <http://www.gnu.org/licenses/>
# --
"""Utility functions module."""
from io import TextIOBase
from pathlib import Path
from typing import Optional, TextIO, Union
import attrs
import numpy as np
import scipy.constants as spc
from numpy.typing import NDArray
from scipy.linalg import eigh
from .attrutils import validate_shape
__all__ = (
"LineIterator",
"FileFormatError",
"LoadError",
"DumpError",
"PrepareDumpError",
"WriteInputError",
"LoadWarning",
"DumpWarning",
"PrepareDumpWarning",
"Cube",
"set_four_index_element",
"volume",
"derive_naturals",
"check_dm",
"strtobool",
)
# The unit conversion factors below can be used as follows:
# - Conversion to atomic units: distance = 5*angstrom
# - Conversion from atomic units: print(distance/angstrom)
angstrom: float = spc.angstrom / spc.value("atomic unit of length")
electronvolt: float = 1 / spc.value("hartree-electron volt relationship")
# Unit conversion for Gromacs gro files
meter: float = 1 / spc.value("Bohr radius")
nanometer: float = 1e-9 * meter
second: float = 1 / spc.value("atomic unit of time")
picosecond: float = 1e-12 * second
# atomic mass unit (not atomic unit of mass!)
amu: float = 1e-3 / (spc.value("electron mass") * spc.value("Avogadro constant"))
kcalmol: float = 1e3 * spc.calorie / spc.value("Avogadro constant") / spc.value("Hartree energy")
calmol: float = spc.calorie / spc.value("Avogadro constant") / spc.value("Hartree energy")
kjmol: float = 1e3 / spc.value("Avogadro constant") / spc.value("Hartree energy")
[docs]
class LineIterator:
"""Iterator class for looping over lines and keeping track of the line number.
Use this class as a context manager, similar to the built-in ``open`` function:
.. code-block:: python
with LineIterator("filename.ext") as lit:
for line in lit:
...
"""
[docs]
def __init__(self, filename: str):
"""Initialize a LineIterator.
Parameters
----------
filename
The file that will be read.
"""
self.filename = filename
self.fh = None
self.lineno = 0
self.stack = []
def __enter__(self):
self.fh = open(self.filename)
return self
def __exit__(self, exc_type, exc_value, traceback):
self.fh.close()
def __iter__(self):
return self
def __next__(self):
"""Return the next line and increase the lineno attribute by one."""
self.lineno += 1
return self.stack.pop() if self.stack else next(self.fh)
[docs]
def back(self, line):
"""Go back one line in the file and decrease the lineno attribute by one."""
self.stack.append(line)
self.lineno -= 1
def _interpret_file_lineno(
file: Optional[Union[str, Path, LineIterator, TextIO]] = None, lineno: Optional[int] = None
) -> tuple[Optional[str], Optional[int]]:
"""Interpret the file and lineno arguments given to Error and Warning constructors.
Parameters
----------
file
Object to deduce the filename (and optionally line number) from.
lineno
Line number, if known and not (correctly) included in the file object.
Returns
-------
filename
The filename associated with the file object.
lineno
The line number.
"""
if isinstance(file, str):
return file, lineno
if isinstance(file, Path):
return str(file), lineno
if isinstance(file, LineIterator):
if lineno is None:
lineno = file.lineno
return file.filename, lineno
if isinstance(file, TextIOBase):
return file.name, lineno
if file is None:
if lineno is not None:
raise TypeError("A line number without a file is not supported.")
return None, None
raise TypeError(f"Types of file and lineno are not supported: {file}, {lineno}")
def _format_file_message(message: str, filename: Optional[str], lineno: Optional[int]) -> str:
"""Format the message of an exception.
Parameters
----------
message
The actual error or warning message, without filename or line number info.
filename
The filename to which the error or warning is related.
lineno
The line number associated with the error or warning.
Returns
-------
full_message
The error message formated with filename and line number info.
"""
if filename is None:
return message
if lineno is None:
return f"{message} ({filename})"
return f"{message} ({filename}:{lineno})"
[docs]
class BaseFileError(Exception):
"""Base class for all errors related to loading or dumping files."""
[docs]
def __init__(
self,
message,
file: Optional[Union[str, Path, LineIterator, TextIO]] = None,
lineno: Optional[int] = None,
):
super().__init__(message)
self.filename, self.lineno = _interpret_file_lineno(file, lineno)
def __str__(self):
return _format_file_message(super().__str__(), self.filename, self.lineno)
[docs]
class LoadError(BaseFileError):
"""Raised when an error is encountered while loading from a file."""
[docs]
class DumpError(BaseFileError):
"""Raised when an error is encountered while dumping to a file."""
[docs]
class PrepareDumpError(BaseFileError):
"""Raised when an IOData object is incompatible with a format before dumping to a file."""
[docs]
class BaseFileWarning(Warning):
"""Base class for all warnings related to loading or dumping files."""
[docs]
def __init__(
self,
message,
file: Optional[Union[str, Path, LineIterator, TextIO]] = None,
lineno: Optional[int] = None,
):
filename, lineno = _interpret_file_lineno(file, lineno)
super().__init__(_format_file_message(message, filename, lineno))
[docs]
class LoadWarning(BaseFileWarning):
"""Raised when incorrect content is encountered and fixed when loading from a file."""
[docs]
class DumpWarning(BaseFileWarning):
"""Raised when an IOData object is made compatible with a format when dumping to a file."""
[docs]
class PrepareDumpWarning(BaseFileWarning):
"""Raised when an IOData object is made compatible with a format before dumping to a file."""
[docs]
@attrs.define
class Cube:
"""The volumetric data from a cube (or similar) file."""
origin: NDArray[float] = attrs.field(validator=validate_shape(3))
"""A 3D vector with the origin of the axes frame."""
axes: NDArray[float] = attrs.field(validator=validate_shape(3, 3))
"""
A (3, 3) array where each row represents the spacing between two neighboring grid points
along the first, second and third axis, respectively.
"""
data: NDArray[float] = attrs.field(validator=validate_shape(None, None, None))
"""A (K, L, M) array of data on a uniform grid"""
@property
def shape(self):
"""Shape of the rectangular grid."""
return self.data.shape
[docs]
def set_four_index_element(
four_index_object: NDArray[float], i0: int, i1: int, i2: int, i3: int, value: float
):
"""Assign values to a four index object, account for 8-fold index symmetry.
This function assumes physicists' notation.
Parameters
----------
four_index_object
The four-index object. It will be written to.
shape=(nbasis, nbasis, nbasis, nbasis), dtype=float
i0, i1, i2, i3
The indices to assign to.
value
The value of the matrix element to store.
"""
four_index_object[i0, i1, i2, i3] = value
four_index_object[i1, i0, i3, i2] = value
four_index_object[i2, i1, i0, i3] = value
four_index_object[i0, i3, i2, i1] = value
four_index_object[i2, i3, i0, i1] = value
four_index_object[i3, i2, i1, i0] = value
four_index_object[i1, i2, i3, i0] = value
four_index_object[i3, i0, i1, i2] = value
[docs]
def volume(cellvecs: NDArray[float]) -> float:
"""Calculate the (generalized) cell volume.
Parameters
----------
cellvecs
A numpy matrix of shape (x,3) where x is in {1,2,3}.
Each row is one cellvector.
Returns
-------
In case of 3D, the cell volume.
In case of 2D, the cell area.
In case of 1D, the cell length.
"""
nvecs = cellvecs.shape[0]
if len(cellvecs.shape) == 1 or nvecs == 1:
return np.linalg.norm(cellvecs)
if nvecs == 2:
return np.linalg.norm(np.cross(cellvecs[0], cellvecs[1]))
if nvecs == 3:
return np.linalg.det(cellvecs)
raise ValueError("Argument cellvecs should be of shape (x, 3), where x is in {1, 2, 3}")
[docs]
def derive_naturals(
dm: NDArray[float], overlap: NDArray[float]
) -> tuple[NDArray[float], NDArray[float]]:
"""Derive natural orbitals from a given density matrix.
Parameters
----------
dm
The density matrix.
shape=(nbasis, nbasis)
overlap
The overlap matrix
shape=(nbasis, nbasis)
Returns
-------
coeffs
Orbital coefficients
shape=(nbasis, nfn)
occs
Orbital occupations
shape=(nfn, )
"""
# Transform density matrix to Fock-like form
sds = np.dot(overlap.T, np.dot(dm, overlap))
# Diagonalize and compute eigenvalues
evals, evecs = eigh(sds, overlap)
coeffs = np.zeros_like(overlap)
coeffs = evecs[:, : coeffs.shape[1]]
occs = evals
return coeffs, occs
[docs]
def check_dm(dm: NDArray[float], overlap: NDArray[float], eps: float = 1e-4, occ_max: float = 1.0):
"""Check if the density matrix has eigenvalues in the proper range.
Parameters
----------
dm
The density matrix
shape=(nbasis, nbasis), dtype=float
overlap
The overlap matrix
shape=(nbasis, nbasis), dtype=float
eps
The threshold on the eigenvalue inequalities.
occ_max
The maximum occupation.
Raises
------
ValueError
When the density matrix has wrong eigenvalues.
"""
# construct natural orbitals
occupations = derive_naturals(dm, overlap)[1]
if occupations.min() < -eps:
raise ValueError(
"The density matrix has eigenvalues considerably smaller than "
f"zero. error={occupations.min():e}"
)
if occupations.max() > occ_max + eps:
raise ValueError(
"The density matrix has eigenvalues considerably larger than "
"max. error=%e" % (occupations.max() - 1)
)
STRTOBOOL = {
"y": True,
"yes": True,
"t": True,
"true": True,
"on": True,
"1": True,
"n": False,
"no": False,
"f": False,
"false": False,
"off": False,
"0": False,
}
[docs]
def strtobool(value: str) -> bool:
"""Interpret string as a boolean."""
result = STRTOBOOL.get(value.lower())
if result is None:
raise ValueError(f"'{value}' cannot be converted to boolean")
return result