diff --git a/examples/llm_server/cpp/worker_loop.h b/examples/llm_server/cpp/worker_loop.h index b3a48a5f003..0f6536ab9ac 100644 --- a/examples/llm_server/cpp/worker_loop.h +++ b/examples/llm_server/cpp/worker_loop.h @@ -45,6 +45,8 @@ // "reused_prompt_tokens": int, "prefilled_prompt_tokens": int, // "session_reset_reason": str // (new|exact_prefix|mismatch|dirty|equal), +// "prefill_ms": float, "decode_ms": float, "total_ms": float, +// "prefill_tok_s": float, "decode_tok_s": float, // "generated_token_ids"?: [int,...]} // omitted if stop-trimmed // open/close/reset: {"opened"|"closed"|"reset": true, "session_id": str} // error: {"error": str, "code"?: str} // capacity_exhausted | @@ -59,6 +61,7 @@ #include #include +#include #include #include #include @@ -108,6 +111,7 @@ inline void worker_handle_request( const std::unordered_map& metadata, const nlohmann::json& req, const std::vector& prompt_prefix_ids = {}) { + const auto request_start = std::chrono::steady_clock::now(); LLMSession& session = *st.session; int64_t max_new = req.value("max_new_tokens", static_cast(-1)); const float temperature = req.value("temperature", 0.0f); @@ -203,6 +207,7 @@ inline void worker_handle_request( SamplingConfig sampling; sampling.temperature = temperature; + const auto prefill_start = std::chrono::steady_clock::now(); if (session.prefill_tokens(to_prefill, &sampling) != ::executorch::runtime::Error::Ok) { st.dirty = true; // state may be partially mutated; force a reset next time @@ -212,6 +217,7 @@ inline void worker_handle_request( // suffix, or the whole prompt). Keep the invariant // resident.size()==position(). st.resident_token_ids = ids; + const auto decode_start = std::chrono::steady_clock::now(); std::string buf; // bytes not yet forming a complete UTF-8 prefix std::string pending; // complete-UTF-8 text held back for stop-string matching @@ -299,6 +305,25 @@ inline void worker_handle_request( st.resident_token_ids.end() - num_generated, st.resident_token_ids.end()); } + const auto request_end = std::chrono::steady_clock::now(); + const double prefill_ms = + std::chrono::duration(decode_start - prefill_start) + .count(); + const double decode_ms = + std::chrono::duration(request_end - decode_start) + .count(); + const double total_ms = + std::chrono::duration(request_end - request_start) + .count(); + done["prefill_ms"] = prefill_ms; + done["decode_ms"] = decode_ms; + done["total_ms"] = total_ms; + done["prefill_tok_s"] = prefill_ms > 0.0 + ? (static_cast(prefilled) * 1000.0 / prefill_ms) + : 0.0; + done["decode_tok_s"] = decode_ms > 0.0 + ? (static_cast(num_generated) * 1000.0 / decode_ms) + : 0.0; worker_emit(done); } diff --git a/examples/llm_server/python/serving_chat.py b/examples/llm_server/python/serving_chat.py index 8172fc36ff7..28551f6bff3 100644 --- a/examples/llm_server/python/serving_chat.py +++ b/examples/llm_server/python/serving_chat.py @@ -184,6 +184,29 @@ def _extract_tools(self, req: ChatCompletionRequest, text: str): text = parsed.normal_text return None, self._visible_content(text) + @staticmethod + def _log_generation_stats( + session_id: Optional[str], stats: GenStats, finish: str + ) -> None: + logger.info( + "llm_turn_stats session_id=%s reason=%s prompt_tokens=%d " + "reused_prompt_tokens=%d prefilled_prompt_tokens=%d " + "completion_tokens=%d prefill_ms=%.1f prefill_tok_s=%.1f " + "decode_ms=%.1f decode_tok_s=%.1f total_ms=%.1f finish=%s", + session_id or "", + stats.session_reset_reason, + stats.prompt_tokens, + stats.reused_prompt_tokens, + stats.prefilled_prompt_tokens, + stats.completion_tokens, + stats.prefill_ms, + stats.prefill_tok_s, + stats.decode_ms, + stats.decode_tok_s, + stats.total_ms, + finish, + ) + async def _clean( self, stream: AsyncIterator[str], stops: list[str], on_stop=None ) -> AsyncIterator[str]: @@ -480,6 +503,7 @@ async def _complete( finish = self._finish_reason( req, stats.completion_tokens, tool_calls, stopped, stats.finish_reason ) + self._log_generation_stats(req.session_id, stats, finish) return ChatCompletionResponse( model=self._model_id, choices=[ @@ -623,6 +647,7 @@ def chunk(delta: DeltaMessage, finish=None) -> str: stopped=stop_hit[0], worker_finish=stats.finish_reason, ) + self._log_generation_stats(req.session_id, stats, finish) yield chunk(DeltaMessage(), finish=finish) if req.stream_options and req.stream_options.include_usage: usage_chunk = ChatCompletionChunk( diff --git a/examples/llm_server/python/session_runtime.py b/examples/llm_server/python/session_runtime.py index 1750b9fe777..b59e211a228 100644 --- a/examples/llm_server/python/session_runtime.py +++ b/examples/llm_server/python/session_runtime.py @@ -71,6 +71,11 @@ class GenStats: reused_prompt_tokens: int = 0 prefilled_prompt_tokens: int = 0 session_reset_reason: Optional[str] = None + prefill_ms: float = 0.0 + decode_ms: float = 0.0 + total_ms: float = 0.0 + prefill_tok_s: float = 0.0 + decode_tok_s: float = 0.0 # Exact token ids generated this turn, for an adapter's transcript # store. Empty when the worker doesn't report them (e.g. a stop-trimmed turn). generated_token_ids: list = field(default_factory=list) @@ -115,6 +120,11 @@ def stats_cb(self, s) -> None: self._stats.reused_prompt_tokens = getattr(s, "reused_prompt_tokens", 0) self._stats.prefilled_prompt_tokens = getattr(s, "prefilled_prompt_tokens", 0) self._stats.session_reset_reason = getattr(s, "session_reset_reason", None) + self._stats.prefill_ms = getattr(s, "prefill_ms", 0.0) + self._stats.decode_ms = getattr(s, "decode_ms", 0.0) + self._stats.total_ms = getattr(s, "total_ms", 0.0) + self._stats.prefill_tok_s = getattr(s, "prefill_tok_s", 0.0) + self._stats.decode_tok_s = getattr(s, "decode_tok_s", 0.0) self._stats.generated_token_ids = getattr(s, "generated_token_ids", []) def run(self) -> None: diff --git a/examples/llm_server/python/tests/test_session_runtime.py b/examples/llm_server/python/tests/test_session_runtime.py index f7f854d2f1f..a6fd4a74e84 100644 --- a/examples/llm_server/python/tests/test_session_runtime.py +++ b/examples/llm_server/python/tests/test_session_runtime.py @@ -96,6 +96,11 @@ class S: num_prompt_tokens = 3 num_generated_tokens = 2 finish_reason = "stop" + prefill_ms = 4.0 + decode_ms = 5.0 + total_ms = 10.0 + prefill_tok_s = 750.0 + decode_tok_s = 400.0 generated_token_ids = [10, 11] stats_callback(S()) @@ -110,6 +115,11 @@ async def scenario(): assert "".join(out) == "Hello world" assert stats.completion_tokens == 2 assert stats.finish_reason == "stop" + assert stats.prefill_ms == 4.0 + assert stats.decode_ms == 5.0 + assert stats.total_ms == 10.0 + assert stats.prefill_tok_s == 750.0 + assert stats.decode_tok_s == 400.0 assert stats.generated_token_ids == [10, 11] diff --git a/examples/llm_server/python/tests/test_worker_client.py b/examples/llm_server/python/tests/test_worker_client.py index 9591a3f9a23..1e4e9907311 100644 --- a/examples/llm_server/python/tests/test_worker_client.py +++ b/examples/llm_server/python/tests/test_worker_client.py @@ -211,6 +211,11 @@ def test_generate_parses_warm_resume_metrics(): "reused_prompt_tokens": 90, "prefilled_prompt_tokens": 10, "session_reset_reason": "exact_prefix", + "prefill_ms": 12.5, + "decode_ms": 25.0, + "total_ms": 40.0, + "prefill_tok_s": 800.0, + "decode_tok_s": 40.0, }, ) ) @@ -222,6 +227,11 @@ def test_generate_parses_warm_resume_metrics(): assert st.reused_prompt_tokens == 90 assert st.prefilled_prompt_tokens == 10 assert st.session_reset_reason == "exact_prefix" + assert st.prefill_ms == 12.5 + assert st.decode_ms == 25.0 + assert st.total_ms == 40.0 + assert st.prefill_tok_s == 800.0 + assert st.decode_tok_s == 40.0 def test_spawn_worker_waits_for_ready(): diff --git a/examples/llm_server/python/worker_client.py b/examples/llm_server/python/worker_client.py index 00e09b58ea4..341789afbf2 100644 --- a/examples/llm_server/python/worker_client.py +++ b/examples/llm_server/python/worker_client.py @@ -53,6 +53,11 @@ class WorkerStats: reused_prompt_tokens: int = 0 prefilled_prompt_tokens: int = 0 session_reset_reason: Optional[str] = None + prefill_ms: float = 0.0 + decode_ms: float = 0.0 + total_ms: float = 0.0 + prefill_tok_s: float = 0.0 + decode_tok_s: float = 0.0 # The exact (non-terminal) token ids generated this turn. The control plane # stores these per session and splices them back as an `ids` prompt segment # next turn, so a prior assistant span is an exact token extension instead of @@ -167,6 +172,11 @@ def _on_done(msg: dict, stats_callback) -> None: reused_prompt_tokens=msg.get("reused_prompt_tokens", 0), prefilled_prompt_tokens=msg.get("prefilled_prompt_tokens", 0), session_reset_reason=reason, + prefill_ms=msg.get("prefill_ms", 0.0), + decode_ms=msg.get("decode_ms", 0.0), + total_ms=msg.get("total_ms", 0.0), + prefill_tok_s=msg.get("prefill_tok_s", 0.0), + decode_tok_s=msg.get("decode_tok_s", 0.0), generated_token_ids=msg.get("generated_token_ids", []), ) )