Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 192 additions & 0 deletions pyrefly/lib/alt/class/class_field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ use crate::types::callable::Param;
use crate::types::callable::Required;
use crate::types::class::Class;
use crate::types::class::ClassType;
use crate::types::display::TypeDisplayContext;
use crate::types::keywords::DataclassFieldKeywords;
use crate::types::literal::Lit;
use crate::types::quantified::Quantified;
Expand Down Expand Up @@ -158,6 +159,51 @@ impl ClassAttribute {
}
}

#[derive(Clone, Copy)]
enum OperatorCompatibilityKind {
Reflected,
InPlace,
}

fn reflected_operator_forward_name(name: &Name) -> Option<&'static str> {
match name.as_str() {
"__radd__" => Some("__add__"),
"__rsub__" => Some("__sub__"),
"__rmul__" => Some("__mul__"),
"__rmatmul__" => Some("__matmul__"),
"__rtruediv__" => Some("__truediv__"),
"__rfloordiv__" => Some("__floordiv__"),
"__rmod__" => Some("__mod__"),
"__rdivmod__" => Some("__divmod__"),
"__rpow__" => Some("__pow__"),
"__rlshift__" => Some("__lshift__"),
"__rrshift__" => Some("__rshift__"),
"__rand__" => Some("__and__"),
"__rxor__" => Some("__xor__"),
"__ror__" => Some("__or__"),
_ => None,
}
}

fn inplace_operator_forward_name(name: &Name) -> Option<&'static str> {
match name.as_str() {
"__iadd__" => Some("__add__"),
"__isub__" => Some("__sub__"),
"__imul__" => Some("__mul__"),
"__imatmul__" => Some("__matmul__"),
"__itruediv__" => Some("__truediv__"),
"__ifloordiv__" => Some("__floordiv__"),
"__imod__" => Some("__mod__"),
"__ipow__" => Some("__pow__"),
"__ilshift__" => Some("__lshift__"),
"__irshift__" => Some("__rshift__"),
"__iand__" => Some("__and__"),
"__ixor__" => Some("__xor__"),
"__ior__" => Some("__or__"),
_ => None,
}
}

#[derive(Debug, Clone, TypeEq, PartialEq, Eq, VisitMut)]
pub struct Descriptor {
/// The location of the property where the descriptor is bound, where we should raise
Expand Down Expand Up @@ -1829,6 +1875,150 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
true
}

fn check_operator_compatibility_for_field(
&self,
cls: &Class,
field_name: &Name,
bases: &ClassBases,
range: TextRange,
errors: &ErrorCollector,
) {
if let Some(forward) = reflected_operator_forward_name(field_name) {
self.check_operator_compatibility_with_forward(
OperatorCompatibilityKind::Reflected,
forward,
cls,
field_name,
bases,
range,
errors,
);
}
if let Some(forward) = inplace_operator_forward_name(field_name) {
self.check_operator_compatibility_with_forward(
OperatorCompatibilityKind::InPlace,
forward,
cls,
field_name,
bases,
range,
errors,
);
}
}

