Created
June 16, 2022 02:36
-
-
Save Wybxc/fcc041987e6cbd61bb8600ff5c7bda99 to your computer and use it in GitHub Desktop.
Schema Type: 尝试用附加 pyi 的方式为动态方法签名
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from schema_gen import schema_gen | |
| def add(x: int, y: int): | |
| """calculate x + y""" | |
| return x + y | |
| async def sub(x: int, y: int) -> int: | |
| """calculate x - y""" | |
| return x - y | |
| env = {"add": add, "sub": sub, "value": 42} | |
| if __name__ == "__main__": | |
| with open("custom_schema.pyi", "w") as f: | |
| f.write(schema_gen('CustomSchema', env)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from __future__ import annotations | |
| import inspect | |
| from typing import Any, Callable, Coroutine, Dict, List, Type, TypeVar, cast | |
| from typing_extensions import ParamSpec | |
| class ArgsCollector: | |
| def __init__(self, name: str, origin: ProxyObj): | |
| self.name = name | |
| self.origin = origin | |
| def __call__(self, value: Any, /): | |
| self.origin.kwargs[self.name] = value | |
| return self | |
| def __getattr__(self, name: str): | |
| return ArgsCollector(name, self.origin) | |
| def __await__(self): | |
| return self.origin.__await__() | |
| class ProxyObj: | |
| def __init__(self, obj: Callable[..., Coroutine[Any, Any, Any]]): | |
| self.obj = obj | |
| self.args: List[Any] = [] | |
| self.kwargs: Dict[str, Any] = {} | |
| def __call__(self, *args: Any, **kwargs: Any): | |
| self.args.extend(args) | |
| self.kwargs.update(kwargs) | |
| return self | |
| def __getattr__(self, name: str): | |
| return ArgsCollector(name, self) | |
| def __await__(self): | |
| return self.obj(*self.args, **self.kwargs).__await__() | |
| Schema = TypeVar("Schema") | |
| class Proxy: | |
| def __init__(self, environment: Dict[str, Any]): | |
| self.environment = environment | |
| def __getattr__(self, name: str) -> ProxyObj: | |
| obj = self.environment[name] | |
| if inspect.iscoroutinefunction(obj): | |
| return ProxyObj(obj) | |
| if inspect.isfunction(obj): | |
| return ProxyObj(async_(obj)) | |
| return ProxyObj(async_(lambda: obj)) | |
| def with_schema(self, schema: Type[Schema]) -> Schema: | |
| return cast(Schema, self) | |
| Params = ParamSpec("Params") | |
| Res = TypeVar("Res", covariant=True) | |
| def async_(func: Callable[Params, Res]) -> Callable[Params, Coroutine[Any, Any, Res]]: | |
| async def wrapped(*args: Params.args, **kwargs: Params.kwargs): | |
| return func(*args, **kwargs) | |
| return wrapped |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import asyncio | |
| from lib import * | |
| from schema import CustomSchema | |
| def add(x: int, y: int) -> int: | |
| """calculate x + y""" | |
| return x + y | |
| async def sub(x: int, y: int) -> int: | |
| """calculate x - y""" | |
| return x - y | |
| async def main(): | |
| proxy = Proxy({"add": add, "sub": sub, "value": 42}).with_schema(CustomSchema) | |
| print(await proxy.add(1, 2)) | |
| print(await proxy.sub.x(1).y(2)) | |
| print(await proxy.value) | |
| if __name__ == "__main__": | |
| asyncio.run(main()) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class CustomSchema: | |
| pass |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import typing | |
| class CustomSchema: | |
| class _Proxy_add(typing.Protocol): | |
| async def __call__(self, x: int, y: int) -> typing.Any: ... | |
| def x(self, value: int, /) -> CustomSchema._Proxy_add_collector: ... | |
| def y(self, value: int, /) -> CustomSchema._Proxy_add_collector: ... | |
| class _Proxy_add_collector(typing.Protocol): | |
| def __await__(self) -> typing.Generator[typing.Any, typing.Any, typing.Any]: ... | |
| def x(self, value: int, /) -> CustomSchema._Proxy_add_collector: ... | |
| def y(self, value: int, /) -> CustomSchema._Proxy_add_collector: ... | |
| add: _Proxy_add | |
| """calculate x + y""" | |
| class _Proxy_sub(typing.Protocol): | |
| async def __call__(self, x: int, y: int) -> int: ... | |
| def x(self, value: int, /) -> CustomSchema._Proxy_sub_collector: ... | |
| def y(self, value: int, /) -> CustomSchema._Proxy_sub_collector: ... | |
| class _Proxy_sub_collector(typing.Protocol): | |
| def __await__(self) -> typing.Generator[typing.Any, typing.Any, int]: ... | |
| def x(self, value: int, /) -> CustomSchema._Proxy_sub_collector: ... | |
| def y(self, value: int, /) -> CustomSchema._Proxy_sub_collector: ... | |
| sub: _Proxy_sub | |
| """calculate x - y""" | |
| value: typing.Awaitable[int] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import inspect | |
| from typing import Any, Dict, List | |
| import textwrap | |
| def schema_gen(name: str, environment: Dict[str, Any]): | |
| result = f""" | |
| import typing | |
| class {name}: | |
| """.strip() | |
| for key, value in environment.items(): | |
| if inspect.isfunction(value): | |
| signature = inspect.signature(value) | |
| result += ( | |
| "\n" | |
| + textwrap.indent( | |
| schema_gen_function(name, key, signature, value.__doc__ or ""), | |
| " " * 4, | |
| ) | |
| + "\n" | |
| ) | |
| else: | |
| result += f"\n {key}: typing.Awaitable[{type(value).__name__}]\n" | |
| return result | |
| def schema_gen_function(base: str, name: str, signature: inspect.Signature, doc: str): | |
| return_type = format_type(signature.return_annotation) | |
| result = f""" | |
| class _Proxy_{name}(typing.Protocol): | |
| async def __call__(self, {', '.join(map(str, signature.parameters.values()))}) -> {return_type}: ... | |
| """.strip() | |
| params: List[str] = [] | |
| for arg, param in signature.parameters.items(): | |
| if param.kind == inspect.Parameter.POSITIONAL_ONLY: | |
| raise ValueError(f"{arg} is positional only") | |
| params.append( | |
| f"def {arg}(self, value: {format_type(param.annotation)}, /) -> {base}._Proxy_{name}_collector: ..." | |
| ) | |
| parmas_str = "\n".join(params) | |
| result += "\n" + textwrap.indent(parmas_str, prefix=" " * 4) + "\n" | |
| result += f""" | |
| class _Proxy_{name}_collector(typing.Protocol): | |
| def __await__(self) -> typing.Generator[typing.Any, typing.Any, {return_type}]: ... | |
| """.strip() | |
| result += "\n" + textwrap.indent(parmas_str, prefix=" " * 4) + "\n" | |
| result += f"{name}: _Proxy_{name}" | |
| result += f'\n"""{doc}"""' | |
| return result | |
| def format_type(type_: Any) -> str: | |
| if type_ is getattr(inspect, "_empty"): | |
| return "typing.Any" | |
| if isinstance(type_, type): | |
| return type_.__name__ | |
| return str(type_) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment