diff --git a/pyrefly/lib/alt/class/class_field.rs b/pyrefly/lib/alt/class/class_field.rs index 34b7ad6661..8dd3522318 100644 --- a/pyrefly/lib/alt/class/class_field.rs +++ b/pyrefly/lib/alt/class/class_field.rs @@ -63,6 +63,7 @@ use crate::types::callable::Param; use crate::types::callable::Required; use crate::types::class::Class; use crate::types::class::ClassType; +use crate::types::display::TypeDisplayContext; use crate::types::keywords::DataclassFieldKeywords; use crate::types::literal::Lit; use crate::types::quantified::Quantified; @@ -158,6 +159,51 @@ impl ClassAttribute { } } +#[derive(Clone, Copy)] +enum OperatorCompatibilityKind { + Reflected, + InPlace, +} + +fn reflected_operator_forward_name(name: &Name) -> Option<&'static str> { + match name.as_str() { + "__radd__" => Some("__add__"), + "__rsub__" => Some("__sub__"), + "__rmul__" => Some("__mul__"), + "__rmatmul__" => Some("__matmul__"), + "__rtruediv__" => Some("__truediv__"), + "__rfloordiv__" => Some("__floordiv__"), + "__rmod__" => Some("__mod__"), + "__rdivmod__" => Some("__divmod__"), + "__rpow__" => Some("__pow__"), + "__rlshift__" => Some("__lshift__"), + "__rrshift__" => Some("__rshift__"), + "__rand__" => Some("__and__"), + "__rxor__" => Some("__xor__"), + "__ror__" => Some("__or__"), + _ => None, + } +} + +fn inplace_operator_forward_name(name: &Name) -> Option<&'static str> { + match name.as_str() { + "__iadd__" => Some("__add__"), + "__isub__" => Some("__sub__"), + "__imul__" => Some("__mul__"), + "__imatmul__" => Some("__matmul__"), + "__itruediv__" => Some("__truediv__"), + "__ifloordiv__" => Some("__floordiv__"), + "__imod__" => Some("__mod__"), + "__ipow__" => Some("__pow__"), + "__ilshift__" => Some("__lshift__"), + "__irshift__" => Some("__rshift__"), + "__iand__" => Some("__and__"), + "__ixor__" => Some("__xor__"), + "__ior__" => Some("__or__"), + _ => None, + } +} + #[derive(Debug, Clone, TypeEq, PartialEq, Eq, VisitMut)] pub struct Descriptor { /// The location of the property where the descriptor is bound, where we should raise @@ -1829,6 +1875,150 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { true } + fn check_operator_compatibility_for_field( + &self, + cls: &Class, + field_name: &Name, + bases: &ClassBases, + range: TextRange, + errors: &ErrorCollector, + ) { + if let Some(forward) = reflected_operator_forward_name(field_name) { + self.check_operator_compatibility_with_forward( + OperatorCompatibilityKind::Reflected, + forward, + cls, + field_name, + bases, + range, + errors, + ); + } + if let Some(forward) = inplace_operator_forward_name(field_name) { + self.check_operator_compatibility_with_forward( + OperatorCompatibilityKind::InPlace, + forward, + cls, + field_name, + bases, + range, + errors, + ); + } + } + + fn check_operator_compatibility_with_forward( + &self, + kind: OperatorCompatibilityKind, + forward_str: &'static str, + cls: &Class, + field_name: &Name, + bases: &ClassBases, + range: TextRange, + errors: &ErrorCollector, + ) { + if bases.is_empty() { + return; + } + + let forward_name = Name::new_static(forward_str); + let subclass_class_type = self.as_class_type_unchecked(cls); + let subclass_type = Type::ClassType(subclass_class_type.clone()); + + for parent in bases.iter() { + let parent_cls = parent.class_object(); + if self.get_class_member(parent_cls, field_name).is_some() { + continue; + } + + let Some(_) = self.get_class_member(parent_cls, &forward_name) else { + continue; + }; + + let lhs_type = parent.clone().to_type(); + let mut forward_arg_ty = subclass_type.clone(); + let forward_args = [CallArg::ty(&forward_arg_ty, range)]; + let forward_errors = self.error_collector(); + let forward_ret = self.call_magic_dunder_method( + &lhs_type, + &forward_name, + range, + &forward_args, + &[], + &forward_errors, + None, + ); + if !forward_errors.is_empty() { + continue; + } + let Some(forward_ret) = forward_ret else { + continue; + }; + if forward_ret.is_error() { + continue; + } + + let method_receiver_ty = subclass_type.clone(); + let mut method_arg_ty = lhs_type.clone(); + let method_args = [CallArg::ty(&method_arg_ty, range)]; + let method_errors = self.error_collector(); + let method_ret = self.call_magic_dunder_method( + &method_receiver_ty, + field_name, + range, + &method_args, + &[], + &method_errors, + None, + ); + if !method_errors.is_empty() { + continue; + } + let Some(method_ret) = method_ret else { + continue; + }; + if method_ret.is_error() { + continue; + } + + if self.is_subset_eq(&method_ret, &forward_ret) { + continue; + } + + let method_ret_display = self.for_display(method_ret.clone()); + let forward_ret_display = self.for_display(forward_ret.clone()); + let ctx = TypeDisplayContext::new(&[&method_ret_display, &forward_ret_display]); + + let message = match kind { + OperatorCompatibilityKind::Reflected => format!( + "Class member `{}.{}` returns `{}` which is incompatible with the return type `{}` of parent operator `{}.{}` when the reflected operator may be selected", + cls.name().as_str(), + field_name.as_str(), + ctx.display(&method_ret_display), + ctx.display(&forward_ret_display), + parent.name().as_str(), + forward_name.as_str(), + ), + OperatorCompatibilityKind::InPlace => format!( + "Class member `{}.{}` returns `{}` which is incompatible with the return type `{}` of parent operator `{}.{}` used for augmented assignment", + cls.name().as_str(), + field_name.as_str(), + ctx.display(&method_ret_display), + ctx.display(&forward_ret_display), + parent.name().as_str(), + forward_name.as_str(), + ), + }; + + self.error( + errors, + range, + ErrorInfo::Kind(ErrorKind::BadOverride), + message, + ); + } + } + pub fn check_consistent_override_for_field( &self, cls: &Class, @@ -2001,6 +2191,8 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { ), ); } + + self.check_operator_compatibility_for_field(cls, field_name, bases, range, errors); } /// For classes with multiple inheritance, check that fields inherited from multiple base classes are consistent. diff --git a/pyrefly/lib/alt/operators.rs b/pyrefly/lib/alt/operators.rs index 6663d5ef61..967e4bf1a4 100644 --- a/pyrefly/lib/alt/operators.rs +++ b/pyrefly/lib/alt/operators.rs @@ -31,6 +31,7 @@ use crate::error::context::ErrorInfo; use crate::error::context::TypeCheckContext; use crate::error::context::TypeCheckKind; use crate::graph::index::Idx; +use crate::types::class::ClassType; use crate::types::literal::Lit; use crate::types::tuple::Tuple; use crate::types::types::Type; @@ -210,6 +211,32 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { } } + fn binop_call_type( + &self, + target: &Type, + method_name: &Name, + arg: &Type, + range: TextRange, + ) -> Option { + let errors = self.error_collector(); + let mut arg_ty = arg.clone(); + let call_args = [CallArg::ty(&arg_ty, range)]; + let ret = self.call_magic_dunder_method( + target, + method_name, + range, + &call_args, + &[], + &errors, + None, + ); + if errors.is_empty() { + ret.filter(|ty| !ty.is_error()) + } else { + None + } + } + pub fn binop_infer( &self, x: &ExprBinOp, @@ -224,14 +251,62 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { self.for_display(rhs.clone()), ) }; - // Reflected operator implementation: This deviates from the runtime semantics by calling the reflected dunder if the regular dunder call errors. - // At runtime, the reflected dunder is called only if the regular dunder method doesn't exist or if it returns NotImplemented. - // This deviation is necessary, given that the typeshed stubs don't record when NotImplemented is returned - let calls_to_try = [ - (&Name::new_static(op.dunder()), lhs, rhs), - (&Name::new_static(op.reflected_dunder()), rhs, lhs), - ]; - self.try_binop_calls(&calls_to_try, range, errors, &context) + let class_type_of = |ty: &Type| -> Option { + match ty { + Type::ClassType(cls) => Some(cls.clone()), + Type::SelfType(cls) => Some(cls.clone()), + Type::Literal(Lit::Enum(lit_enum)) => Some(lit_enum.class.clone()), + _ => None, + } + }; + let lhs_cls = class_type_of(lhs); + let rhs_cls = class_type_of(rhs); + let strict_rhs_subclass = matches!((&rhs_cls, &lhs_cls), (Some(rhs_cls), Some(lhs_cls)) + if rhs_cls.class_object() != lhs_cls.class_object() + && self.has_superclass(rhs_cls.class_object(), lhs_cls.class_object())); + + let forward_name = Name::new_static(op.dunder()); + let reflected_name = Name::new_static(op.reflected_dunder()); + let calls_to_try = if strict_rhs_subclass { + [(&reflected_name, rhs, lhs), (&forward_name, lhs, rhs)] + } else { + [(&forward_name, lhs, rhs), (&reflected_name, rhs, lhs)] + }; + let mut result = self.try_binop_calls(&calls_to_try, range, errors, &context); + + if !strict_rhs_subclass { + let forward_ret = self.binop_call_type(lhs, &forward_name, rhs, range); + let reflected_ret = self.binop_call_type(rhs, &reflected_name, lhs, range); + let extra = match (forward_ret, reflected_ret) { + (Some(f), Some(r)) => { + if self.is_equal(&f, &r) && self.is_equal(&result, &f) { + None + } else { + Some(self.union(f, r)) + } + } + (Some(f), None) => { + if self.is_equal(&result, &f) { + None + } else { + Some(f) + } + } + (None, Some(r)) => { + if self.is_equal(&result, &r) { + None + } else { + Some(r) + } + } + (None, None) => None, + }; + if let Some(extra) = extra { + result = self.union(result, extra); + } + } + + result }; // If the expression is of the form [X] * Y where Y is a number, pass down the contextual // type hint when evaluating [X] @@ -304,12 +379,56 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { self.for_display(rhs.clone()), ) }; - let calls_to_try = [ - (&Name::new_static(op.in_place_dunder()), lhs, rhs), - (&Name::new_static(op.dunder()), lhs, rhs), - (&Name::new_static(op.reflected_dunder()), rhs, lhs), - ]; - self.try_binop_calls(&calls_to_try, range, errors, &context) + let class_type_of = |ty: &Type| -> Option { + match ty { + Type::ClassType(cls) => Some(cls.clone()), + Type::SelfType(cls) => Some(cls.clone()), + Type::Literal(Lit::Enum(lit_enum)) => Some(lit_enum.class.clone()), + _ => None, + } + }; + let lhs_cls = class_type_of(lhs); + let rhs_cls = class_type_of(rhs); + let strict_rhs_subclass = matches!((&rhs_cls, &lhs_cls), (Some(rhs_cls), Some(lhs_cls)) + if rhs_cls.class_object() != lhs_cls.class_object() + && self.has_superclass(rhs_cls.class_object(), lhs_cls.class_object())); + + let inplace_name = Name::new_static(op.in_place_dunder()); + let forward_name = Name::new_static(op.dunder()); + let reflected_name = Name::new_static(op.reflected_dunder()); + + let calls_to_try = if strict_rhs_subclass { + [ + (&inplace_name, lhs, rhs), + (&reflected_name, rhs, lhs), + (&forward_name, lhs, rhs), + ] + } else { + [ + (&inplace_name, lhs, rhs), + (&forward_name, lhs, rhs), + (&reflected_name, rhs, lhs), + ] + }; + let mut result = self.try_binop_calls(&calls_to_try, range, errors, &context); + + let mut union_parts = Vec::new(); + if let Some(ret) = self.binop_call_type(lhs, &inplace_name, rhs, range) { + union_parts.push(ret); + } + if let Some(ret) = self.binop_call_type(lhs, &forward_name, rhs, range) { + union_parts.push(ret); + } + if let Some(ret) = self.binop_call_type(rhs, &reflected_name, lhs, range) { + union_parts.push(ret); + } + if !union_parts.is_empty() { + let extra = self.unions(union_parts); + if !self.is_equal(&result, &extra) { + result = self.union(result, extra); + } + } + result }; let base = self.expr_infer(&x.target, errors); let rhs = self.expr_infer(&x.value, errors); diff --git a/pyrefly/lib/test/operators.rs b/pyrefly/lib/test/operators.rs index 15bf37c856..f22e0d1fdc 100644 --- a/pyrefly/lib/test/operators.rs +++ b/pyrefly/lib/test/operators.rs @@ -413,6 +413,39 @@ assert_type(a, A) "#, ); +testcase!( + test_reflected_operator_incompatible_return, + r#" +from typing import assert_type + +class A: + def __add__(self, other: "A") -> int: + return 1 + +class B(A): + def __radd__(self, other: "A") -> str: # E: Class member `B.__radd__` returns `str` which is incompatible with the return type `int` of parent operator `A.__add__` when the reflected operator may be selected + return "B" + +assert_type(A() + B(), str) + "#, +); + +testcase!( + test_inplace_operator_incompatible_return, + r#" +class A: + def __add__(self, other: "A") -> int: + return 1 + +class B(A): + def __iadd__(self, other: "A") -> str: # E: Class member `B.__iadd__` returns `str` which is incompatible with the return type `int` of parent operator `A.__add__` used for augmented assignment + return "a" + +def f(a: B, b: A) -> None: + a += b # E: Augmented assignment produces a value of type `int | str`, which is not assignable to `B` + "#, +); + // We try __iadd__ and some fallback dunders. When all fail, the least confusing option is to use __iadd__. testcase!( test_iadd_error, @@ -581,6 +614,6 @@ class A: class B: def __radd__(self, other) -> Self: return self -assert_type(A() + B(), A) # E: `A.__add__` is deprecated +assert_type(A() + B(), A | B) # E: `A.__add__` is deprecated "#, );