Last active
March 20, 2022 01:18
-
-
Save ProgramRipper/a87be486bc86d2c70c3775fe83352530 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from __future__ import annotations | |
| from abc import get_cache_token | |
| from functools import partial | |
| from numbers import Number | |
| from typing import Any, Callable, cast, get_type_hints | |
| from itertools import chain | |
| class Overload: | |
| """ | |
| ```python | |
| # if bypass is True, Overload will run all covariant overloads | |
| from numbers import Number | |
| overload = Overload(bypass=True) | |
| @overload.overload | |
| def func1(a: Number) -> None: | |
| pass | |
| @overload.overload | |
| def func2(a: int) -> None: | |
| pass | |
| ... | |
| assert tuple(overload(1)) == (func1(1), func2(1)) | |
| assert overload.bypasses == { | |
| (int,): {(int,), (Number,)} | |
| } | |
| # else, Overload will run only the specific overloads | |
| overload = Overload(bypass=False) | |
| @overload.overload | |
| def func1(a: Number) -> None: | |
| pass | |
| @overload.overload | |
| def func2(a: int) -> None: | |
| pass | |
| ... | |
| assert tuple(overload(1)) == (func2(1),) | |
| ``` | |
| """ | |
| overloads: dict[tuple[type, ...], set[Callable[..., Any]]] | |
| bypasses: dict[tuple[type, ...], set[tuple[type, ...]]] | None = None | |
| _cache_token: int | None = None | |
| def __init__(self, bypass: bool = False) -> None: | |
| self.overloads = {} | |
| if bypass: | |
| self.bypasses = {} | |
| def overload( | |
| self, *args: tuple[type, ...] | Callable[..., Any] | |
| ) -> "Overload" | Callable[[Callable[..., Any]], "Overload"]: | |
| annos = cast(tuple[type, ...], args[:-1]) | |
| func = cast(Callable[..., Any], args[-1]) | |
| if not annos: # only one argument | |
| anno = func # if is an annotation | |
| if isinstance(anno, type): | |
| return partial(self.overload, anno) # type: ignore | |
| if not getattr(anno, "__annotations__", {}): | |
| raise TypeError( | |
| f"Invalid first argument to `overload()`: {anno!r}. " | |
| f"Use either `@overload(some_class, ...)` or plain `@overload` " | |
| f"on an annotated function." | |
| ) | |
| func = anno # if is a function | |
| annos = tuple(get_type_hints(func).values())[:-1] | |
| if annos not in self.overloads: | |
| funcs = self.overloads[annos] = cast(set[Callable[..., Any]], set()) | |
| if not self.bypasses: # don't need to clear if is None or empty | |
| self.bypasses.clear() | |
| else: | |
| funcs = self.overloads[annos] | |
| funcs.add(func) | |
| if self._cache_token is None and any( | |
| hasattr(anno, "__abstractmethods__") for anno in annos | |
| ): | |
| self._cache_token = cast(int, get_cache_token()) | |
| return self | |
| def __call__(self, *args: Any) -> tuple[Any]: | |
| annos = tuple(map(type, args)) | |
| if annos in self.overloads: | |
| funcs = iter(self.overloads[annos]) | |
| else: | |
| funcs: iter[Callable[..., Any]] = iter() | |
| if self.bypasses is not None: | |
| if self._cache_token is not None and self._cache_token != ( | |
| current_token := get_cache_token() | |
| ): | |
| self._cache_token = cast(int, current_token) | |
| self.bypasses.clear() | |
| keys = self.bypasses.setdefault( | |
| annos, | |
| { | |
| k | |
| for k in self.overloads | |
| if annos != k | |
| and len(annos) == len(k) | |
| and all(map(issubclass, annos, k)) | |
| }, | |
| ) | |
| funcs = chain(funcs, chain.from_iterable(self.overloads[k] for k in keys)) | |
| return (f(*args) for f in funcs) | |
| if __name__ == "__main__": | |
| o = Overload(True) | |
| overload = o.overload | |
| @overload | |
| def func1(a: Number, b: str) -> None: | |
| print(1) | |
| print(b) | |
| return 1 | |
| @overload | |
| def func2(a: tuple, b: str) -> None: | |
| print(2) | |
| print(b) | |
| return 2 | |
| print(tuple(o((), "test"))) | |
| Number.register(tuple) | |
| print(tuple(o((), "test"))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment