Source code for fmflow.core.array.decorators

# coding: utf-8

# public items
__all__ = [
    "chunk",
    "xarrayfunc",
]

# standard library
from concurrent.futures import ProcessPoolExecutor as Pool
from functools import wraps
from inspect import Parameter, signature, stack
from multiprocessing import cpu_count
from sys import _getframe as getframe

# dependent packages
import fmflow as fm
import numpy as np
import xarray as xr

# module constants
DEFAULT_N_CHUNKS = 1
try:
    MAX_WORKERS = cpu_count() - 1
except:
    MAX_WORKERS = 1


# decorators
[docs]def xarrayfunc(func): """Make a function compatible with xarray.DataArray. This function is intended to be used as a decorator like:: >>> @fm.xarrayfunc >>> def func(array): ... # do something ... return newarray >>> >>> result = func(array) Args: func (function): A function to be wrapped. The first argument of the function must be an array to be processed. Returns: wrapper (function): A wrapped function. """ @wraps(func) def wrapper(*args, **kwargs): if any(isinstance(arg, xr.DataArray) for arg in args): newargs = [] for arg in args: if isinstance(arg, xr.DataArray): newargs.append(arg.values) else: newargs.append(arg) return fm.full_like(args[0], func(*newargs, **kwargs)) else: return func(*args, **kwargs) return wrapper
[docs]def chunk(*argnames, concatfunc=None): """Make a function compatible with multicore chunk processing. This function is intended to be used as a decorator like:: >>> @fm.chunk('array') >>> def func(array): ... # do something ... return newarray >>> >>> result = func(array, timechunk=10) or you can set a global chunk parameter outside the function:: >>> timechunk = 10 >>> result = func(array) """ def _chunk(func): depth = [s.function for s in stack()].index("<module>") f_globals = getframe(depth).f_globals # original (unwrapped) function orgname = "_original_" + func.__name__ orgfunc = fm.utils.copy_function(func, orgname) f_globals[orgname] = orgfunc @wraps(func) def wrapper(*args, **kwargs): depth = [s.function for s in stack()].index("<module>") f_globals = getframe(depth).f_globals # parse args and kwargs params = signature(func).parameters for i, (key, val) in enumerate(params.items()): if not val.kind == Parameter.POSITIONAL_OR_KEYWORD: break try: kwargs.update({key: args[i]}) except IndexError: kwargs.setdefault(key, val.default) # n_chunks and n_processes n_chunks = DEFAULT_N_CHUNKS n_processes = MAX_WORKERS multiprocess = True if argnames: length = len(kwargs[argnames[0]]) if "numchunk" in kwargs: n_chunks = kwargs.pop("numchunk") elif "timechunk" in kwargs: n_chunks = round(length / kwargs.pop("timechunk")) elif "numchunk" in f_globals: n_chunks = f_globals["numchunk"] elif "timechunk" in f_globals: n_chunks = round(length / f_globals["timechunk"]) if "n_processes" in kwargs: n_processes = kwargs.pop("n_processes") elif "n_processes" in f_globals: n_processes = f_globals["n_processes"] if "multiprocess" in kwargs: multiprocess = kwargs.pop("multiprocess") elif "multiprocess" in f_globals: multiprocess = f_globals["multiprocess"] # make chunked args chunks = {} for name in argnames: arg = kwargs.pop(name) try: chunks.update({name: np.array_split(arg, n_chunks)}) except TypeError: chunks.update({name: np.tile(arg, n_chunks)}) # run the function futures = [] results = [] if multiprocess: with fm.utils.one_thread_per_process(), Pool(n_processes) as p: for i in range(n_chunks): chunk = {key: val[i] for key, val in chunks.items()} futures.append(p.submit(orgfunc, **{**chunk, **kwargs})) for future in futures: results.append(future.result()) else: for i in range(n_chunks): chunk = {key: val[i] for key, val in chunks.items()} results.append(orgfunc(**{**chunk, **kwargs})) # make an output if concatfunc is not None: return concatfunc(results) try: return xr.concat(results, "t") except TypeError: return np.concatenate(results, 0) return wrapper return _chunk