Source code for triad.utils.class_extension

from typing import Callable, Dict, Optional, Type, get_type_hints, Any
import inspect
from triad.utils.assertion import assert_or_throw


def _get_first_arg_type(func: Callable) -> Any:
    sig = inspect.signature(func)
    annotations = get_type_hints(func)
    for k, w in sig.parameters.items():
        assert_or_throw(k != "self", ValueError(f"class method is not allowed {func}"))
        assert_or_throw(
            w.kind == w.POSITIONAL_OR_KEYWORD,
            ValueError(f"{w} is not a valid parameter in {func}"),
        )
        anno = annotations.get(k, w.annotation)
        assert_or_throw(
            anno != inspect.Parameter.empty,
            ValueError(f"the first argument must be annotated in {func}"),
        )
        return anno
    raise ValueError(f"{func} does not have any input parameter")


class _ClassExtension:
    def __init__(self, class_type: Type):
        self._class_type = class_type
        self._built_in = set(dir(class_type))
        self._ext: Dict[str, Callable] = {}

    def add_method(
        self, func: Callable, name: Optional[str] = None, on_dup: str = "error"
    ) -> None:
        assert_or_throw(
            name not in self._built_in, ValueError(f"{name} is a built in attribute")
        )
        if name is None:
            name = func.__name__
        if name in self._ext:
            if on_dup == "ignore":
                return
            if on_dup == "error":
                raise ValueError(f"{name} is already registered")
        self._ext[name] = func
        setattr(self._class_type, name, func)


class _ClassExtensions:
    def __init__(self):
        self._types: Dict[Type, _ClassExtension] = {}

    def register_type(self, tp: Type) -> None:
        assert_or_throw(
            tp not in self._types, ValueError(f"{tp} is already registered")
        )
        self._types[tp] = _ClassExtension(tp)

    def add_method(
        self,
        class_type: Type,
        func: Callable,
        name: Optional[str] = None,
        on_dup: str = "error",
    ) -> None:
        assert_or_throw(
            class_type in self._types, ValueError(f"{class_type} is not registered")
        )
        self._types[class_type].add_method(func, name=name, on_dup=on_dup)


_CLASS_EXTENSIONS = _ClassExtensions()


[docs] def extensible_class(class_type: Type) -> Type: """The decorator making classes extensible by external methods :param class_type: the class under the decorator :return: the ``class_type`` .. admonition:: Examples .. code-block:: python @extensible_class class A: # It's recommended to implement __getattr__ so that # PyLint will not complain about the dynamically added methods def __getattr__(self, name): raise NotImplementedError @extension_method def method(obj:A): return 1 assert 1 == A().method() .. note:: If the method name is already in the original class, a ValueError will be thrown. You can't modify any built-in attribute. """ _CLASS_EXTENSIONS.register_type(class_type) return class_type
[docs] def extension_method( func: Optional[Callable] = None, class_type: Optional[Type] = None, name: Optional[str] = None, on_dup: str = "error", ) -> Callable: """The decorator to add functions as members of the correspondent classes. :param func: the function under the decorator :param class_type: the parent class type, defaults to None :param name: the specified class method name, defaults to None. If None then ``func.__name__`` will be used as the method name :param on_dup: action on name duplication, defaults to ``error``. ``error`` will throw a ValueError; ``ignore`` will take no action; ``overwrite`` will use the current method to overwrite. :return: the underlying function .. admonition:: Examples .. code-block:: python @extensible_class class A: # It's recommended to implement __getattr__ so that # PyLint will not complain about the dynamically added methods def __getattr__(self, name): raise NotImplementedError # The simplest way to use this decorator, the first argument of # the method must be annotated, and the annotated type is the # class type to add this method to. @extension_method def method1(obj:A): return 1 assert 1 == A().method1() # Or you can be explicit of the class type and the name of the # method in the class. In this case, you don't have to annotate # the first argument. @extension_method(class_type=A, name="m3") def method2(obj, b): return 2 + b assert 5 == A().m3(3) .. note:: If the method name is already in the original class, a ValueError will be thrown. You can't modify any built-in attribute. """ if func is not None: # @extension_method _CLASS_EXTENSIONS.add_method( _get_first_arg_type(func) if class_type is None else class_type, func=func, name=name, on_dup=on_dup, ) return func else: # @extension_method(...) def inner(func): _CLASS_EXTENSIONS.add_method( _get_first_arg_type(func) if class_type is None else class_type, func=func, name=name, on_dup=on_dup, ) return func return inner