diff --git a/src/maxtext/kernels/gather_reduce_pallas.py b/src/maxtext/kernels/gather_reduce_pallas.py index 56fa07cac5..c1ded104c9 100644 --- a/src/maxtext/kernels/gather_reduce_pallas.py +++ b/src/maxtext/kernels/gather_reduce_pallas.py @@ -93,9 +93,9 @@ def sc_gather_reduce( mesh=plsc.VectorSubcoreMesh( core_axis_name="core", subcore_axis_name="subcore", - num_cores=1 if single_sc else 2, + num_cores=1 if single_sc else sc_info.num_cores, ), - compiler_params=pltpu.CompilerParams(needs_layout_passes=True), + compiler_params=pltpu.CompilerParams(needs_layout_passes=True, use_tc_tiling_on_sc=True), ) def kernel(in_hbm_ref, idx_hbm_ref, weights_hbm_ref, out_hbm_ref): row_wave_size = row_chunk_size * lax.axis_size(("core", "subcore")) @@ -118,7 +118,7 @@ def kernel(in_hbm_ref, idx_hbm_ref, weights_hbm_ref, out_hbm_ref): def idx_pipeline(idx_ref, weights_ref=None): row_chunk_idx = subcore_first_row_chunk + pl.program_id(0) - row_subchunk_size = 16 + row_subchunk_size = sc_info.num_lanes out_rows_per_step = row_subchunk_size // reduce_group_size assert reduce_group_size * out_rows_per_step == sc_info.num_lanes num_row_subchunks = row_chunk_size // row_subchunk_size