fn check_operator_compatibility_with_forward(
&self,
kind: OperatorCompatibilityKind,
forward_str: &'static str,
cls: &Class,
field_name: &Name,
bases: &ClassBases,
range: TextRange,
errors: &ErrorCollector,
) {
if bases.is_empty() {
return;
}

let forward_name = Name::new_static(forward_str);
let subclass_class_type = self.as_class_type_unchecked(cls);
let subclass_type = Type::ClassType(subclass_class_type.clone());

for parent in bases.iter() {
let parent_cls = parent.class_object();
if self.get_class_member(parent_cls, field_name).is_some() {
continue;
}

let Some(_) = self.get_class_member(parent_cls, &forward_name) else {
continue;
};

let lhs_type = parent.clone().to_type();
let mut forward_arg_ty = subclass_type.clone();
let forward_args = [CallArg::ty(&forward_arg_ty, range)];
let forward_errors = self.error_collector();
let forward_ret = self.call_magic_dunder_method(
&lhs_type,
&forward_name,
range,
&forward_args,
&[],
&forward_errors,
None,
);
if !forward_errors.is_empty() {
continue;
}
let Some(forward_ret) = forward_ret else {
continue;
};
if forward_ret.is_error() {
continue;
}

let method_receiver_ty = subclass_type.clone();
let mut method_arg_ty = lhs_type.clone();
let method_args = [CallArg::ty(&method_arg_ty, range)];
let method_errors = self.error_collector();
let method_ret = self.call_magic_dunder_method(
&method_receiver_ty,
field_name,
range,
&method_args,
&[],
&method_errors,
None,
);
if !method_errors.is_empty() {
continue;
}
let Some(method_ret) = method_ret else {
continue;
};
if method_ret.is_error() {
continue;
}

if self.is_subset_eq(&method_ret, &forward_ret) {
continue;
}

let method_ret_display = self.for_display(method_ret.clone());
let forward_ret_display = self.for_display(forward_ret.clone());
let ctx = TypeDisplayContext::new(&[&method_ret_display, &forward_ret_display]);

let message = match kind {
OperatorCompatibilityKind::Reflected => format!(
"Class member `{}.{}` returns `{}` which is incompatible with the return type `{}` of parent operator `{}.{}` when the reflected operator may be selected",
cls.name().as_str(),
field_name.as_str(),
ctx.display(&method_ret_display),
ctx.display(&forward_ret_display),
parent.name().as_str(),
forward_name.as_str(),
),
OperatorCompatibilityKind::InPlace => format!(
"Class member `{}.{}` returns `{}` which is incompatible with the return type `{}` of parent operator `{}.{}` used for augmented assignment",
cls.name().as_str(),
field_name.as_str(),
ctx.display(&method_ret_display),
ctx.display(&forward_ret_display),
parent.name().as_str(),
forward_name.as_str(),
),
};

self.error(
errors,
range,
ErrorInfo::Kind(ErrorKind::BadOverride),
message,
);
}
}

pub fn check_consistent_override_for_field(
&self,
cls: &Class,
Expand Down Expand Up @@ -2001,6 +2191,8 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
),
);
}

self.check_operator_compatibility_for_field(cls, field_name, bases, range, errors);
}

/// For classes with multiple inheritance, check that fields inherited from multiple base classes are consistent.
Expand Down
147 changes: 133 additions & 14 deletions pyrefly/lib/alt/operators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ use crate::error::context::ErrorInfo;
use crate::error::context::TypeCheckContext;
use crate::error::context::TypeCheckKind;
use crate::graph::index::Idx;
use crate::types::class::ClassType;
use crate::types::literal::Lit;
use crate::types::tuple::Tuple;
use crate::types::types::Type;
Expand Down Expand Up @@ -210,6 +211,32 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
}
}

fn binop_call_type(
&self,
target: &Type,
method_name: &Name,
arg: &Type,
range: TextRange,
) -> Option<Type> {
let errors = self.error_collector();
let mut arg_ty = arg.clone();
let call_args = [CallArg::ty(&arg_ty, range)];
let ret = self.call_magic_dunder_method(
target,
method_name,
range,
&call_args,
&[],
&errors,
None,
);
if errors.is_empty() {
ret.filter(|ty| !ty.is_error())
} else {
None
}
}

pub fn binop_infer(
&self,
x: &ExprBinOp,
Expand All @@ -224,14 +251,62 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
self.for_display(rhs.clone()),
)
};
// Reflected operator implementation: This deviates from the runtime semantics by calling the reflected dunder if the regular dunder call errors.
// At runtime, the reflected dunder is called only if the regular dunder method doesn't exist or if it returns NotImplemented.
// This deviation is necessary, given that the typeshed stubs don't record when NotImplemented is returned
let calls_to_try = [
(&Name::new_static(op.dunder()), lhs, rhs),
(&Name::new_static(op.reflected_dunder()), rhs, lhs),
];
self.try_binop_calls(&calls_to_try, range, errors, &context)
let class_type_of = |ty: &Type| -> Option<ClassType> {
match ty {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mapping over type like this is almost never what we want: it's hard to maintain and prone to errors

Type::ClassType(cls) => Some(cls.clone()),
Type::SelfType(cls) => Some(cls.clone()),
Type::Literal(Lit::Enum(lit_enum)) => Some(lit_enum.class.clone()),
_ => None,
}
};
let lhs_cls = class_type_of(lhs);
let rhs_cls = class_type_of(rhs);
let strict_rhs_subclass = matches!((&rhs_cls, &lhs_cls), (Some(rhs_cls), Some(lhs_cls))
if rhs_cls.class_object() != lhs_cls.class_object()
&& self.has_superclass(rhs_cls.class_object(), lhs_cls.class_object()));

