From 5e3aaa084ab855394e933bf9faa2a4248808a5aa Mon Sep 17 00:00:00 2001 From: Nina Shvetsova Date: Thu, 2 Jul 2026 14:55:08 +0000 Subject: [PATCH] Follow up fix custom flash block sizes fallback in Ulysses attention Ensure that block sizes and heads_per_tile fall back to default values when resolved as None from CustomFlashBlockSizes dataclass. This fixes a TypeError in ulysses_custom attention when heads_per_tile is not specified. --- src/maxdiffusion/models/attention_flax.py | 24 +++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index f6edc8309..76cd95c50 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -362,19 +362,19 @@ def _extract_custom_block_sizes(flash_block_sizes): if flash_block_sizes is not None: if isinstance(flash_block_sizes, dict): get = flash_block_sizes.get - bq = get("block_q", bq) - bkv = get("block_kv", bkv) - bkv_compute = get("block_kv_compute", bkv_compute) - bkv_compute_in = get("block_kv_compute_in", bkv_compute_in) - heads_per_tile = get("heads_per_tile", heads_per_tile) - vmem_limit_bytes = get("vmem_limit_bytes", vmem_limit_bytes) + bq = get("block_q", None) or bq + bkv = get("block_kv", None) or bkv + bkv_compute = get("block_kv_compute", None) or bkv_compute + bkv_compute_in = get("block_kv_compute_in", None) or bkv_compute_in + heads_per_tile = get("heads_per_tile", None) or heads_per_tile + vmem_limit_bytes = get("vmem_limit_bytes", None) or vmem_limit_bytes else: - bq = getattr(flash_block_sizes, "block_q", bq) - bkv = getattr(flash_block_sizes, "block_kv", bkv) - bkv_compute = getattr(flash_block_sizes, "block_kv_compute", bkv_compute) - bkv_compute_in = getattr(flash_block_sizes, "block_kv_compute_in", bkv_compute_in) - heads_per_tile = getattr(flash_block_sizes, "heads_per_tile", heads_per_tile) - vmem_limit_bytes = getattr(flash_block_sizes, "vmem_limit_bytes", vmem_limit_bytes) + bq = getattr(flash_block_sizes, "block_q", None) or bq + bkv = getattr(flash_block_sizes, "block_kv", None) or bkv + bkv_compute = getattr(flash_block_sizes, "block_kv_compute", None) or bkv_compute + bkv_compute_in = getattr(flash_block_sizes, "block_kv_compute_in", None) or bkv_compute_in + heads_per_tile = getattr(flash_block_sizes, "heads_per_tile", None) or heads_per_tile + vmem_limit_bytes = getattr(flash_block_sizes, "vmem_limit_bytes", None) or vmem_limit_bytes return bq, bkv, bkv_compute, bkv_compute_in, heads_per_tile, vmem_limit_bytes