Source code for iodata.test.common

# IODATA is an input and output module for quantum chemistry.
# Copyright (C) 2011-2019 The IODATA Development Team
#
# This file is part of IODATA.
#
# IODATA is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 3
# of the License, or (at your option) any later version.
#
# IODATA is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, see <http://www.gnu.org/licenses/>
# --
"""Utilities for unit tests."""

import os
from contextlib import contextmanager

import numpy as np
from numpy.testing import assert_equal, assert_allclose
import pytest

from ..api import load_one
from ..overlap import compute_overlap
from ..basis import convert_conventions
from ..utils import FileFormatWarning

try:
    from importlib_resources import path
except ImportError:
    from importlib.resources import path


__all__ = ['compute_mulliken_charges', 'compute_1rdm',
           'compare_mols', 'check_orthonormal', 'load_one_warning']


[docs]def compute_1rdm(iodata): """Compute 1-RDM.""" coeffs, occs = iodata.mo.coeffs, iodata.mo.occs dm = np.dot(coeffs * occs, coeffs.T) return dm
[docs]def compute_mulliken_charges(iodata): """Compute Mulliken charges.""" dm = compute_1rdm(iodata) ov = compute_overlap(iodata.obasis, iodata.atcoords) # compute basis function population matrix bp = np.sum(np.multiply(dm, ov), axis=1) # find basis functions center basis_center = [] for shell in iodata.obasis.shells: basis_center.extend([shell.icenter] * shell.nbasis) basis_center = np.array(basis_center) # compute atomic populations populations = np.array([np.sum(bp[basis_center == index]) for index in range(iodata.natom)]) return iodata.atcorenums - np.array(populations)
[docs]@contextmanager def truncated_file(fn_orig, nline, nadd, tmpdir): """Make a temporary truncated copy of a file. Parameters ---------- fn_orig : str The file to be truncated. nline : int The number of lines to retain. nadd : int The number of empty lines to add. tmpdir : str A temporary directory where the truncated file is stored. """ fn_truncated = '%s/truncated_%i_%s' % ( tmpdir, nline, os.path.basename(fn_orig)) with open(fn_orig) as f_orig, open(fn_truncated, 'w') as f_truncated: for counter, line in enumerate(f_orig): if counter >= nline: break f_truncated.write(line) for _ in range(nadd): f_truncated.write('\n') yield fn_truncated
[docs]def compare_mols(mol1, mol2, atol=1.0e-8, rtol=0.0): """Compare two IOData objects.""" assert mol1.title == mol2.title assert_equal(mol1.atnums, mol2.atnums) assert_equal(mol1.atcorenums, mol2.atcorenums) assert_allclose(mol1.atcoords, mol2.atcoords, atol=1e-10) # orbital basis if mol1.obasis is not None: # compare dictionaries assert len(mol1.obasis.shells) == len(mol2.obasis.shells) for shell1, shell2 in zip(mol1.obasis.shells, mol2.obasis.shells): assert shell1.icenter == shell2.icenter assert_equal(shell1.angmoms, shell2.angmoms) assert shell1.kinds == shell2.kinds assert_allclose(shell1.exponents, shell2.exponents, atol=atol, rtol=rtol) assert_allclose(shell1.coeffs, shell2.coeffs, atol=atol, rtol=rtol) assert mol1.obasis.primitive_normalization == mol2.obasis.primitive_normalization # compute and compare Mulliken charges charges1 = compute_mulliken_charges(mol1) charges2 = compute_mulliken_charges(mol2) assert_allclose(charges1, charges2, atol=atol, rtol=rtol) else: assert mol2.obasis is None # wfn perm, sgn = convert_conventions(mol1.obasis, mol2.obasis.conventions) assert mol1.mo.kind == mol2.mo.kind assert_allclose(mol1.mo.occs, mol2.mo.occs, atol=atol, rtol=rtol) assert_allclose(mol1.mo.coeffs[perm] * sgn.reshape(-1, 1), mol2.mo.coeffs, atol=atol, rtol=rtol) assert_allclose(mol1.mo.energies, mol2.mo.energies, atol=atol, rtol=rtol) assert_equal(mol1.mo.irreps, mol2.mo.irreps) # operators and density matrices cases = [ ('one_ints', ['olp', 'kin_ao', 'na_ao']), ('two_ints', ['er_ao']), ('one_rdms', ['scf', 'scf_spin', 'post_scf', 'post_scf_spin']), ] for attrname, keys in cases: d1 = getattr(mol1, attrname) d2 = getattr(mol2, attrname) for key in keys: if key in d1: assert key in d2 matrix1 = d1[key] matrix1 = matrix1[perm] * sgn.reshape(-1, 1) matrix1 = matrix1[:, perm] * sgn matrix2 = d2[key] np.testing.assert_equal(matrix1, matrix2) else: assert key not in d2
[docs]def check_orthonormal(mo_coeffs, ao_overlap, atol=1e-5): """Check that molecular orbitals are orthogonal and normalized. Parameters ---------- mo_coeffs : np.ndarray, shape=(nbasis, mo_count) Molecular orbital coefficients. ao_overlap : np.ndarray, shape=(nbasis, nbasis) Atomic orbital overlap matrix. atol : float Absolute tolerance in deviation from identity matrix. """ # compute MO overlap & number of MO orbitals mo_overlap = np.dot(mo_coeffs.T, np.dot(ao_overlap, mo_coeffs)) mo_count = mo_coeffs.shape[1] message = 'Molecular orbitals are not orthonormal!' assert_allclose(mo_overlap, np.eye(mo_count), rtol=0., atol=atol, err_msg=message)
[docs]def load_one_warning(filename: str, fmt: str = None, match: str = None, **kwargs): """Call load_one, catching expected FileFormatWarning. Parameters ---------- filename The file in the unit test data directory to load. fmt The name of the file format module to use. When not given, it is guessed from the filename. match When given, loading the file is expected to raise a warning whose message string contains match. **kwargs Keyword arguments are passed on to the format-specific load_one function. Returns ------- out The instance of IOData with data loaded from the input files. """ with path('iodata.test.data', filename) as fn: if match is None: return load_one(str(fn), fmt, **kwargs) with pytest.warns(FileFormatWarning, match=match): return load_one(str(fn), fmt, **kwargs)