Skip to content

fix: avoid CUDA crash from repetition_penalty in Fun-ASR-Nano vLLM prompt-embeds mode (#2948)#2974

Open
SuperMarioYL wants to merge 1 commit into
modelscope:mainfrom
SuperMarioYL:fix/vllm-repetition-penalty-prompt-embeds
Open

fix: avoid CUDA crash from repetition_penalty in Fun-ASR-Nano vLLM prompt-embeds mode (#2948)#2974
SuperMarioYL wants to merge 1 commit into
modelscope:mainfrom
SuperMarioYL:fix/vllm-repetition-penalty-prompt-embeds

Conversation

@SuperMarioYL

Copy link
Copy Markdown

Summary

The Fun-ASR-Nano vLLM serving paths run the engine with enable_prompt_embeds=True, so each request carries precomputed audio/text embeddings and has no prompt token IDs. vLLM applies repetition_penalty by scattering over the prompt's token IDs, so any value other than 1.0 indexes an empty token-id tensor and aborts the engine with:

ScatterGatherKernel.cu:163: operator(): Assertion `idx_dim >= 0 && idx_dim < index_size && "scatter gather kernel index out of bounds"` failed.

This is the crash reported in #2948. The batch generate() entry already defaulted to repetition_penalty=1.0, but three other prompt-embeds paths still forwarded 1.3 and crashed:

  • funasr/bin/_server_app.py — the OpenAI-compatible / REST server passes repetition_penalty=1.3 straight into the vLLM engine.
  • funasr/models/fun_asr_nano/inference_vllm_pipeline.pySamplingParams(..., repetition_penalty=1.3).
  • funasr/models/fun_asr_nano/inference_vllm_streaming.pySamplingParams(..., repetition_penalty=1.3).

Fix

Centralize the rule in a small, dependency-free helper, fun_asr_nano/vllm_utils.resolve_repetition_penalty(), which forces the neutral value (1.0) whenever the request runs in prompt-embeds mode and warns once so the behavior is discoverable. Every Fun-ASR-Nano SamplingParams construction now routes through it, so the failure mode is impossible regardless of the call site. The pipeline and streaming paths additionally now honor a caller-supplied repetition_penalty (sanitized) instead of silently ignoring it.

This keeps all serving paths consistent with the already-safe batch generate() entry.

Type of change

  • Bug fix
  • Documentation
  • Example or demo
  • Runtime or deployment

Validation

  • python -m compileall funasr tests (the new/edited files compile; the only error is the pre-existing Triton-DSL file funasr/models/sense_voice/whisper_lib/triton_ops.py, unrelated to this change)
  • python -m unittest tests.test_fun_asr_nano_repetition_penalty — new tests cover: neutral pass-through, None handling, clamping 1.3 → 1.0 in prompt-embeds mode, pass-through for regular token prompts, and warn-once behavior. The helper is dependency-free, so the tests run without a GPU or vLLM.

Failure mode / reproduction

Before this change, starting the server and transcribing audio that reaches the vLLM engine:

python serve_vllm.py --model FunAudioLLM/Fun-ASR-Nano-2512 --gpu-memory-utilization 0.5 --port 8899

aborts the engine with the CUDA scatter assertion above (most visible on longer audio, which produces more decode steps). After this change the penalty is neutralized for prompt-embeds requests and inference completes.

User impact

Anyone deploying Fun-ASR-Nano via the OpenAI-compatible/REST server, the pipeline path, or the streaming path — including agent integrations that call the transcription endpoint — no longer hits the CUDA crash. The supported batch path is unaffected (it already used 1.0).

Notes for reviewers

  • No behavior change for callers that already used repetition_penalty=1.0 (the default).
  • The only callers affected are those that previously crashed, so there is no working path that regresses.
  • The helper is intentionally standard-library only so it imports and tests cleanly without a CUDA device.
  • If you prefer the penalty to remain configurable for a future non-embeds decode path, resolve_repetition_penalty(value, prompt_embeds=False) passes the value through unchanged.

…ompt-embeds mode (modelscope#2948)

The Fun-ASR-Nano vLLM serving paths run with enable_prompt_embeds=True, so
requests carry audio/text embeddings and have no prompt token IDs. vLLM applies
repetition_penalty by scattering over the prompt token IDs, so any value other
than 1.0 indexes an empty tensor and aborts the engine with a CUDA
"scatter gather index out of bounds" assertion (issue modelscope#2948).

The batch generate() entry already defaulted to 1.0, but the OpenAI/REST server
(_server_app), the pipeline path and the streaming path still hardcoded
repetition_penalty=1.3 and crashed. Centralize the rule in a dependency-free
helper (resolve_repetition_penalty) that forces the neutral value in
prompt-embeds mode and warns once, and route every Fun-ASR-Nano SamplingParams
through it. The pipeline and streaming paths now also honor a caller-supplied
repetition_penalty (sanitized) instead of ignoring it.

Adds tests/test_fun_asr_nano_repetition_penalty.py (no GPU or vLLM required).

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses issue #2948 by introducing a helper function to safely resolve the repetition penalty in vLLM prompt-embeds mode, preventing CUDA kernel crashes. The helper is integrated across various inference paths, and corresponding unit tests are added. Feedback on these changes includes standardizing the logger initialization using __name__, moving inline imports to the top of the files to comply with PEP 8, and explicitly managing the logger level in tests to ensure robustness.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.


import logging

logger = logging.getLogger("funasr.fun_asr_nano.vllm")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a hardcoded string for the logger name ("funasr.fun_asr_nano.vllm") bypasses Python's standard module-based logging hierarchy. It is highly recommended to use __name__ instead, which automatically aligns the logger with the module's package path (funasr.models.fun_asr_nano.vllm_utils) and makes logging configuration much more manageable and standard.

Suggested change
logger = logging.getLogger("funasr.fun_asr_nano.vllm")
logger = logging.getLogger(__name__)

Comment on lines +63 to +69
handler = _Collect(level=logging.WARNING)
vllm_utils.logger.addHandler(handler)
try:
resolve_repetition_penalty(1.0, prompt_embeds=True)
resolve_repetition_penalty(1.3, prompt_embeds=False)
finally:
vllm_utils.logger.removeHandler(handler)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To ensure the test is robust and independent of any external logging configuration (which might set the logger level to ERROR or higher and cause false positives), it is safer to explicitly set the logger's level to logging.WARNING (or lower) during the test execution and restore it afterward.

Suggested change
handler = _Collect(level=logging.WARNING)
vllm_utils.logger.addHandler(handler)
try:
resolve_repetition_penalty(1.0, prompt_embeds=True)
resolve_repetition_penalty(1.3, prompt_embeds=False)
finally:
vllm_utils.logger.removeHandler(handler)
handler = _Collect(level=logging.WARNING)
vllm_utils.logger.addHandler(handler)
original_level = vllm_utils.logger.level
vllm_utils.logger.setLevel(logging.WARNING)
try:
resolve_repetition_penalty(1.0, prompt_embeds=True)
resolve_repetition_penalty(1.3, prompt_embeds=False)
finally:
vllm_utils.logger.setLevel(original_level)
vllm_utils.logger.removeHandler(handler)

except ImportError:
from vllm.inputs.data import EmbedsPrompt

from funasr.models.fun_asr_nano.vllm_utils import resolve_repetition_penalty

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

According to PEP 8, imports should always be placed at the top of the file. Since vllm_utils is a lightweight, dependency-free module, there is no risk of circular dependencies or heavy import overhead. Moving this import to the top level of the file improves code readability and avoids the minor overhead of re-importing it on every call to generate.

References
  1. PEP 8: Imports should always be placed at the top of the file, just after any module comments and docstrings, and before module globals and constants. (link)

except ImportError:
from vllm.inputs.data import EmbedsPrompt

from funasr.models.fun_asr_nano.vllm_utils import resolve_repetition_penalty

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

According to PEP 8, imports should always be placed at the top of the file. Since vllm_utils is a lightweight, dependency-free module, there is no risk of circular dependencies or heavy import overhead. Moving this import to the top level of the file improves code readability and avoids the minor overhead of re-importing it on every call to _process_one.

References
  1. PEP 8: Imports should always be placed at the top of the file, just after any module comments and docstrings, and before module globals and constants. (link)


params = SamplingParams(max_tokens=max_new_tokens, temperature=temperature,
repetition_penalty=1.3, skip_special_tokens=True)
from funasr.models.fun_asr_nano.vllm_utils import resolve_repetition_penalty

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

According to PEP 8, imports should always be placed at the top of the file. Since vllm_utils is a lightweight, dependency-free module, there is no risk of circular dependencies or heavy import overhead. Moving this import to the top level of the file improves code readability and avoids the minor overhead of re-importing it on every call to streaming_generate.

References
  1. PEP 8: Imports should always be placed at the top of the file, just after any module comments and docstrings, and before module globals and constants. (link)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant