# IODATA is an input and output module for quantum chemistry.
# Copyright (C) 2011-2019 The IODATA Development Team
#
# This file is part of IODATA.
#
# IODATA is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 3
# of the License, or (at your option) any later version.
#
# IODATA is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, see <http://www.gnu.org/licenses/>
# --
"""Functions to be used by end users."""
import os
import warnings
from collections.abc import Iterable, Iterator
from fnmatch import fnmatch
from importlib import import_module
from pkgutil import iter_modules
from types import ModuleType
from typing import Callable, Optional
from .iodata import IOData
from .utils import (
DumpError,
FileFormatError,
LineIterator,
LoadError,
PrepareDumpError,
WriteInputError,
)
__all__ = ("load_one", "load_many", "dump_one", "dump_many", "write_input")
def _find_format_modules():
"""Return all file-format modules found with importlib."""
result = {}
for module_info in iter_modules(import_module("iodata.formats").__path__):
if not module_info.ispkg:
format_module = import_module("iodata.formats." + module_info.name)
if hasattr(format_module, "PATTERNS"):
result[module_info.name] = format_module
return result
FORMAT_MODULES = _find_format_modules()
def _select_format_module(filename: str, attrname: str, fmt: Optional[str] = None) -> ModuleType:
"""Find a file format module with the requested attribute name.
Parameters
----------
filename
The file to load or dump.
attrname
The required attribute of the file format module.
fmt
The name of the file format module to use. When not given, it is guessed
from the filename.
Returns
-------
The module implementing the required file format.
Raises
------
FileFormatError
When no file format module can be found that has a member named ``attrname``.
"""
basename = os.path.basename(filename)
if fmt is None:
for format_module in FORMAT_MODULES.values():
if any(fnmatch(basename, pattern) for pattern in format_module.PATTERNS) and hasattr(
format_module, attrname
):
return format_module
raise FileFormatError(f"Cannot find file format with feature {attrname}", filename)
if fmt in FORMAT_MODULES:
format_module = FORMAT_MODULES[fmt]
if not hasattr(format_module, attrname):
raise FileFormatError(f"Format {fmt} does not support feature {attrname}", filename)
return format_module
raise FileFormatError(f"Unknown file format {fmt}", filename)
def _find_input_modules():
"""Return all input modules found with importlib."""
result = {}
for module_info in iter_modules(import_module("iodata.inputs").__path__):
if not module_info.ispkg:
input_module = import_module("iodata.inputs." + module_info.name)
if hasattr(input_module, "write_input"):
result[module_info.name] = input_module
return result
INPUT_MODULES = _find_input_modules()
def _select_input_module(filename: str, fmt: str) -> ModuleType:
"""Find an input module.
Parameters
----------
filename
The file to be written to, only used for error messages.
fmt
The name of the input module to use.
Returns
-------
The module implementing the required input format.
Raises
------
FileFormatError
When the format ``fmt`` does not exist.
"""
if fmt in INPUT_MODULES:
if not hasattr(INPUT_MODULES[fmt], "write_input"):
raise FileFormatError(f"{fmt} input module does not have write_input.", filename)
return INPUT_MODULES[fmt]
raise FileFormatError(f"Cannot find input format {fmt}.", filename)
def _reissue_warnings(func):
"""Correct stacklevel of warnings raised in functions called deeper in IOData.
This function should be used as a decorator of end-user API functions.
Adapted from https://stackoverflow.com/a/71635963/494584
"""
def inner(*args, **kwargs):
"""Wrapper for func that reissues warnings."""
warning_list = []
try:
with warnings.catch_warnings(record=True) as warning_list:
result = func(*args, **kwargs)
finally:
for warning in warning_list:
warnings.warn(warning.message, warning.category, stacklevel=2)
return result
return inner
[docs]
@_reissue_warnings
def load_one(filename: str, *, fmt: Optional[str] = None, **kwargs) -> IOData:
"""Load data from a file.
This function uses the extension or prefix of the filename to determine the
file format. When the file format is detected, a specialized load function
is called for the heavy lifting.
Parameters
----------
filename
The file to load data from.
fmt
The name of the file format module to use. When not given, it is guessed
from the filename.
**kwargs
Keyword arguments are passed on to the format-specific load_one function.
Returns
-------
The instance of IOData with data loaded from the input files.
"""
format_module = _select_format_module(filename, "load_one", fmt)
with LineIterator(filename) as lit:
try:
return IOData(**format_module.load_one(lit, **kwargs))
except LoadError:
raise
except StopIteration as exc:
raise LoadError("File ended before all data was read.", lit) from exc
except Exception as exc:
raise LoadError("Uncaught exception while loading file.", lit) from exc
[docs]
@_reissue_warnings
def load_many(filename: str, *, fmt: Optional[str] = None, **kwargs) -> Iterator[IOData]:
"""Load multiple IOData instances from a file.
This function uses the extension or prefix of the filename to determine the
file format. When the file format is detected, a specialized load function
is called for the heavy lifting.
Parameters
----------
filename
The file to load data from.
fmt
The name of the file format module to use. When not given, it is guessed
from the filename.
**kwargs
Keyword arguments are passed on to the format-specific load_many function.
Yields
------
IOData
An instance of IOData with data for one frame loaded for the file.
"""
format_module = _select_format_module(filename, "load_many", fmt)
with LineIterator(filename) as lit:
try:
for data in format_module.load_many(lit, **kwargs):
yield IOData(**data)
except StopIteration:
return
except LoadError:
raise
except Exception as exc:
raise LoadError("Uncaught exception while loading file.", lit) from exc
def _check_required(filename: str, data: IOData, dump_func: Callable):
"""Check that required attributes are not None before dumping to a file.
Parameters
----------
filename
The file to be dumped to, only used for error messages.
data
The data to be written.
dump_func
The dump_one or dump_many function that will write the file.
Raises
------
PrepareDumpError
When a required attribute is ``None``.
"""
for attr_name in dump_func.required:
if getattr(data, attr_name) is None:
raise PrepareDumpError(
f"Required attribute {attr_name}, for format {dump_func.fmt}, is None.", filename
)
[docs]
@_reissue_warnings
def dump_one(
data: IOData,
filename: str,
*,
fmt: Optional[str] = None,
allow_changes: bool = False,
**kwargs,
):
"""Write data to a file.
This routine uses the extension or prefix of the filename to determine
the file format. For each file format, a specialized function is
called that does the real work.
Parameters
----------
data
The object containing the data to be written.
filename
The file to write the data to.
fmt
The name of the file format module to use. When not given, it is guessed
from the filename.
allow_changes
Whether conversion of the IOData object to a compatible form is allowed or not.
**kwargs
Keyword arguments are passed on to the format-specific dump_one function.
Returns
-------
data
The given ``IOData`` object or a shallow copy with some new attributes if converted.
Raises
------
DumpError
When an error is encountered while dumping to a file.
If the output file already existed, it is (partially) overwritten.
PrepareDumpError
When the ``IOData`` object is not compatible with the file format,
e.g. due to missing attributes, and no conversion is available or allowed
to make it compatible.
If the output file already existed, it is not overwritten.
PrepareDumpWarning
When the ``IOData`` object is not compatible with the file format,
but it was converted to fix the compatibility issue.
"""
format_module = _select_format_module(filename, "dump_one", fmt)
try:
_check_required(filename, data, format_module.dump_one)
if hasattr(format_module, "prepare_dump"):
data = format_module.prepare_dump(data, allow_changes, filename)
except PrepareDumpError:
raise
except Exception as exc:
raise PrepareDumpError(
"Uncaught exception while preparing for dumping to a file.", filename
) from exc
with open(filename, "w") as f:
try:
format_module.dump_one(f, data, **kwargs)
except DumpError:
raise
except Exception as exc:
raise DumpError("Uncaught exception while dumping to a file", filename) from exc
return data
[docs]
@_reissue_warnings
def dump_many(
iter_data: Iterable[IOData],
filename: str,
*,
fmt: Optional[str] = None,
allow_changes: bool = False,
**kwargs,
):
"""Write multiple IOData instances to a file.
This routine uses the extension or prefix of the filename to determine
the file format. For each file format, a specialized function is
called that does the real work.
Parameters
----------
iter_data
An iterator over IOData instances.
filename
The file to write the data to.
fmt
The name of the file format module to use.
allow_changes
Whether conversion of the IOData object to a compatible form is allowed or not.
**kwargs
Keyword arguments are passed on to the format-specific dump_many function.
Raises
------
DumpError
When an error is encountered while dumping to a file.
If the output file already existed, it (partially) overwritten.
PrepareDumpError
When an ``IOData`` object is not compatible with the file format,
e.g. due to missing attributes, and no conversion is available or allowed
to make it compatible.
If the output file already existed, it is not overwritten when this error
is raised while processing the first ``IOData`` instance in the ``datas`` argument.
When the exception is raised in later iterations, any existing file is overwritten.
PrepareDumpWarning
When an ``IOData`` object is not compatible with the file format,
but it was converted to fix the compatibility issue.
"""
format_module = _select_format_module(filename, "dump_many", fmt)
# Make sure we have an iterator. For example, this turns a list into an iterator.
iter_data = iter(iter_data)
# Check the first item before creating the file.
# If the file already exists, this may prevent data loss:
# The file is not overwritten when it is clear that writing will fail.
try:
first = next(iter_data)
except StopIteration as exc:
raise DumpError("dump_many needs at least one IOData object.", filename) from exc
try:
_check_required(filename, first, format_module.dump_many)
if hasattr(format_module, "prepare_dump"):
first = format_module.prepare_dump(first, allow_changes, filename)
except PrepareDumpError:
raise
except Exception as exc:
raise PrepareDumpError(
"Uncaught exception while preparing for dumping to a file.", filename
) from exc
def checking_iterator():
"""Iterate over all data items, not checking the first."""
# The first one was already checked.
yield first
for other in iter_data:
_check_required(filename, other, format_module.dump_many)
yield (
format_module.prepare_dump(other, allow_changes, filename)
if hasattr(format_module, "prepare_dump")
else other
)
with open(filename, "w") as f:
try:
format_module.dump_many(f, checking_iterator(), **kwargs)
except (PrepareDumpError, DumpError):
raise
except Exception as exc:
raise DumpError("Uncaught exception while dumping to a file.", filename) from exc