Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions backends/mlx/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1687,6 +1687,41 @@ def _repeat_handler(P: MLXProgramBuilder, n: Node) -> Slot:
return out


@REGISTRY.register(target=[torch.ops.aten.flip.default])
def _flip_handler(P: MLXProgramBuilder, n: Node) -> Slot:
args = P.args(n)
require_args(args, 2, 2, "aten.flip")
require_kwargs(P.kwargs(n), set(), "aten.flip")
x, dims_arg = args

dims: List[int] = [dims_arg] if isinstance(dims_arg, int) else list(dims_arg)
require_static_ints(dims, "dims", "aten.flip")

x_meta = n.args[0].meta.get("val")
if x_meta is None:
raise RuntimeError("aten.flip: missing tensor metadata")
ndim = len(x_meta.shape)
if len(set(d % ndim for d in dims)) != len(dims):
raise ValueError(f"aten.flip: dims must be unique, got {dims}")

out = P.make_or_get_slot(n)
current = x
for i, dim in enumerate(dims):
reverse_out = out if i == len(dims) - 1 else P.make_tmp_slot()[1]
P.emit(
SliceNode(
x=P.slot_to_tid(current),
out=P.slot_to_tid(reverse_out),
axis=P.to_int_or_vid(dim),
start=P.to_int_or_vid(-1),
stop=P.to_int_or_vid(-(x_meta.shape[dim % ndim] + 1)),
step=-1,
)
)
current = reverse_out
return out


@REGISTRY.register(target=[torch.ops.aten.roll.default])
def _roll_handler(P: MLXProgramBuilder, n: Node) -> Slot:
args = P.args(n)
Expand Down
49 changes: 49 additions & 0 deletions backends/mlx/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,55 @@ def create_inputs(self) -> Tuple[torch.Tensor, ...]:
return (x,)


class FlipModel(nn.Module):
"""Model that flips a tensor along specified dimensions."""

def __init__(self, dims: Tuple[int, ...]):
super().__init__()
self.dims = dims

def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.flip(x, dims=self.dims)


@register_test
class FlipTest(OpTestCase):
"""Test case for torch.flip()."""

name = "flip"
rtol = 1e-5
atol = 1e-5

def __init__(
self,
input_shape: Tuple[int, ...] = (4, 5),
dims: Tuple[int, ...] = (0,),
):
self.input_shape = input_shape
self.dims = dims
dim_str = ",".join(str(d) for d in dims)
self.name = f"flip_dim({dim_str})"
Comment on lines +884 to +885

@classmethod
def get_test_configs(cls) -> List["FlipTest"]:
return [
cls(input_shape=(8,), dims=(0,)),
cls(input_shape=(4, 5), dims=(0,)),
cls(input_shape=(4, 5), dims=(1,)),
cls(input_shape=(3, 4, 5), dims=(2,)),
cls(input_shape=(3, 4, 5), dims=(0, 2)),
cls(input_shape=(3, 4, 5), dims=(0, 1, 2)),
cls(input_shape=(3, 4, 5), dims=(-1,)),
]

def create_model(self) -> nn.Module:
return FlipModel(self.dims)

def create_inputs(self) -> Tuple[torch.Tensor, ...]:
x = torch.randn(self.input_shape)
return (x,)


class RollModel(nn.Module):
"""Model that rolls a tensor along specified dimensions."""

Expand Down
Loading