From 97a4fe10378bb4be61c68be1dda9a0cc01f8e185 Mon Sep 17 00:00:00 2001 From: Amanda Liang Date: Wed, 24 Jun 2026 17:07:32 +0000 Subject: [PATCH] Enable TC tiling on SparseCore and use num_lanes for row_subchunk_size. --- src/maxtext/kernels/gather_reduce_pallas.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/maxtext/kernels/gather_reduce_pallas.py b/src/maxtext/kernels/gather_reduce_pallas.py index 56fa07cac5..c1ded104c9 100644 --- a/src/maxtext/kernels/gather_reduce_pallas.py +++ b/src/maxtext/kernels/gather_reduce_pallas.py @@ -93,9 +93,9 @@ def sc_gather_reduce( mesh=plsc.VectorSubcoreMesh( core_axis_name="core", subcore_axis_name="subcore", - num_cores=1 if single_sc else 2, + num_cores=1 if single_sc else sc_info.num_cores, ), - compiler_params=pltpu.CompilerParams(needs_layout_passes=True), + compiler_params=pltpu.CompilerParams(needs_layout_passes=True, use_tc_tiling_on_sc=True), ) def kernel(in_hbm_ref, idx_hbm_ref, weights_hbm_ref, out_hbm_ref): row_wave_size = row_chunk_size * lax.axis_size(("core", "subcore")) @@ -118,7 +118,7 @@ def kernel(in_hbm_ref, idx_hbm_ref, weights_hbm_ref, out_hbm_ref): def idx_pipeline(idx_ref, weights_ref=None): row_chunk_idx = subcore_first_row_chunk + pl.program_id(0) - row_subchunk_size = 16 + row_subchunk_size = sc_info.num_lanes out_rows_per_step = row_subchunk_size // reduce_group_size assert reduce_group_size * out_rows_per_step == sc_info.num_lanes num_row_subchunks = row_chunk_size // row_subchunk_size