Skip to content
Draft
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 40 additions & 10 deletions src/plum/_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from copy import copy
from functools import wraps
from types import MethodType
from typing import Any, Protocol, TypeVar
from typing import Any, Protocol, TypeVar, overload
from typing_extensions import Self

from ._method import Method, MethodList
Expand Down Expand Up @@ -454,11 +454,17 @@ def wrapped_method(*args: Any, **kw: Any) -> Any:

return wrapped_method

def __get__(self, instance: object, owner: type, /) -> "Function | MethodType":
if instance is not None:
return MethodType(_BoundFunction(self, instance), instance)
else:
@overload
def __get__(self, instance: None, owner: type, /) -> "Function": ...
@overload
def __get__(self, instance: object, owner: type, /) -> MethodType: ...

def __get__(
self, instance: object | None, owner: type, /
) -> "Function | MethodType":
if instance is None:
return self
return MethodType(_BoundFunction(self, instance), instance)

def __repr__(self) -> str:
return (
Expand Down Expand Up @@ -499,6 +505,30 @@ def __call__(
) -> Self | Callable[[Callable[..., Any]], Self]: ...


class _BoundFunctionProto(Protocol):
"""Subset of :class:`Function`'s interface required by :class:`_BoundFunction`.

Declaring ``_BoundFunction._f`` with this Protocol rather than :class:`Function`
directly prevents mypy from applying ``Function.__get__``'s descriptor protocol
when resolving instance-attribute accesses of ``_f``.
Comment thread
nstarman marked this conversation as resolved.
Outdated
"""

_f: Callable[..., Any]

def __call__(self, *args: object, **kw: object) -> object: ...

def invoke(self, *types: TypeHint) -> Callable[..., Any]: ...

@property
def methods(self) -> MethodList: ...

def dispatch(
self,
method: Callable[..., Any] | None = None,
precedence: int = 0,
) -> Any: ...


class _BoundFunction:
"""A bound instance of `.function.Function`.

Expand All @@ -507,7 +537,7 @@ class _BoundFunction:
instance (object): Instance to which the function is bound.
"""

_f: "Function"
_f: "_BoundFunctionProto"
Comment thread
wesselb marked this conversation as resolved.
_instance: object

def __init__(self, f: "Function", instance: object) -> None:
Expand All @@ -532,10 +562,10 @@ def __call__(self, _: object, *args: object, **kw: object) -> object:
def invoke(self, *types: TypeHint) -> Callable[..., Any]:
"""See :meth:`.Function.invoke`."""

@wraps(self._f._f) # type: ignore[union-attr]
@wraps(self._f._f)
def wrapped_method(*args: Any, **kw: Any) -> Any:
# TODO: Can we do this without `type` here?
method = self._f.invoke(type(self._instance), *types) # type: ignore[union-attr]
method = self._f.invoke(type(self._instance), *types)
return method(self._instance, *args, **kw)

# We set `f.__wrapped_by_plum__` for :func:`Function.invoke`, but here
Expand All @@ -549,9 +579,9 @@ def wrapped_method(*args: Any, **kw: Any) -> Any:
@property
def methods(self) -> MethodList:
"""list[:class:`.method.Method`]: All available methods."""
return self._f.methods # type: ignore[union-attr]
return self._f.methods

@property
def dispatch(self) -> _DispatchFunction:
"""See :meth:`.Function.dispatch`."""
return self._f.dispatch # type: ignore[union-attr, return-value]
return self._f.dispatch
94 changes: 50 additions & 44 deletions src/plum/_parametric.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
)

import contextlib
from typing import TypeVar, final
from collections.abc import Callable
from typing import Any, TypeVar, cast, final
from typing_extensions import deprecated

import beartype.door
Expand All @@ -35,7 +36,7 @@ class ParametricTypeMeta(type):
`Type[type(Arg1), type(Arg2)](Arg1, Arg2, **kw_args)`.
"""

def __getitem__(cls, p):
def __getitem__(cls, p: TypeHint | tuple[TypeHint, ...]) -> type:
if not cls.concrete:
# Initialise the type parameters. This can perform, e.g., validation.
p = p if isinstance(p, tuple) else (p,) # Ensure that it is a tuple.
Expand All @@ -46,7 +47,7 @@ def __getitem__(cls, p):
else:
raise TypeError("Cannot specify type parameters. This type is concrete.")

def __concrete_class__(cls, *args, **kw_args):
def __concrete_class__(cls, *args: object, **kw_args: object) -> type:
"""If `cls` is not a concrete class, infer the type parameters and return a
concrete class. If `cls` is already a concrete class, simply return it.

Expand All @@ -62,7 +63,7 @@ def __concrete_class__(cls, *args, **kw_args):
cls = cls[type_parameter]
return cls