let forward_name = Name::new_static(op.dunder());
let reflected_name = Name::new_static(op.reflected_dunder());
let calls_to_try = if strict_rhs_subclass {
[(&reflected_name, rhs, lhs), (&forward_name, lhs, rhs)]
} else {
[(&forward_name, lhs, rhs), (&reflected_name, rhs, lhs)]
};
let mut result = self.try_binop_calls(&calls_to_try, range, errors, &context);

if !strict_rhs_subclass {
let forward_ret = self.binop_call_type(lhs, &forward_name, rhs, range);
let reflected_ret = self.binop_call_type(rhs, &reflected_name, lhs, range);
let extra = match (forward_ret, reflected_ret) {
(Some(f), Some(r)) => {
if self.is_equal(&f, &r) && self.is_equal(&result, &f) {
None
} else {
Some(self.union(f, r))
}
}
(Some(f), None) => {
if self.is_equal(&result, &f) {
None
} else {
Some(f)
}
}
(None, Some(r)) => {
if self.is_equal(&result, &r) {
None
} else {
Some(r)
}
}
(None, None) => None,
};
if let Some(extra) = extra {
result = self.union(result, extra);
}
}

result
};
// If the expression is of the form [X] * Y where Y is a number, pass down the contextual
// type hint when evaluating [X]
Expand Down Expand Up @@ -304,12 +379,56 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
self.for_display(rhs.clone()),
)
};
let calls_to_try = [
(&Name::new_static(op.in_place_dunder()), lhs, rhs),
(&Name::new_static(op.dunder()), lhs, rhs),
(&Name::new_static(op.reflected_dunder()), rhs, lhs),
];
self.try_binop_calls(&calls_to_try, range, errors, &context)
let class_type_of = |ty: &Type| -> Option<ClassType> {
match ty {
Type::ClassType(cls) => Some(cls.clone()),
Type::SelfType(cls) => Some(cls.clone()),
Type::Literal(Lit::Enum(lit_enum)) => Some(lit_enum.class.clone()),
_ => None,
}
};
let lhs_cls = class_type_of(lhs);
let rhs_cls = class_type_of(rhs);
let strict_rhs_subclass = matches!((&rhs_cls, &lhs_cls), (Some(rhs_cls), Some(lhs_cls))
if rhs_cls.class_object() != lhs_cls.class_object()
&& self.has_superclass(rhs_cls.class_object(), lhs_cls.class_object()));

let inplace_name = Name::new_static(op.in_place_dunder());
let forward_name = Name::new_static(op.dunder());
let reflected_name = Name::new_static(op.reflected_dunder());

let calls_to_try = if strict_rhs_subclass {
[
(&inplace_name, lhs, rhs),
(&reflected_name, rhs, lhs),
(&forward_name, lhs, rhs),
]
} else {
[
(&inplace_name, lhs, rhs),
(&forward_name, lhs, rhs),
(&reflected_name, rhs, lhs),
]
};
let mut result = self.try_binop_calls(&calls_to_try, range, errors, &context);

let mut union_parts = Vec::new();
if let Some(ret) = self.binop_call_type(lhs, &inplace_name, rhs, range) {
union_parts.push(ret);
}
if let Some(ret) = self.binop_call_type(lhs, &forward_name, rhs, range) {
union_parts.push(ret);
}
if let Some(ret) = self.binop_call_type(rhs, &reflected_name, lhs, range) {
union_parts.push(ret);
}
if !union_parts.is_empty() {
let extra = self.unions(union_parts);
if !self.is_equal(&result, &extra) {
result = self.union(result, extra);
}
}
result
};
let base = self.expr_infer(&x.target, errors);
let rhs = self.expr_infer(&x.value, errors);
Expand Down
Loading