Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api-reference/expr.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
- alias
- all
- any
- any_value
- cast
- ceil
- clip
Expand Down
1 change: 1 addition & 0 deletions docs/api-reference/series.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
- alias
- all
- any
- any_value
- arg_max
- arg_min
- arg_true
Expand Down
17 changes: 13 additions & 4 deletions narwhals/_arrow/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class ArrowGroupBy(EagerGroupBy["ArrowDataFrame", "ArrowExpr", "Aggregation"]):
"any": "any",
"first": "first",
"last": "last",
"any_value": "first",
}
_REMAP_UNIQUE: ClassVar[Mapping[UniqueKeepStrategy, Aggregation]] = {
"any": "min",
Expand All @@ -52,7 +53,7 @@ class ArrowGroupBy(EagerGroupBy["ArrowDataFrame", "ArrowExpr", "Aggregation"]):
)
_OPTION_COUNT_VALID: ClassVar[frozenset[NarwhalsAggregation]] = frozenset(("count",))
_OPTION_ORDERED: ClassVar[frozenset[NarwhalsAggregation]] = frozenset(
("first", "last")
("first", "last", "any_value")
)
_OPTION_VARIANCE: ClassVar[frozenset[NarwhalsAggregation]] = frozenset(("std", "var"))
_OPTION_SCALAR: ClassVar[frozenset[NarwhalsAggregation]] = frozenset(
Expand Down Expand Up @@ -89,11 +90,19 @@ def _configure_agg(
elif function_name in self._OPTION_SCALAR:
option = pc.ScalarAggregateOptions(min_count=0)
elif function_name in self._OPTION_ORDERED:
grouped, option = self._ordered_agg(grouped, function_name)
ignore_nulls = kwargs.get("ignore_nulls", False)
grouped, option = self._ordered_agg(
grouped, function_name, ignore_nulls=ignore_nulls
)
return grouped, self._remap_expr_name(function_name), option

def _ordered_agg(
self, grouped: pa.TableGroupBy, name: NarwhalsAggregation, /
self,
grouped: pa.TableGroupBy,
name: NarwhalsAggregation,
/,
*,
ignore_nulls: bool,
) -> tuple[pa.TableGroupBy, AggregateOptions]:
"""The default behavior of `pyarrow` raises when `first` or `last` are used.

Expand All @@ -117,7 +126,7 @@ def _ordered_agg(
f"See https://github.com/apache/arrow/issues/36709"
)
raise NotImplementedError(msg)
return grouped, pc.ScalarAggregateOptions(skip_nulls=False)
return grouped, pc.ScalarAggregateOptions(skip_nulls=ignore_nulls)

def agg(self, *exprs: ArrowExpr) -> ArrowDataFrame:
self._ensure_all_simple(exprs)
Expand Down
5 changes: 5 additions & 0 deletions narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,6 +1104,11 @@ def exp(self) -> Self:
def sqrt(self) -> Self:
return self._with_native(pc.sqrt(self.native))

def any_value(
self, *, ignore_nulls: bool, _return_py_scalar: bool = True
) -> PythonLiteral:
return self.drop_nulls().first() if ignore_nulls else self.first()

@property
def dt(self) -> ArrowSeriesDateTimeNamespace:
return ArrowSeriesDateTimeNamespace(self)
Expand Down
6 changes: 6 additions & 0 deletions narwhals/_compliant/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def broadcast(self) -> Self: ...
def alias(self, name: str) -> Self: ...
def all(self) -> Self: ...
def any(self) -> Self: ...
def any_value(self, *, ignore_nulls: bool) -> Self: ...
def count(self) -> Self: ...
def min(self) -> Self: ...
def max(self) -> Self: ...
Expand Down Expand Up @@ -824,6 +825,11 @@ def first(self) -> Self:
def last(self) -> Self:
return self._reuse_series("last", returns_scalar=True)

def any_value(self, *, ignore_nulls: bool) -> Self:
return self._reuse_series(
"any_value", returns_scalar=True, ignore_nulls=ignore_nulls
)

@property
def cat(self) -> EagerExprCatNamespace[Self]:
return EagerExprCatNamespace(self)
Expand Down
1 change: 1 addition & 0 deletions narwhals/_compliant/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def __rmul__(self, other: Any) -> Self: ...
def __ror__(self, other: Any) -> Self: ...
def all(self) -> bool: ...
def any(self) -> bool: ...
def any_value(self, *, ignore_nulls: bool) -> PythonLiteral: ...
def arg_max(self) -> int: ...
def arg_min(self) -> int: ...
def arg_true(self) -> Self: ...
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_compliant/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,11 @@ class ScalarKwargs(TypedDict, total=False):
"any",
"first",
"last",
"any_value",
]
"""`Expr` methods we aim to support in `DepthTrackingGroupBy`.

Be sure to update me if you're working on one of these:
- https://github.com/narwhals-dev/narwhals/issues/981
- https://github.com/narwhals-dev/narwhals/issues/2385
- https://github.com/narwhals-dev/narwhals/issues/2484
- https://github.com/narwhals-dev/narwhals/issues/2526
Expand Down
1 change: 1 addition & 0 deletions narwhals/_dask/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,7 @@ def str(self) -> DaskExprStringNamespace:
def dt(self) -> DaskExprDateTimeNamespace:
return DaskExprDateTimeNamespace(self)

any_value = not_implemented()
filter = not_implemented()
first = not_implemented()
rank = not_implemented()
Expand Down
9 changes: 9 additions & 0 deletions narwhals/_duckdb/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,15 @@ def _first(self, expr: Expression, *order_by: str) -> Expression:
def _last(self, expr: Expression, *order_by: str) -> Expression:
return self._first_last("last", expr, order_by)

def _any_value(self, expr: Expression, *, ignore_nulls: bool) -> Expression:
# !NOTE: DuckDB `any_value` returns the first non-null value
# See: https://duckdb.org/docs/stable/sql/functions/aggregates#any_valuearg
return (
self._function("any_value", expr)
if ignore_nulls
else self._function("first", expr)
)

def __narwhals_namespace__(self) -> DuckDBNamespace: # pragma: no cover
from narwhals._duckdb.namespace import DuckDBNamespace

Expand Down
6 changes: 6 additions & 0 deletions narwhals/_ibis/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,12 @@ def _last(self, expr: ir.Value, *order_by: str) -> ir.Value:
order_by=self._sort(*order_by), include_null=True
)

def _any_value(self, expr: ir.Value, *, ignore_nulls: bool) -> ir.Value:
# !NOTE: ibis arbitrary returns a random non-null value
# See: https://ibis-project.org/reference/expression-generic.html#ibis.expr.types.generic.Column.arbitrary
expr = cast("ir.Column", expr)
return expr.arbitrary() if ignore_nulls else expr.first(include_null=True)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

expr.first(include_null=not ignore_nulls) is not behaving as expected 🫠


def __narwhals_namespace__(self) -> IbisNamespace: # pragma: no cover
from narwhals._ibis.namespace import IbisNamespace

Expand Down
19 changes: 12 additions & 7 deletions narwhals/_pandas_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,13 @@ def window_kwargs_to_pandas_equivalent( # noqa: C901
"min_periods": kwargs["min_samples"],
"ignore_na": kwargs["ignore_nulls"],
}
elif function_name in {"first", "last"}:
elif function_name in {"first", "last", "any_value"}:
if kwargs.get("ignore_nulls"):
msg = (
"`Expr.any_value(ignore_nulls=True)` is not supported in a `over` "
"context for pandas-like backend."
)
raise NotImplementedError(msg)
Comment on lines +113 to +118
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an issue if a group has all null values

pandas_kwargs = {
"n": _REMAP_ORDERED_INDEX[cast("NarwhalsAggregation", function_name)]
}
Expand Down Expand Up @@ -357,17 +363,16 @@ def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]: # noqa: C901,
msg = "Safety check failed, please report a bug."
raise AssertionError(msg)
res_native = grouped.transform("size").to_frame(aliases[0])
elif function_name in {"first", "last"}:
elif function_name in {"first", "last", "any_value"}:
with warnings.catch_warnings():
# Ignore settingwithcopy warnings/errors, they're false-positives here.
warnings.filterwarnings("ignore", message="\n.*copy of a slice")
_nth = getattr(
_agg = getattr(
grouped[[*partition_by, *aliases]], pandas_function_name
)(**pandas_kwargs)
_nth.reset_index(drop=True, inplace=True)
res_native = df.native[list(partition_by)].merge(
_nth, on=list(partition_by)
)[list(aliases)]
_agg.reset_index(drop=True, inplace=True)
keys = list(partition_by)
res_native = df.native[keys].merge(_agg, on=keys)[list(aliases)]
else:
res_native = grouped[list(aliases)].transform(
pandas_function_name, **pandas_kwargs
Expand Down
13 changes: 12 additions & 1 deletion narwhals/_pandas_like/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
_REMAP_ORDERED_INDEX: Mapping[NarwhalsAggregation, Literal[0, -1]] = {
"first": 0,
"last": -1,
"any_value": 0,
}


Expand Down Expand Up @@ -151,7 +152,7 @@ def _getitem_aggs(
for col in cols
]
)
elif self.is_last() or self.is_first():
elif self.is_last() or self.is_first() or self.is_any_value():
result = self.native_agg()(group_by._grouped[[*group_by._keys, *names]])
result.set_index(group_by._keys, inplace=True) # noqa: PD002
else:
Expand All @@ -175,6 +176,9 @@ def is_first(self) -> bool:
def is_mode(self) -> bool:
return self.leaf_name == "mode"

def is_any_value(self) -> bool:
return self.leaf_name == "any_value"

def is_top_level_function(self) -> bool:
# e.g. `nw.len()`.
return len(list(self.expr._metadata.op_nodes_reversed())) == 1
Expand All @@ -191,6 +195,12 @@ def native_agg(self) -> _NativeAgg:
native_name = PandasLikeGroupBy._remap_expr_name(self.leaf_name)
last_node = next(self.expr._metadata.op_nodes_reversed())
if self.leaf_name in _REMAP_ORDERED_INDEX:
if last_node.kwargs.get("ignore_nulls"):
msg = (
"`Expr.any_value(ignore_nulls=True)` is not supported in a `group_by` "
"context for pandas-like backend"
)
raise NotImplementedError(msg)
Comment on lines +198 to +203
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an issue if a group has all null values

return methodcaller("nth", n=_REMAP_ORDERED_INDEX[self.leaf_name])
return _native_agg(native_name, **last_node.kwargs)

Expand All @@ -215,6 +225,7 @@ class PandasLikeGroupBy(
"any": "any",
"first": "nth",
"last": "nth",
"any_value": "nth",
}
_original_columns: tuple[str, ...]
"""Column names *prior* to any aliasing in `ParseKeysGroupBy`."""
Expand Down
3 changes: 3 additions & 0 deletions narwhals/_pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1152,6 +1152,9 @@ def exp(self) -> Self:
def sqrt(self) -> Self:
return self._with_native(self.native.pow(0.5))

def any_value(self, *, ignore_nulls: bool) -> PythonLiteral:
return self.drop_nulls().first() if ignore_nulls else self.first()

@property
def str(self) -> PandasLikeSeriesStringNamespace:
return PandasLikeSeriesStringNamespace(self)
Expand Down
3 changes: 3 additions & 0 deletions narwhals/_polars/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,9 @@ def mode(self, *, keep: ModeKeepStrategy) -> Self:
result = self.native.mode()
return self._with_native(result.first() if keep == "any" else result)

def any_value(self, *, ignore_nulls: bool) -> Self:
return self.drop_nulls().first() if ignore_nulls else self.first()

@property
def dt(self) -> PolarsExprDateTimeNamespace:
return PolarsExprDateTimeNamespace(self)
Expand Down
3 changes: 3 additions & 0 deletions narwhals/_polars/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,9 @@ def last(self) -> PythonLiteral:
return self.native.item(-1) if len(self) else None
return self.native.last() # type: ignore[return-value]

def any_value(self, *, ignore_nulls: bool) -> PythonLiteral:
return self.drop_nulls().first() if ignore_nulls else self.first()

@property
def dt(self) -> PolarsSeriesDateTimeNamespace:
return PolarsSeriesDateTimeNamespace(self)
Expand Down
3 changes: 3 additions & 0 deletions narwhals/_spark_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ def _last(self, expr: Column, *order_by: str) -> Column: # pragma: no cover
msg = "`last` is not supported for PySpark."
raise NotImplementedError(msg)

def _any_value(self, expr: Column, *, ignore_nulls: bool) -> Column:
return self._F.any_value(expr, ignoreNulls=ignore_nulls)

def broadcast(self) -> Self:
return self.over([self._F.lit(1)], [])

Expand Down
6 changes: 6 additions & 0 deletions narwhals/_sql/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def _coalesce(self, *expr: NativeExprT) -> NativeExprT:
def _count_star(self) -> NativeExprT: ...
def _first(self, expr: NativeExprT, *order_by: str) -> NativeExprT: ...
def _last(self, expr: NativeExprT, *order_by: str) -> NativeExprT: ...
def _any_value(self, expr: NativeExprT, *, ignore_nulls: bool) -> NativeExprT: ...

def _when(
self,
Expand Down Expand Up @@ -759,6 +760,11 @@ def func(

return self._with_window_function(func)

def any_value(self, *, ignore_nulls: bool) -> Self:
return self._with_callable(
lambda expr: self._any_value(expr, ignore_nulls=ignore_nulls)
)

def rank(self, method: RankMethod, *, descending: bool) -> Self:
if method in {"min", "max", "average"}:
func = self._function("rank")
Expand Down
34 changes: 34 additions & 0 deletions narwhals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2368,6 +2368,40 @@ def is_close( # noqa: PLR0914

return result

def any_value(self, *, ignore_nulls: bool = False) -> Self:
"""Get a random value from the column.

Arguments:
ignore_nulls: Whether to ignore null values or not.
If `True` and there are no not-null elements, then `None` is returned.

Examples:
>>> import pandas as pd
>>> import narwhals as nw
>>> data = {"a": [1, 1, 2, 2], "b": [None, "foo", "baz", None]}
>>> df_native = pd.DataFrame(data)
>>> df = nw.from_native(df_native)
>>> df.select(nw.all().any_value(ignore_nulls=False))
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
|Narwhals DataFrame|
|------------------|
| a b |
| 0 1 None |
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

>>> df.group_by("a").agg(nw.col("b").any_value(ignore_nulls=True))
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
|Narwhals DataFrame|
|------------------|
| a b |
| 0 1 foo |
| 1 2 baz |
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
"""
return self._append_node(
ExprNode(ExprKind.AGGREGATION, "any_value", ignore_nulls=ignore_nulls)
)

@property
def str(self) -> ExprStringNamespace[Self]:
return ExprStringNamespace(self)
Expand Down
17 changes: 17 additions & 0 deletions narwhals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2844,6 +2844,23 @@ def is_close(
result = result.rename(orig_name) if name_is_none else result
return cast("Self", result)

def any_value(self, *, ignore_nulls: bool = False) -> PythonLiteral:
"""Get a random value from the column.

Arguments:
ignore_nulls: Whether to ignore null values or not.
If `True` and there are no not-null elements, then `None` is returned.

Examples:
>>> import pyarrow as pa
>>> import narwhals as nw
>>> s_native = pa.chunked_array([[1, 2, None]])
>>> s = nw.from_native(s_native, series_only=True)
>>> s.any_value(ignore_nulls=True)
1
"""
return self._compliant_series.any_value(ignore_nulls=ignore_nulls)

@property
def str(self) -> SeriesStringNamespace[Self]:
return SeriesStringNamespace(self)
Expand Down
Loading
Loading