def __init_type_parameter__(cls, *ps):
def __init_type_parameter__(cls, *ps: TypeHint) -> TypeHint | tuple[TypeHint, ...]:
"""Function called to initialise the type parameters.

The default behaviour is to just return `ps`.
Expand All @@ -75,7 +76,9 @@ def __init_type_parameter__(cls, *ps):
"""
return ps

def __infer_type_parameter__(cls, *args, **kw_args):
def __infer_type_parameter__(
cls, *args: object, **kw_args: object
) -> type | tuple[type, ...]:
"""Function called when the constructor of this parametric type is called
before the parameters have been specified.

Expand All @@ -96,29 +99,27 @@ def __infer_type_parameter__(cls, *args, **kw_args):
return type_parameter

@property
def parametric(cls):
def parametric(cls) -> bool:
"""bool: Check whether the type is a parametric type."""
return getattr(cls, "_parametric", False)

@property
def concrete(cls):
def concrete(cls) -> bool:
"""bool: Check whether the parametric type is instantiated or not."""
if cls.parametric:
return getattr(cls, "_concrete", False)
else:
if not cls.parametric:
raise RuntimeError(
"Cannot check whether a non-parametric type is instantiated or not."
)
return getattr(cls, "_concrete", False)

@property
def type_parameter(cls):
def type_parameter(cls) -> object:
"""object: Get the type parameter. Parametric type must be instantiated."""
if cls.concrete:
return cls._type_parameter
else:
if not cls.concrete:
raise RuntimeError(
"Cannot get the type parameter of non-instantiated parametric type."
)
return cls._type_parameter


def _default_le_type_par(p_left: TypeHint | object, p_right: TypeHint | object) -> bool:
Expand All @@ -133,7 +134,7 @@ def _default_le_type_par(p_left: TypeHint | object, p_right: TypeHint | object)
class CovariantMeta(ParametricTypeMeta):
"""A metaclass that implements *covariance* of parametric types."""

def __subclasscheck__(cls, subclass):
def __subclasscheck__(cls, subclass: type) -> bool:
# Check that they are instances of the same parametric type.
if (
is_concrete(cls)
Expand All @@ -150,7 +151,7 @@ def __subclasscheck__(cls, subclass):
# Default behaviour to `type`s subclass check.
return type.__subclasscheck__(cls, subclass)

def __instancecheck__(cls, instance):
def __instancecheck__(cls, instance: object) -> bool:
# If `A` is a parametric type, then `A[T1]` and `A[T2]` are subclasses of
# `A`. With the implementation of `__subclasscheck__` above, we have that
# `issubclass(A[T1], A[T2])` whenever `issubclass(T1, T2)`. _However_,
Expand All @@ -162,7 +163,9 @@ def __instancecheck__(cls, instance):
# since it is fast and only gives true positives.
return type.__instancecheck__(cls, instance) or issubclass(type(instance), cls)

def __le_type_parameter__(cls, p_left, p_right):
def __le_type_parameter__(
cls, p_left: tuple[object, ...], p_right: tuple[object, ...]
) -> bool:
# Check that there are an equal number of parameters.
if len(p_left) != len(p_right):
return False
Expand All @@ -172,7 +175,7 @@ def __le_type_parameter__(cls, p_left, p_right):
)


def parametric(original_class=None):
def parametric(original_class: type, /) -> type:
"""A decorator for parametric classes.

When the constructor of this parametric type is called before the type parameter
Expand Down Expand Up @@ -200,37 +203,38 @@ def __le_type_parameter__(cls, left, right) -> bool:
...

