Skip to content

feat(compute): native bf16 GPU transpose kernels (capture-safe)#160

Merged
dndungu merged 2 commits into
mainfrom
feat/bf16-gpu-transpose
Jun 17, 2026
Merged

feat(compute): native bf16 GPU transpose kernels (capture-safe)#160
dndungu merged 2 commits into
mainfrom
feat/bf16-gpu-transpose

Conversation

@dndungu

@dndungu dndungu commented Jun 17, 2026

Copy link
Copy Markdown
Contributor

What

Native bf16 (16-bit) GPU transpose kernels so bf16 transposes stay on-device and are CUDA-graph-capturable.

Why

GPUEngine.Transpose routed every non-float32 type to the CPU engine, whose host memcpy breaks CUDA-graph capture: any bf16 transpose under capture failed with operation would make the legacy stream depend on a capturing blocking stream (e.g. QKL2Norm's Transpose). This forced the bf16 CrossAsset GPU bench to run with capture disabled (~190 s/epoch), an unrepresentative speed number.

What changed

Transpose is pure data movement, so the bf16 kernels operate on unsigned short — a bitwise element copy independent of the bf16 numeric interpretation (no bf16 math, no new headers):

  • transpose.cu: kernel_transpose_2d_bf16 + kernel_transpose_nd_bf16 (+ launchers)
  • internal/cuda/kernels: Transpose2DBF16 / TransposeNDBF16 (cuda + purego builds) + dlopen symbol registration
  • gpuapi: optional BFloat16Transposer extension + CUDAKernels impl + assertion
  • GPUEngine.Transpose: for bf16 with GPU-resident input and a backend implementing BFloat16Transposer, transpose on-device via the bf16 kernels; otherwise the CPU fallback as before. The f32 path is byte-for-byte unchanged (element size threaded only through byteSize + kernel selection).

Verification

  • Local: build + vet + non-CUDA compute/gpuapi tests green.
  • GB10 (sm_121): full bf16 compute suite GREEN incl. new TestGPUBF16_TransposeParity (2D + 3D [0,2,1], the QKL2Norm shape), exact-match (ok compute 3.521s).

Lets the bf16 CrossAsset GPU bench run with CUDA-graph capture ON (representative s/epoch). Final piece of the bf16 GPU backward chain (ztensor v1.16.0 NT/TN + zerfoo v1.53.1 grad-accum). ADR-075 lever L4.

dndungu added 2 commits June 16, 2026 23:05
Extends the ADR-091 PyTorch-oracle / gradcheck harness to the GroupNorm op
class (zerfoo E127/T127.1.0a, first of six new diffusion-DiT op classes).

GroupNorm composes entirely from existing engine reduce/elementwise ops:
reshape [N,C] -> [N*groups, C/groups], normalize the last axis exactly like
the LayerNorm node, reshape back, apply a per-channel affine. No new engine
kernel. Adds the node (gradcheck/ops.go), the registry entry + dispatch
(registry.go, dim=4 groups=2), and the torch replay + tolerance (torchmap.go,
torch.nn.functional.group_norm).

Verified: TestRegistry/GroupNorm gradcheck passes (analytic backward vs
finite-difference); full gradcheck + oracle registry<->torchmap lockstep green.
Unlocks the convolutional-VAE/UNet GroupNorm primitive for the diffusion class.
GPUEngine.Transpose routed every non-float32 type to the CPU engine, whose host
memcpy breaks CUDA-graph capture -- so any bf16 transpose under capture failed
("operation would make the legacy stream depend on a capturing blocking stream",
e.g. node QKL2Norm's Transpose). This forced the bf16 CrossAsset GPU bench to run
with capture DISABLED (~190 s/epoch).

Add native bf16 (16-bit) transpose kernels. Transpose is pure data movement, so
the kernels operate on `unsigned short` -- a bitwise element copy independent of
the bf16 numeric interpretation (no bf16 math, no new headers):

  - transpose.cu: kernel_transpose_2d_bf16 + kernel_transpose_nd_bf16 (+ launchers)
  - cuda/kernels: Transpose2DBF16 / TransposeNDBF16 (cuda + purego builds) and
    dlopen symbol registration
  - gpuapi: optional BFloat16Transposer extension + CUDAKernels impl
  - GPUEngine.Transpose: for bf16 with GPU-resident input and a backend that
    implements BFloat16Transposer, transpose on-device via the bf16 kernels;
    otherwise fall back to the CPU engine as before. The f32 path is byte-for-byte
    unchanged (element size threaded only through the byteSize/kernel-select).

CUDA-gated parity tests: 2D + 3D[0,2,1] (the QKL2Norm shape), exact-match.
Lets the bf16 CrossAsset GPU bench run with CUDA-graph capture ON. ADR-075 L4.
@dndungu dndungu merged commit 25f5981 into main Jun 17, 2026
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant