Skip to content

Instantly share code, notes, and snippets.

@ProgramRipper
Last active March 20, 2022 01:18
Show Gist options
  • Select an option

  • Save ProgramRipper/a87be486bc86d2c70c3775fe83352530 to your computer and use it in GitHub Desktop.

Select an option

Save ProgramRipper/a87be486bc86d2c70c3775fe83352530 to your computer and use it in GitHub Desktop.
from __future__ import annotations
from abc import get_cache_token
from functools import partial
from numbers import Number
from typing import Any, Callable, cast, get_type_hints
from itertools import chain
class Overload:
"""
```python
# if bypass is True, Overload will run all covariant overloads
from numbers import Number
overload = Overload(bypass=True)
@overload.overload
def func1(a: Number) -> None:
pass
@overload.overload
def func2(a: int) -> None:
pass
...
assert tuple(overload(1)) == (func1(1), func2(1))
assert overload.bypasses == {
(int,): {(int,), (Number,)}
}
# else, Overload will run only the specific overloads
overload = Overload(bypass=False)
@overload.overload
def func1(a: Number) -> None:
pass
@overload.overload
def func2(a: int) -> None:
pass
...
assert tuple(overload(1)) == (func2(1),)
```
"""
overloads: dict[tuple[type, ...], set[Callable[..., Any]]]
bypasses: dict[tuple[type, ...], set[tuple[type, ...]]] | None = None
_cache_token: int | None = None
def __init__(self, bypass: bool = False) -> None:
self.overloads = {}
if bypass:
self.bypasses = {}
def overload(
self, *args: tuple[type, ...] | Callable[..., Any]
) -> "Overload" | Callable[[Callable[..., Any]], "Overload"]:
annos = cast(tuple[type, ...], args[:-1])
func = cast(Callable[..., Any], args[-1])
if not annos: # only one argument
anno = func # if is an annotation
if isinstance(anno, type):
return partial(self.overload, anno) # type: ignore
if not getattr(anno, "__annotations__", {}):
raise TypeError(
f"Invalid first argument to `overload()`: {anno!r}. "
f"Use either `@overload(some_class, ...)` or plain `@overload` "
f"on an annotated function."
)
func = anno # if is a function
annos = tuple(get_type_hints(func).values())[:-1]
if annos not in self.overloads:
funcs = self.overloads[annos] = cast(set[Callable[..., Any]], set())
if not self.bypasses: # don't need to clear if is None or empty
self.bypasses.clear()
else:
funcs = self.overloads[annos]
funcs.add(func)
if self._cache_token is None and any(
hasattr(anno, "__abstractmethods__") for anno in annos
):
self._cache_token = cast(int, get_cache_token())
return self
def __call__(self, *args: Any) -> tuple[Any]:
annos = tuple(map(type, args))
if annos in self.overloads:
funcs = iter(self.overloads[annos])
else:
funcs: iter[Callable[..., Any]] = iter()
if self.bypasses is not None:
if self._cache_token is not None and self._cache_token != (
current_token := get_cache_token()
):
self._cache_token = cast(int, current_token)
self.bypasses.clear()
keys = self.bypasses.setdefault(
annos,
{
k
for k in self.overloads
if annos != k
and len(annos) == len(k)
and all(map(issubclass, annos, k))
},
)
funcs = chain(funcs, chain.from_iterable(self.overloads[k] for k in keys))
return (f(*args) for f in funcs)
if __name__ == "__main__":
o = Overload(True)
overload = o.overload
@overload
def func1(a: Number, b: str) -> None:
print(1)
print(b)
return 1
@overload
def func2(a: tuple, b: str) -> None:
print(2)
print(b)
return 2
print(tuple(o((), "test")))
Number.register(tuple)
print(tuple(o((), "test")))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment