general product: handle fused-broadcast-on-the-right (rank-0 right operand)#568
Merged
evaleev merged 1 commit intoJun 17, 2026
Merged
Conversation
…erand)
A general product whose RIGHT operand is entirely fused with no contraction
(a fused broadcast on the right, e.g. C("b,k") = A("b,k") * B("b")) folds to a
rank-0 right operand, which the batched GEMM cannot host (TA rank-0 ranges are
null/volume-0, and gemm asserts operand.rank()==gemm_helper.rank()). The
synthetic-unit machinery covered only the rank-0 RESULT (no-external product)
and rank-0 LEFT (fused-broadcast-left) cases via a unit LEFT-external mode; the
rank-0 RIGHT case was unhandled, so folding the right operand aborted in
BatchedContractReduce::contract_pair (or silently mis-shaped with asserts off).
Add a symmetric synthetic unit RIGHT-external mode, mirroring the left one:
- ContEngine::synthetic_unit_right_external() detects outer_size(right_indices_)
== n_fused_modes_; u_right is threaded through init_struct_general (op_ ctors,
right_op NoTranspose), make_trange_general and init_distribution_general
(neB -= u_right).
- SparseShape::gemm_batched detects the right phantom by the one-rank mismatch
and guards the right-outer loops / result rank accordingly (shape-level
analog); its fold_range lambda gains an append_unit option.
- BatchedContractReduce::contract_pair detects unit_right_external, excludes it
from neB, and pads the folded right (and accumulating result) views with a
trailing unit; the member fold_range gains an append_unit option.
The synthetic mode lives only in the GemmHelper; tranges, shapes and tiles
carry the true ranks. The result gains a trailing unit mode (squeezed out of
the actual result) exactly as the left case prepends one.
Tests: general_product gains dense, ToT->ToT and block-sparse fused-broadcast-
right cases through the expression layer (which respects operand order, unlike
einsum's reordering); dot_inner gains a ->T denest fused-broadcast-right case.
This was surfaced by MPQC PNO/CSV-CCk denest products evaluated via the native
ToT*ToT->T dispatch.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
A general product whose right operand is entirely fused with no contraction — a fused broadcast on the right, e.g.
C("b,k") = A("b,k") * B("b")— folds to a rank-0 right operand. TiledArray cannot host that: rank-0Ranges are null (volume 0) by design, andgemm/GemmHelperassertoperand.rank() == gemm_helper.rank().The general-product engine already avoids rank-0 folded operands with a synthetic unit mode carried only in the
GemmHelper, butContEngine::synthetic_unit_left_external()covered only two of the three degeneracies:The rank-0 right case was unhandled — a left-external unit can't fix it (
right_rankstays 0). So folding the right operand aborted inBatchedContractReduce::contract_pair(Tensor::reshapevolume assert), or silently mis-shaped the tile with asserts off.This surfaced via MPQC PNO/CSV-CCSD denest products (
A(b,k;…)·B(b;…)withba per-pair batch index) evaluated through the nativeToT*ToT->Tdispatch added in #567.Fix
Add a symmetric synthetic unit right-external mode, mirroring the existing left one:
cont_engine.h: newsynthetic_unit_right_external()(detectsouter_size(right_indices_) == n_fused_modes_);u_rightthreaded throughinit_struct_general(bothop_ctors;right_opisNoTransposefor the synthetic mode),make_trange_generalandinit_distribution_general(neB -= u_right).sparse_shape.h:SparseShape::gemm_batched(shape-level analog) detects the right phantom by the one-rank mismatch and guards the right-outer loops / result rank; itsfold_rangelambda gains anappend_unitoption.batched_contract_reduce.h:contract_pairdetectsunit_right_external, excludes it fromneB, and pads the folded right (and accumulating result) views with a trailing unit; the memberfold_rangegains anappend_unitoption.The synthetic mode lives only in the
GemmHelper; tranges, shapes and tiles carry the true ranks. The result gains a trailing unit mode (squeezed out of the actual result) exactly as the left case prepends one.Tests
general_product: new dense,ToT*ToT->ToT, and block-sparse fused-broadcast-right cases, all through the expression layer (which respects operand order — unlikeeinsum, whose operand reordering hides the shape).dot_inner: new->Tdenest fused-broadcast-right case.All
general_product(61) anddot_innersuites pass, pluseinsum_tiledarray/einsum_tot/einsum_tot_t/arena_einsum_unit_suite/expressions_suite(112) with no regressions. Also verified end-to-end against the originating MPQC CSV-CCSD case (correct to ~1e-11; previously crashed under asserts / silently wrong under Apple Accelerate).