"""

original_meta = type(original_class)

# Make a metaclass that derives from both the metaclass of `original_meta` and
# `CovariantMeta`, but make sure not to insert `CovariantMeta` twice, because that
# will error.

if CovariantMeta in original_meta.__mro__:
if CovariantMeta in cast(tuple[type, ...], original_meta.__mro__):
bases = (original_meta,)
name = original_meta.__name__
else:
bases = (CovariantMeta, original_meta)
name = f"CovariantMeta[{repr_short(original_meta)}]"

def __call__(cls, *args, **kw_args):
def __call__(cls: type, *args: object, **kw_args: object) -> object:
cls = cls.__concrete_class__(*args, **kw_args)
return original_meta.__call__(cls, *args, **kw_args)

def __instancecheck__(cls, instance):
def __instancecheck__(cls: type, instance: object) -> bool:
# An implementation of `__instancecheck__` is necessary to ensure that
# `isinstance(A[SubType](), A[Type])`. `CovariantMeta` comes first in the MRO,
# but the implementation of `__instancecheck__` should be taken from
# `original_meta` if it exists. The implementation of `CovariantMeta` should be
# used as a fallback. Note that `original_meta.__instancecheck__` always exists.
# We check that it is not equal to the default `type.__instancecheck__`.
if original_meta.__instancecheck__ != type.__instancecheck__:
return original_meta.__instancecheck__(cls, instance)
return cast(
Callable[[type, object], bool], original_meta.__instancecheck__
)(cls, instance)
Comment thread
nstarman marked this conversation as resolved.
Outdated
else:
return CovariantMeta.__instancecheck__(cls, instance)

meta = type(
meta: Any = type(
name,
bases,
{
Expand All @@ -239,14 +243,15 @@ def __instancecheck__(cls, instance):
},
)

subclasses = {}
subclasses: dict[tuple[object, ...], type] = {}

def __new__(cls, *ps):
def __new__(cls: type, *ps: object) -> type:
# Only create a new subclass if it doesn't exist already.
if ps not in subclasses:

def __new__(cls, *args, **kw_args):
return original_class.__new__(cls)
def __new__(cls: type, *args: object, **kw_args: object) -> object:
_new: Any = original_class.__new__
return _new(cls)

# Create subclass.
name = original_class.__name__
Expand All @@ -268,20 +273,21 @@ def __new__(cls, *args, **kw_args):
subclasses[ps] = subclass
return subclasses[ps]

def __init_subclass__(cls, **kw_args):
def __init_subclass__(cls: type, **kw_args: object) -> None:
cls._parametric = False
# If the subclass has the same `__new__` as `ParametricClass`, then we should
# replace it with the `__new__` of `Class`. If the user already defined another
# `__new__`, then everything is fine.
if cls.__new__ is __new__:

def class_new(cls, *args, **kw_args):
return original_class.__new__(cls)
def class_new(cls: type, *args: object, **kw_args: object) -> object:
_new: Any = original_class.__new__
return _new(cls)

cls.__new__ = class_new
super(original_class, cls).__init_subclass__(**kw_args)

def __class_nonparametric__(cls):
def __class_nonparametric__(cls: type) -> type:
"""Return the non-parametric type of an object.

:mod:`plum.parametric` produces parametric subtypes of classes. This
Expand Down Expand Up @@ -333,7 +339,7 @@ def __class_nonparametric__(cls):
"""
return original_class

def __class_unparametrized__(cls):
def __class_unparametrized__(cls: type) -> type:
"""Return the unparametrized type of an object.

:mod:`plum.parametric` produces parametric subtypes of classes. This
Expand Down Expand Up @@ -418,7 +424,7 @@ def __class_unparametrized__(cls):
return parametric_class


def is_concrete(t):
def is_concrete(t: object) -> bool:
"""Check if a type `t` is a concrete instance of a parametric type.

Args:
Expand Down Expand Up @@ -603,25 +609,25 @@ def type_unparametrized(q: T, /) -> type[T]:
parameter(s).
"""
typ = type(q)
return q.__class_unparametrized__() if isinstance(typ, ParametricTypeMeta) else typ
return q.__class_unparametrized__() if isinstance(typ, ParametricTypeMeta) else typ # type: ignore[redundant-expr]
Comment thread
wesselb marked this conversation as resolved.


def kind(SuperClass=object):
def kind(cls: type = object, /) -> type:
"""Create a parametric wrapper type for dispatch purposes.

Args:
SuperClass (type): Super class.
cls (type): Super class.

Returns:
object: New parametric type wrapper.
"""

@parametric
class Kind(SuperClass):
def __init__(self, *xs):
class Kind(cls):
def __init__(self, *xs: object) -> None:
self.xs = xs

def get(self):
def get(self) -> object:
return self.xs[0] if len(self.xs) == 1 else self.xs

return Kind
Expand All @@ -645,7 +651,7 @@ class Val:
"""

@classmethod
def __infer_type_parameter__(cls, *arg):
def __infer_type_parameter__(cls, *arg: Any) -> object:
"""Function called when the constructor of `Val` is called to determine the type
parameters."""
if len(arg) == 0:
Expand All @@ -654,14 +660,14 @@ def __infer_type_parameter__(cls, *arg):
raise ValueError("Too many values. `Val` accepts only one argument.")
return arg[0]

def __init__(self, val=None):
def __init__(self, val: Any = None) -> None:
"""Construct a value object with type `Val(arg)` that can be used to dispatch
based on values.

Args:
val (object): The value to be moved to the type domain.
"""
if type(self).concrete:
if type(self).concrete: # type: ignore[attr-defined]
if val is not None and type_parameter(self) != val:
raise ValueError("The value must be equal to the type parameter.")
else:
Expand All @@ -670,5 +676,5 @@ def __init__(self, val=None):
def __repr__(self) -> str:
return repr_short(type(self)).replace("._parametric", "") + "()"

def __eq__(self, other):
def __eq__(self, other: object) -> bool:
return type(self) is type(other)
Loading
Loading