diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index f6edc8309..76cd95c50 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -362,19 +362,19 @@ def _extract_custom_block_sizes(flash_block_sizes): if flash_block_sizes is not None: if isinstance(flash_block_sizes, dict): get = flash_block_sizes.get - bq = get("block_q", bq) - bkv = get("block_kv", bkv) - bkv_compute = get("block_kv_compute", bkv_compute) - bkv_compute_in = get("block_kv_compute_in", bkv_compute_in) - heads_per_tile = get("heads_per_tile", heads_per_tile) - vmem_limit_bytes = get("vmem_limit_bytes", vmem_limit_bytes) + bq = get("block_q", None) or bq + bkv = get("block_kv", None) or bkv + bkv_compute = get("block_kv_compute", None) or bkv_compute + bkv_compute_in = get("block_kv_compute_in", None) or bkv_compute_in + heads_per_tile = get("heads_per_tile", None) or heads_per_tile + vmem_limit_bytes = get("vmem_limit_bytes", None) or vmem_limit_bytes else: - bq = getattr(flash_block_sizes, "block_q", bq) - bkv = getattr(flash_block_sizes, "block_kv", bkv) - bkv_compute = getattr(flash_block_sizes, "block_kv_compute", bkv_compute) - bkv_compute_in = getattr(flash_block_sizes, "block_kv_compute_in", bkv_compute_in) - heads_per_tile = getattr(flash_block_sizes, "heads_per_tile", heads_per_tile) - vmem_limit_bytes = getattr(flash_block_sizes, "vmem_limit_bytes", vmem_limit_bytes) + bq = getattr(flash_block_sizes, "block_q", None) or bq + bkv = getattr(flash_block_sizes, "block_kv", None) or bkv + bkv_compute = getattr(flash_block_sizes, "block_kv_compute", None) or bkv_compute + bkv_compute_in = getattr(flash_block_sizes, "block_kv_compute_in", None) or bkv_compute_in + heads_per_tile = getattr(flash_block_sizes, "heads_per_tile", None) or heads_per_tile + vmem_limit_bytes = getattr(flash_block_sizes, "vmem_limit_bytes", None) or vmem_limit_bytes return bq, bkv, bkv_compute, bkv_compute_in, heads_per_tile, vmem_limit_bytes