Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions examples/llm_server/cpp/worker_loop.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
// "reused_prompt_tokens": int, "prefilled_prompt_tokens": int,
// "session_reset_reason": str
// (new|exact_prefix|mismatch|dirty|equal),
// "prefill_ms": float, "decode_ms": float, "total_ms": float,
// "prefill_tok_s": float, "decode_tok_s": float,
// "generated_token_ids"?: [int,...]} // omitted if stop-trimmed
// open/close/reset: {"opened"|"closed"|"reset": true, "session_id": str}
// error: {"error": str, "code"?: str} // capacity_exhausted |
Expand All @@ -59,6 +61,7 @@
#include <pytorch/tokenizers/tokenizer.h>

#include <algorithm>
#include <chrono>
#include <cstdint>
#include <iostream>
#include <iterator>
Expand Down Expand Up @@ -108,6 +111,7 @@ inline void worker_handle_request(
const std::unordered_map<std::string, int64_t>& metadata,
const nlohmann::json& req,
const std::vector<uint64_t>& prompt_prefix_ids = {}) {
const auto request_start = std::chrono::steady_clock::now();
LLMSession& session = *st.session;
int64_t max_new = req.value("max_new_tokens", static_cast<int64_t>(-1));
const float temperature = req.value("temperature", 0.0f);
Expand Down Expand Up @@ -203,6 +207,7 @@ inline void worker_handle_request(

SamplingConfig sampling;
sampling.temperature = temperature;
const auto prefill_start = std::chrono::steady_clock::now();
if (session.prefill_tokens(to_prefill, &sampling) !=
::executorch::runtime::Error::Ok) {
st.dirty = true; // state may be partially mutated; force a reset next time
Expand All @@ -212,6 +217,7 @@ inline void worker_handle_request(
// suffix, or the whole prompt). Keep the invariant
// resident.size()==position().
st.resident_token_ids = ids;
const auto decode_start = std::chrono::steady_clock::now();

std::string buf; // bytes not yet forming a complete UTF-8 prefix
std::string pending; // complete-UTF-8 text held back for stop-string matching
Expand Down Expand Up @@ -299,6 +305,25 @@ inline void worker_handle_request(
st.resident_token_ids.end() - num_generated,
st.resident_token_ids.end());
}
const auto request_end = std::chrono::steady_clock::now();
const double prefill_ms =
std::chrono::duration<double, std::milli>(decode_start - prefill_start)
.count();
Comment on lines +309 to +311
const double decode_ms =
std::chrono::duration<double, std::milli>(request_end - decode_start)
.count();
const double total_ms =
std::chrono::duration<double, std::milli>(request_end - request_start)
.count();
done["prefill_ms"] = prefill_ms;
done["decode_ms"] = decode_ms;
done["total_ms"] = total_ms;
done["prefill_tok_s"] = prefill_ms > 0.0
? (static_cast<double>(prefilled) * 1000.0 / prefill_ms)
: 0.0;
done["decode_tok_s"] = decode_ms > 0.0
? (static_cast<double>(num_generated) * 1000.0 / decode_ms)
: 0.0;
worker_emit(done);
}

Expand Down
25 changes: 25 additions & 0 deletions examples/llm_server/python/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,29 @@ def _extract_tools(self, req: ChatCompletionRequest, text: str):
text = parsed.normal_text
return None, self._visible_content(text)

@staticmethod
def _log_generation_stats(
session_id: Optional[str], stats: GenStats, finish: str
) -> None:
logger.info(
"llm_turn_stats session_id=%s reason=%s prompt_tokens=%d "
"reused_prompt_tokens=%d prefilled_prompt_tokens=%d "
"completion_tokens=%d prefill_ms=%.1f prefill_tok_s=%.1f "
"decode_ms=%.1f decode_tok_s=%.1f total_ms=%.1f finish=%s",
session_id or "<scratch>",
stats.session_reset_reason,
stats.prompt_tokens,
stats.reused_prompt_tokens,
stats.prefilled_prompt_tokens,
stats.completion_tokens,
stats.prefill_ms,
stats.prefill_tok_s,
stats.decode_ms,
stats.decode_tok_s,
stats.total_ms,
finish,
)

async def _clean(
self, stream: AsyncIterator[str], stops: list[str], on_stop=None
) -> AsyncIterator[str]:
Expand Down Expand Up @@ -480,6 +503,7 @@ async def _complete(
finish = self._finish_reason(
req, stats.completion_tokens, tool_calls, stopped, stats.finish_reason
)
self._log_generation_stats(req.session_id, stats, finish)
return ChatCompletionResponse(
model=self._model_id,
choices=[
Expand Down Expand Up @@ -623,6 +647,7 @@ def chunk(delta: DeltaMessage, finish=None) -> str:
stopped=stop_hit[0],
worker_finish=stats.finish_reason,
)
self._log_generation_stats(req.session_id, stats, finish)
yield chunk(DeltaMessage(), finish=finish)
if req.stream_options and req.stream_options.include_usage:
usage_chunk = ChatCompletionChunk(
Expand Down
10 changes: 10 additions & 0 deletions examples/llm_server/python/session_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ class GenStats:
reused_prompt_tokens: int = 0
prefilled_prompt_tokens: int = 0
session_reset_reason: Optional[str] = None
prefill_ms: float = 0.0
decode_ms: float = 0.0
total_ms: float = 0.0
prefill_tok_s: float = 0.0
decode_tok_s: float = 0.0
# Exact token ids generated this turn, for an adapter's transcript
# store. Empty when the worker doesn't report them (e.g. a stop-trimmed turn).
generated_token_ids: list = field(default_factory=list)
Expand Down Expand Up @@ -115,6 +120,11 @@ def stats_cb(self, s) -> None:
self._stats.reused_prompt_tokens = getattr(s, "reused_prompt_tokens", 0)
self._stats.prefilled_prompt_tokens = getattr(s, "prefilled_prompt_tokens", 0)
self._stats.session_reset_reason = getattr(s, "session_reset_reason", None)
self._stats.prefill_ms = getattr(s, "prefill_ms", 0.0)
self._stats.decode_ms = getattr(s, "decode_ms", 0.0)
self._stats.total_ms = getattr(s, "total_ms", 0.0)
self._stats.prefill_tok_s = getattr(s, "prefill_tok_s", 0.0)
self._stats.decode_tok_s = getattr(s, "decode_tok_s", 0.0)
self._stats.generated_token_ids = getattr(s, "generated_token_ids", [])

def run(self) -> None:
Expand Down
10 changes: 10 additions & 0 deletions examples/llm_server/python/tests/test_session_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ class S:
num_prompt_tokens = 3
num_generated_tokens = 2
finish_reason = "stop"
prefill_ms = 4.0
decode_ms = 5.0
total_ms = 10.0
prefill_tok_s = 750.0
decode_tok_s = 400.0
generated_token_ids = [10, 11]

stats_callback(S())
Expand All @@ -110,6 +115,11 @@ async def scenario():
assert "".join(out) == "Hello world"
assert stats.completion_tokens == 2
assert stats.finish_reason == "stop"
assert stats.prefill_ms == 4.0
assert stats.decode_ms == 5.0
assert stats.total_ms == 10.0
assert stats.prefill_tok_s == 750.0
assert stats.decode_tok_s == 400.0
assert stats.generated_token_ids == [10, 11]


Expand Down
10 changes: 10 additions & 0 deletions examples/llm_server/python/tests/test_worker_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,11 @@ def test_generate_parses_warm_resume_metrics():
"reused_prompt_tokens": 90,
"prefilled_prompt_tokens": 10,
"session_reset_reason": "exact_prefix",
"prefill_ms": 12.5,
"decode_ms": 25.0,
"total_ms": 40.0,
"prefill_tok_s": 800.0,
"decode_tok_s": 40.0,
},
)
)
Expand All @@ -222,6 +227,11 @@ def test_generate_parses_warm_resume_metrics():
assert st.reused_prompt_tokens == 90
assert st.prefilled_prompt_tokens == 10
assert st.session_reset_reason == "exact_prefix"
assert st.prefill_ms == 12.5
assert st.decode_ms == 25.0
assert st.total_ms == 40.0
assert st.prefill_tok_s == 800.0
assert st.decode_tok_s == 40.0


def test_spawn_worker_waits_for_ready():
Expand Down
10 changes: 10 additions & 0 deletions examples/llm_server/python/worker_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ class WorkerStats:
reused_prompt_tokens: int = 0
prefilled_prompt_tokens: int = 0
session_reset_reason: Optional[str] = None
prefill_ms: float = 0.0
decode_ms: float = 0.0
total_ms: float = 0.0
prefill_tok_s: float = 0.0
decode_tok_s: float = 0.0
# The exact (non-terminal) token ids generated this turn. The control plane
# stores these per session and splices them back as an `ids` prompt segment
# next turn, so a prior assistant span is an exact token extension instead of
Expand Down Expand Up @@ -167,6 +172,11 @@ def _on_done(msg: dict, stats_callback) -> None:
reused_prompt_tokens=msg.get("reused_prompt_tokens", 0),
prefilled_prompt_tokens=msg.get("prefilled_prompt_tokens", 0),
session_reset_reason=reason,
prefill_ms=msg.get("prefill_ms", 0.0),
decode_ms=msg.get("decode_ms", 0.0),
total_ms=msg.get("total_ms", 0.0),
prefill_tok_s=msg.get("prefill_tok_s", 0.0),
decode_tok_s=msg.get("decode_tok_s", 0.0),
generated_token_ids=msg.get("generated_token_ids", []),
)
)
Expand Down
Loading