Source code for reacnetgenerator.utils

# SPDX-License-Identifier: LGPL-3.0-or-later
# cython: language_level=3
# cython: linetrace=True
"""Provide utils for ReacNetGenerator."""

import asyncio
import hashlib
import itertools
import os
import pickle
import shutil
from contextlib import ExitStack
from multiprocessing import Pool, Semaphore
from typing import (
    IO,
    TYPE_CHECKING,
    Any,
    AnyStr,
    BinaryIO,
    Callable,
    Generator,
    Iterable,
    List,
    Optional,
    Tuple,
    Union,
)

import lz4.frame
import numpy as np
import requests
from requests.adapters import HTTPAdapter
from tqdm.auto import tqdm

from ._logging import logger

if TYPE_CHECKING:
    import multiprocessing.pool
    import multiprocessing.synchronize

    import reacnetgenerator


[docs] class WriteBuffer: """Store a buffer for writing files. It is expensive to write to a file, so we need to make a buffer. Parameters ---------- f: fileObject The file object to write. linenumber: int, default: 1200 The number of contents to store in the buffer. The buffer will be flushed if it exceeds the set number. sep: str or bytes, default: None The separator for contents. If None (default), there will be no separator. """ def __init__( self, f: IO, linenumber: int = 1200, sep: Optional[AnyStr] = None ) -> None: self.f = f if sep is not None: self.sep = sep elif f.mode == "w": self.sep = "" elif f.mode == "wb": self.sep = b"" else: raise RuntimeError("File mode should be w or wb!") self.linenumber = linenumber self.buff = [] self.name = self.f.name
[docs] def append(self, text: AnyStr) -> None: """Append a text. Parameters ---------- text : str or bytes The text to be appended. """ self.buff.append(text) self.check()
[docs] def extend(self, text: Iterable[AnyStr]) -> None: """Extend texts. Parameters ---------- text : list of strs or bytes Texts to be extended. """ self.buff.extend(text) self.check()
[docs] def check(self) -> None: """Check if the number of stored contents exceeds. If so, the buffer will be flushed. """ if len(self.buff) > self.linenumber: self.flush()
[docs] def flush(self) -> None: """Flush the buffer.""" if self.buff: self.f.writelines([self.sep.join(self.buff), self.sep]) self.buff[:] = []
def __enter__(self) -> "WriteBuffer": """Enter the context.""" return self def __exit__(self, exc_type, exc_value, traceback): """Exit the context.""" self.flush() self.f.__exit__(exc_type, exc_value, traceback)
[docs] def appendIfNotNone(f: Union[WriteBuffer, ExitStack], wbytes: Optional[AnyStr]) -> None: """Append a line to a file if the line is not None. Parameters ---------- f : WriteBuffer The file to write. wbytes : str or bytes The line to write. """ if wbytes is not None: assert not isinstance(f, ExitStack) f.append(wbytes)
[docs] def produce( semaphore: "multiprocessing.synchronize.Semaphore", plist: Iterable[Any], parameter: Any, ) -> Generator[Tuple[Any, Any], None, None]: """Item producer with a semaphore. Prevent large memory usage due to slow IO. Parameters ---------- semaphore : multiprocessing.Semaphore The semaphore to acquire. plist : list of objects The list of items to be passed. parameter : object The parameter yielded with each item. Yields ------ item: object The item to be yielded. parameter: object The parameter yielded with each item. """ for item in plist: semaphore.acquire() if parameter is not None: item = (item, parameter) yield item
[docs] def compress(x: Union[str, bytes]) -> bytes: """Compress the line. This function reduces IO overhead to speed up the program. The functions will use lz4 to compress, since lz4 has better performance that any others. The compressed format is size + data + size + data + ..., where size is a 64-bit little-endian integer. Parameters ---------- x : str or bytes The line to compress. Returns ------- bytes The compressed line, with a linebreak in the end. """ if isinstance(x, str): x = x.encode() compress_block = lz4.frame.compress(x, compression_level=0) length_bytes = len(compress_block).to_bytes(64, byteorder="little") return length_bytes + compress_block
[docs] def decompress(x: bytes, isbytes: bool = False) -> Union[str, bytes]: """Decompress the line. Parameters ---------- x : bytes The line to decompress. isbytes : bool, optional, default: False If the decompressed content is bytes. If not, the line will be decoded. Returns ------- str or bytes The decompressed line. """ x = lz4.frame.decompress(x[64:]) if isbytes: return x return x.decode()
[docs] def listtobytes(x: Any) -> bytes: """Convert an object to a compressed line. Parameters ---------- x : object The object to convert, such as numpy.ndarray. Returns ------- bytes The compressed line. """ return compress(pickle.dumps(x))
[docs] def read_compressed_block(f: BinaryIO) -> Generator[bytes, None, None]: """Read compressed binary file, assuming the format is size + data + size + data + ... Parameters ---------- f : fileObject The file object to read. Yields ------ data: bytes The compressed block. """ while True: sizeb = f.read(64) if not sizeb: break size = int.from_bytes(sizeb, byteorder="little") yield sizeb + f.read(size)
[docs] def bytestolist(x: bytes) -> Any: """Convert a compressed line to an object. Parameters ---------- x : bytes The compressed line. Returns ------- object The decompressed object. """ data = decompress(x, isbytes=True) assert isinstance(data, bytes) return pickle.loads(data)
[docs] def listtostirng( l: Union[str, list, tuple, np.ndarray], sep: Union[List[str], Tuple[str, ...]] ) -> str: """Convert a list to string, that is easier to store. Parameters ---------- l : str or array-like The list to convert, which can contain any number of dimensions. sep : list of strs The seperators for each dimension. Returns ------- str The converted string. """ if isinstance(l, str): return l if isinstance(l, (list, tuple, np.ndarray)): return sep[0].join(listtostirng(x, sep[1:]) for x in l) return str(l)
[docs] def multiopen( pool: "multiprocessing.pool.Pool", func: Callable, l: IO, semaphore: Optional["multiprocessing.synchronize.Semaphore"] = None, nlines: Optional[int] = None, unordered: bool = True, return_num: bool = False, start: int = 0, extra: Optional[Any] = None, interval: Optional[int] = None, bar: bool = True, desc: Optional[str] = None, unit: str = "it", total: Optional[int] = None, ) -> Iterable: """Return an interated object for process a file with multiple processors. Parameters ---------- pool : multiprocessing.Pool The pool for multiprocessing. func : function The function to process lines. l : File object The file object. semaphore : multiprocessing.Semaphore, optional, default: None The semaphore to acquire. If None (default), the object will be passed without control. nlines : int, optional, default: None The number of lines to pass to the function each time. If None (default), only one line will be passed to the function. unordered : bool, optional, default: True Whether the process can be unordered. return_num : bool, optional, default: False If True, adds a counter to an iterable. start : int, optional, default: 0 The start number of the counter. extra : object, optional, default: None The extra object passed to the item. interval : int, optional, default: None The interval of items that will be passed to the function. For example, if set to 10, a item will be passed once every 10 items and others will be dropped. bar : bool, optional, default: True If True, show a tqdm bar for the iteration. desc : str, optional, default: None The description of the iteration shown in the bar. unit : str, optional, default: it The unit of the iteration shown in the bar. total : int, optional, default: None The total number of the iteration shown in the bar. Returns ------- object An object that can be iterated. """ obj = l if nlines: obj = itertools.zip_longest(*[obj] * nlines) if interval: obj = itertools.islice(obj, 0, None, interval) if return_num: obj = enumerate(obj, start) if semaphore: obj = produce(semaphore, obj, extra) if unordered: obj = pool.imap_unordered(func, obj, 100) else: obj = pool.imap(func, obj, 100) if bar: obj = tqdm(obj, desc=desc, unit=unit, total=total, disable=None) return obj
[docs] class SCOUROPTIONS: """Scour (SVG optimization) options.""" strip_xml_prolog = True remove_titles = True remove_descriptions = True remove_metadata = True remove_descriptive_elements = True strip_comments = True enable_viewboxing = True strip_xml_space_attribute = True strip_ids = True shorten_ids = True newlines = False
[docs] class SharedRNGData: """Share ReacNetGenerator data with a class of the submodule. Parameters ---------- rng: reacnetgenerator.ReacNetGenerator The centered ReacNetGenerator class. usedRNGKeys: list of strs Keys that needs to pass from ReacNetGenerator class to the submodule. returnedRNGKeys: list of strs Keys that needs to pass from the submodule to ReacNetGenerator class. extraNoneKeys: list of strs, optional, default: None Set keys to None, which will be used in the submodule. """ def __init__( self, rng: "reacnetgenerator.ReacNetGenerator", usedRNGKeys: List[str], returnedRNGKeys: List[str], extraNoneKeys: Optional[List[str]] = None, ) -> None: self.rng = rng self.returnedRNGKeys = returnedRNGKeys for key in usedRNGKeys: setattr(self, key, getattr(self.rng, key)) for key in returnedRNGKeys: setattr(self, key, None) if extraNoneKeys is not None: for key in extraNoneKeys: setattr(self, key, None)
[docs] def returnkeys(self) -> None: """Return back keys to ReacNetGenerator class.""" for key in self.returnedRNGKeys: setattr(self.rng, key, getattr(self, key))
[docs] def checksha256(filename: str, sha256_check: Union[str, List[str]]): """Check sha256 of a file is correct. Parameters ---------- filename : str The filename. sha256_check : str or list of strs The sha256 to be checked. Returns ------- bool Indicate whether sha256 is correct. """ if not os.path.isfile(filename): return h = hashlib.sha256() b = bytearray(128 * 1024) mv = memoryview(b) with open(filename, "rb", buffering=0) as f: for n in iter(lambda: f.readinto(mv), 0): h.update(mv[:n]) sha256 = h.hexdigest() logger.info(f"SHA256 of {filename}: {sha256}") if sha256 in must_be_list(sha256_check): return True logger.warning("SHA256 is not correct.") logger.warning(open(filename).read()) return False
[docs] async def download_file( urls: Union[str, List[str]], pathfilename: str, sha256: str ) -> str: """Download files from remote urls if not exists. Parameters ---------- urls: str or list of strs The url(s) that is available to download. pathfilename: str The downloading path of the file. sha256: str Sha256 of the file. If not None and match the file, the download will be skiped. Returns ------- pathfilename: str The downloading path of the file. """ s = requests.Session() s.mount("http://", HTTPAdapter(max_retries=3)) s.mount("https://", HTTPAdapter(max_retries=3)) # download if not exists if os.path.isfile(pathfilename) and ( sha256 is None or checksha256(pathfilename, sha256) ): return pathfilename # from https://stackoverflow.com/questions/16694907 for url in must_be_list(urls): logger.info(f"Try to download {pathfilename} from {url}") with s.get(url, stream=True) as r, open(pathfilename, "wb") as f: try: shutil.copyfileobj(r.raw, f) break except requests.exceptions.RequestException as e: logger.warning(f"Request {pathfilename} Error.", exc_info=e) else: raise RuntimeError(f"Cannot download {pathfilename}.") return pathfilename
[docs] async def gather_download_files(urls: List[dict]) -> None: """Asynchronously download files from remote urls if not exists. See download_multifiles function for details. See Also -------- download_multifiles """ await asyncio.gather( *[ download_file(jdata["url"], jdata["fn"], jdata.get("sha256", None)) for jdata in urls ] )
[docs] def download_multifiles(urls: List[dict]) -> None: """Download multiple files from dicts. Parameters ---------- urls : list of dicts The information of download files. Each dict should contain the following key: - url: str or list of strs The url(s) that is available to download. - pathfilename: str The downloading path of the file. - sha256: str, optional, default: None Sha256 of the file. If not None and match the file, the download will be skiped. """ asyncio.run(gather_download_files(urls))
[docs] def run_mp(nproc: int, **kwargs: Any) -> Iterable[Any]: """Process a file with multiple processors. Parameters ---------- nproc : int The number of processors to be used. **kwargs : dict, optional Other parameters can be found in the `multiopen` method. Yields ------ object The yielded object from the `multiopen` method. See Also -------- multiopen """ pool = Pool(nproc, maxtasksperchild=1000) semaphore = Semaphore(nproc * 150) try: results = multiopen(pool=pool, semaphore=semaphore, **kwargs) for item in results: yield item semaphore.release() except: logger.exception("run_mp failed") pool.terminate() raise else: pool.close() finally: pool.join()
[docs] def must_be_list(obj: Union[Any, List[Any]]) -> List[Any]: """Convert a object to a list if the object is not a list. Parameters ---------- obj : Object The object to convert. Returns ------- obj: list If the input object is not a list, returns a list that only contains that object. Otherwise, returns that object. """ if isinstance(obj, list): return obj return [obj]