fix: avoid CUDA crash from repetition_penalty in Fun-ASR-Nano vLLM prompt-embeds mode (#2948)#2974
Conversation
…ompt-embeds mode (modelscope#2948) The Fun-ASR-Nano vLLM serving paths run with enable_prompt_embeds=True, so requests carry audio/text embeddings and have no prompt token IDs. vLLM applies repetition_penalty by scattering over the prompt token IDs, so any value other than 1.0 indexes an empty tensor and aborts the engine with a CUDA "scatter gather index out of bounds" assertion (issue modelscope#2948). The batch generate() entry already defaulted to 1.0, but the OpenAI/REST server (_server_app), the pipeline path and the streaming path still hardcoded repetition_penalty=1.3 and crashed. Centralize the rule in a dependency-free helper (resolve_repetition_penalty) that forces the neutral value in prompt-embeds mode and warns once, and route every Fun-ASR-Nano SamplingParams through it. The pipeline and streaming paths now also honor a caller-supplied repetition_penalty (sanitized) instead of ignoring it. Adds tests/test_fun_asr_nano_repetition_penalty.py (no GPU or vLLM required).
There was a problem hiding this comment.
Code Review
This pull request addresses issue #2948 by introducing a helper function to safely resolve the repetition penalty in vLLM prompt-embeds mode, preventing CUDA kernel crashes. The helper is integrated across various inference paths, and corresponding unit tests are added. Feedback on these changes includes standardizing the logger initialization using __name__, moving inline imports to the top of the files to comply with PEP 8, and explicitly managing the logger level in tests to ensure robustness.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
|
|
||
| import logging | ||
|
|
||
| logger = logging.getLogger("funasr.fun_asr_nano.vllm") |
There was a problem hiding this comment.
Using a hardcoded string for the logger name ("funasr.fun_asr_nano.vllm") bypasses Python's standard module-based logging hierarchy. It is highly recommended to use __name__ instead, which automatically aligns the logger with the module's package path (funasr.models.fun_asr_nano.vllm_utils) and makes logging configuration much more manageable and standard.
| logger = logging.getLogger("funasr.fun_asr_nano.vllm") | |
| logger = logging.getLogger(__name__) |
| handler = _Collect(level=logging.WARNING) | ||
| vllm_utils.logger.addHandler(handler) | ||
| try: | ||
| resolve_repetition_penalty(1.0, prompt_embeds=True) | ||
| resolve_repetition_penalty(1.3, prompt_embeds=False) | ||
| finally: | ||
| vllm_utils.logger.removeHandler(handler) |
There was a problem hiding this comment.
To ensure the test is robust and independent of any external logging configuration (which might set the logger level to ERROR or higher and cause false positives), it is safer to explicitly set the logger's level to logging.WARNING (or lower) during the test execution and restore it afterward.
| handler = _Collect(level=logging.WARNING) | |
| vllm_utils.logger.addHandler(handler) | |
| try: | |
| resolve_repetition_penalty(1.0, prompt_embeds=True) | |
| resolve_repetition_penalty(1.3, prompt_embeds=False) | |
| finally: | |
| vllm_utils.logger.removeHandler(handler) | |
| handler = _Collect(level=logging.WARNING) | |
| vllm_utils.logger.addHandler(handler) | |
| original_level = vllm_utils.logger.level | |
| vllm_utils.logger.setLevel(logging.WARNING) | |
| try: | |
| resolve_repetition_penalty(1.0, prompt_embeds=True) | |
| resolve_repetition_penalty(1.3, prompt_embeds=False) | |
| finally: | |
| vllm_utils.logger.setLevel(original_level) | |
| vllm_utils.logger.removeHandler(handler) |
| except ImportError: | ||
| from vllm.inputs.data import EmbedsPrompt | ||
|
|
||
| from funasr.models.fun_asr_nano.vllm_utils import resolve_repetition_penalty |
There was a problem hiding this comment.
According to PEP 8, imports should always be placed at the top of the file. Since vllm_utils is a lightweight, dependency-free module, there is no risk of circular dependencies or heavy import overhead. Moving this import to the top level of the file improves code readability and avoids the minor overhead of re-importing it on every call to generate.
References
- PEP 8: Imports should always be placed at the top of the file, just after any module comments and docstrings, and before module globals and constants. (link)
| except ImportError: | ||
| from vllm.inputs.data import EmbedsPrompt | ||
|
|
||
| from funasr.models.fun_asr_nano.vllm_utils import resolve_repetition_penalty |
There was a problem hiding this comment.
According to PEP 8, imports should always be placed at the top of the file. Since vllm_utils is a lightweight, dependency-free module, there is no risk of circular dependencies or heavy import overhead. Moving this import to the top level of the file improves code readability and avoids the minor overhead of re-importing it on every call to _process_one.
References
- PEP 8: Imports should always be placed at the top of the file, just after any module comments and docstrings, and before module globals and constants. (link)
|
|
||
| params = SamplingParams(max_tokens=max_new_tokens, temperature=temperature, | ||
| repetition_penalty=1.3, skip_special_tokens=True) | ||
| from funasr.models.fun_asr_nano.vllm_utils import resolve_repetition_penalty |
There was a problem hiding this comment.
According to PEP 8, imports should always be placed at the top of the file. Since vllm_utils is a lightweight, dependency-free module, there is no risk of circular dependencies or heavy import overhead. Moving this import to the top level of the file improves code readability and avoids the minor overhead of re-importing it on every call to streaming_generate.
References
- PEP 8: Imports should always be placed at the top of the file, just after any module comments and docstrings, and before module globals and constants. (link)
Summary
The Fun-ASR-Nano vLLM serving paths run the engine with
enable_prompt_embeds=True, so each request carries precomputed audio/text embeddings and has no prompt token IDs. vLLM appliesrepetition_penaltyby scattering over the prompt's token IDs, so any value other than1.0indexes an empty token-id tensor and aborts the engine with:This is the crash reported in #2948. The batch
generate()entry already defaulted torepetition_penalty=1.0, but three other prompt-embeds paths still forwarded1.3and crashed:funasr/bin/_server_app.py— the OpenAI-compatible / REST server passesrepetition_penalty=1.3straight into the vLLM engine.funasr/models/fun_asr_nano/inference_vllm_pipeline.py—SamplingParams(..., repetition_penalty=1.3).funasr/models/fun_asr_nano/inference_vllm_streaming.py—SamplingParams(..., repetition_penalty=1.3).Fix
Centralize the rule in a small, dependency-free helper,
fun_asr_nano/vllm_utils.resolve_repetition_penalty(), which forces the neutral value (1.0) whenever the request runs in prompt-embeds mode and warns once so the behavior is discoverable. Every Fun-ASR-NanoSamplingParamsconstruction now routes through it, so the failure mode is impossible regardless of the call site. The pipeline and streaming paths additionally now honor a caller-suppliedrepetition_penalty(sanitized) instead of silently ignoring it.This keeps all serving paths consistent with the already-safe batch
generate()entry.Type of change
Validation
python -m compileall funasr tests(the new/edited files compile; the only error is the pre-existing Triton-DSL filefunasr/models/sense_voice/whisper_lib/triton_ops.py, unrelated to this change)python -m unittest tests.test_fun_asr_nano_repetition_penalty— new tests cover: neutral pass-through,Nonehandling, clamping1.3 → 1.0in prompt-embeds mode, pass-through for regular token prompts, and warn-once behavior. The helper is dependency-free, so the tests run without a GPU or vLLM.Failure mode / reproduction
Before this change, starting the server and transcribing audio that reaches the vLLM engine:
aborts the engine with the CUDA scatter assertion above (most visible on longer audio, which produces more decode steps). After this change the penalty is neutralized for prompt-embeds requests and inference completes.
User impact
Anyone deploying Fun-ASR-Nano via the OpenAI-compatible/REST server, the pipeline path, or the streaming path — including agent integrations that call the transcription endpoint — no longer hits the CUDA crash. The supported batch path is unaffected (it already used
1.0).Notes for reviewers
repetition_penalty=1.0(the default).resolve_repetition_penalty(value, prompt_embeds=False)passes the value through unchanged.