Adding Chunked Prefill to NanoGPT
In the previous post, we built a continuous batching scheduler for NanoGPT. Requests can now arrive and leave the decode batch independently — the GPU stays busy, and no request waits for anyone else to finish. But there’s still a problem hiding in the prefill step.
When a new request arrives, we run its entire prompt through the model in one shot. For a 9-token Shakespeare prompt, this is fine. But what if the prompt is 2,000 tokens? Or 8,000? That single prefill forward pass monopolizes the GPU for the entire duration, and every request currently decoding has to wait. Their users stare at a frozen cursor. In production, this is called prefill starvation, and it’s one of the main reasons inference servers implement chunked prefill.
The fix is conceptually simple: instead of processing the whole prompt at once, break it into smaller chunks and interleave them with decode steps. Each scheduler iteration processes one chunk of the prefill and one decode step for everyone already active. The prompt still gets fully processed — it just takes a few more iterations, and nobody else gets starved in the meantime.
The token budget
The key mechanism is a token budget: a cap on how many tokens the scheduler can process per iteration. Two types of work compete for this budget:
- Decode tokens — 1 token per active request. These are cheap and high priority. A user is watching.
- Prefill tokens — whatever’s left of the budget goes to the next chunk of a partially-prefilled request.
If we have 3 active decode requests and a token budget of 10, that leaves 7 tokens of budget for prefill work. A 20-token prompt would get processed as three chunks: 7 + 7 + 6. The decode requests never skip a beat.
Tracking partial prefill
The Request class from the continuous batching post needs one new field: prefill_cursor, an integer tracking how many prompt tokens have been processed so far. A request is “fully prefilled” when prefill_cursor == len(prompt_tokens).
The lifecycle also gains a new state: waiting → prefilling → active → done. The prefilling state means the prompt is partially processed — some chunks are in the KV cache, but more remain.
@dataclass
class Request:
"""Each in-flight generation carries its own state and KV cache."""
id: int
prompt_tokens: List[int] # the original encoded prompt
max_new_tokens: int # how many tokens this request wants
generated_tokens: List[int] = field(default_factory=list)
status: str = "waiting" # "waiting" -> "prefilling" -> "active" -> "done"
prefill_cursor: int = 0
# Per-request KV cache, keyed by (layer_idx, head_idx)
# Each value is a (key_tensor, value_tensor) tuple of shape (1, T_i, head_size)
# T_i grows by 1 each decode step — different requests have different T_i
kv_cache: Dict[Tuple[int, int], Tuple[torch.Tensor, torch.Tensor]] = field(
default_factory=dict
)
@property
def tokens_so_far(self) -> List[int]:
"""Full sequence: prompt + everything generated."""
return self.prompt_tokens + self.generated_tokens
@property
def num_generated(self) -> int:
return len(self.generated_tokens)
@property
def is_done(self) -> bool:
return self.num_generated >= self.max_new_tokens
@property
def is_fully_prefilled(self) -> bool:
return self.prefill_cursor == len(self.prompt_tokens)
def clear_cache(self):
self.kv_cache.clear()
We add in prefill_cursor which is just the index for how many prompt tokens have already been processed. A request is fully prefilled when prefill_cursor == len(prompt_tokens).
We also add a new property called is_fully_prefilled which returns True if the request is fully prefilled, and False otherwise.
Continuous Batching Loop
This is where most of the changes are made.
Before, we had an inner while loop that prefill the entire request in one go. We will get rid of that entirely. Instead, we will introduce a new variable called prefilling_requests which will hold all the requests that are currently being prefilled.
For every outer while loop iteration, we want to calculate the remaining budget for prefilling, which is just the token budget minus the number of tokens that have already been generated by the active requests.
If we have a remaining budget, then we take the first prefilling request, and prefill it for one step.
We use the prefill_cursor property on that particular request and the remaining available token budget to determine how many tokens to prefill in this step. These two parameters serve as our slicing indexes for the prompt tokens.
After we slice the prompt tokens, we have to remember to update the prefill_cursor property on that particular request by adding the number of tokens we just prefilled.
If we don’t have enough budget to prefill the entire request or we don’t have any more active (decode) requests, then we continue to the next loop.
Otherwise, we grab the positional embeddings by using the torch.arange function.
Then, we add the previous KV cache to an array (refer to my KV Cache implementation article for more details on this) and then run the model forward pass to get the new logits and KV cache. We grab the logits and then get the next token index by using the torch.multinomial function.
We store the new KV cache on the request object, and then we check if the request is fully prefilled. If it is, we move it to the active requests list.
The rest of the decode logic is exactly the same, since we are doing chunked prefill optimizations after all!
If we don’t have any more active requests, then we increment the step, which is equivalent to moving on to the next decode step.
That’s it!
Here is the code for the entire continuous function:
def continuous_batching_generate(model, request_queue, max_batch_size=4, token_budget=None):
"""
Hint 1: Decode requests will be highest priority.
"""
model.eval()
active_requests = []
completed_requests = []
prefilling_requests = []
step = 0 # simulating clock in a real time server
queue_idx = 0 # for efficiency reasons
with torch.no_grad():
while active_requests or queue_idx < len(request_queue):
if not prefilling_requests and queue_idx < len(request_queue):
time_step, new_req = request_queue[queue_idx]
new_req.status = "prefilling"
prefilling_requests.append(new_req)
queue_idx += 1
prefill_chunk_tokens = []
remaining_budget = token_budget - len(active_requests)
if remaining_budget > 0 and prefilling_requests:
p_req = prefilling_requests[0]
tokens_left = len(p_req.prompt_tokens) - p_req.prefill_cursor
chunk_size = min(remaining_budget, tokens_left)
chunk_start = p_req.prefill_cursor
chunk_tokens = p_req.prompt_tokens[chunk_start: chunk_start + chunk_size]
prefill_chunk_tokens = torch.tensor([chunk_tokens], dtype=torch.long, device=device)
p_req.prefill_cursor += chunk_size
if not prefill_chunk_tokens and not active_requests:
step += 1
continue
if prefill_chunk_tokens:
pos = torch.arange(chunk_start, chunk_start + chunk_size).unsqueeze(0)
if p_req.kv_cache:
past_kvs = []
for layer_idx in range(n_layer):
block_kv = [(p_req.kv_cache[(layer_idx, hi)]) for hi in range(n_head)]
past_kvs.append(block_kv)
logits, _, new_kvs = model(prefill_chunk_tokens, past_kvs, pos=pos)
else:
logits, _, new_kvs = model(prefill_chunk_tokens, pos=pos)
for li, bkv in enumerate(new_kvs):
for hi, (k, v) in enumerate(bkv):
p_req.kv_cache[(li, hi)] = (k, v)
logits = logits[:, -1, :]
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
if p_req.is_fully_prefilled:
p_req.generated_tokens.append(idx_next.item())
p_req.status = "active"
p_req._last_token = idx_next
active_requests.append(p_req)
prefilling_requests.pop(0)
if active_requests:
B_active = len(active_requests)
batch_tokens = torch.cat([req._last_token for req in active_requests])
batch_positions = torch.tensor([[len(req.tokens_so_far) - 1] for req in active_requests], device=device)
past_kvs, attn_mask, pad_lengths = assemble_batch_cache(active_requests)
logits, _, new_kvs = model(
batch_tokens,
pos=batch_positions,
past_kvs=past_kvs,
attn_mask=attn_mask
)
logits = logits[:, -1, :]
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
disassemble_batch_cache(active_requests, new_kvs, pad_lengths)
for i, req in enumerate(active_requests):
req.generated_tokens.append(idx_next[i].item())
req._last_token = idx_next[i : i + 1]
still_active = []
for req in active_requests:
if req.is_done:
req.status = "done"
completed_requests.append(req)
print(f" [step {step}] Completed request {req.id} "
f"({req.num_generated} tokens)")
else:
still_active.append(req)
active_requests = still_active
step += 1
return completed_requests
Here’s a visual overview of the entire loop:
active_requests = []
prefilling_requests = []
completed_requests = []
step = 0, queue_idx = 0"] INIT --> WHILE{"active_requests OR
queue_idx < len(queue) OR
prefilling_requests?"} WHILE -- No --> RETURN(["✅ Return completed_requests"]) WHILE -- Yes --> ADMIT{"No prefilling requests
AND queue has more?"} ADMIT -- Yes --> DEQUEUE["Dequeue next request
Set status → prefilling
Add to prefilling_requests"] ADMIT -- No --> BUDGET DEQUEUE --> BUDGET BUDGET["Calculate budget:
remaining = token_budget − len(active)"] BUDGET --> HAS_BUDGET{"remaining > 0
AND prefilling_requests
not empty?"} HAS_BUDGET -- Yes --> CHUNK["Compute chunk slice:
chunk_size = min(remaining, tokens_left)
Slice prompt[cursor : cursor + chunk_size]
Advance prefill_cursor += chunk_size"] HAS_BUDGET -- No --> SKIP_CHECK{"No prefill chunk
AND no active requests?"} SKIP_CHECK -- Yes --> STEP_SKIP["step += 1"] --> WHILE SKIP_CHECK -- No --> DECODE_PHASE CHUNK --> PREFILL_FWD subgraph PREFILL ["🔵 Prefill Forward Pass"] PREFILL_FWD["Build positions via
arange(chunk_start, chunk_start + chunk_size)
Attach past KV cache if exists"] PREFILL_FWD --> MODEL_P["model(chunk_tokens, past_kvs, pos)
→ logits, new_kvs"] MODEL_P --> STORE_KV["Store new KV cache
on request object"] STORE_KV --> SAMPLE_P["Sample next token from logits[-1]"] end SAMPLE_P --> GRADUATED{"is_fully_prefilled?
cursor == len(prompt)"} GRADUATED -- Yes --> PROMOTE["Append generated token
Set status → active
Move to active_requests
Remove from prefilling_requests"] GRADUATED -- No --> DECODE_PHASE PROMOTE --> DECODE_PHASE DECODE_PHASE{"active_requests
not empty?"} DECODE_PHASE -- No --> CLEANUP DECODE_PHASE -- Yes --> DECODE subgraph DECODE ["🟢 Decode Forward Pass"] DECODE_BATCH["Assemble batch:
Concat last tokens from all active reqs
Build positions & attention masks
assemble_batch_cache()"] DECODE_BATCH --> MODEL_D["model(batch_tokens, past_kvs, pos, attn_mask)
→ logits, new_kvs"] MODEL_D --> DISASSEMBLE["disassemble_batch_cache()
Store per-request KV slices back"] DISASSEMBLE --> SAMPLE_D["Sample next token per request
Append to generated_tokens"] end SAMPLE_D --> CLEANUP subgraph COMPLETE ["🏁 Completion Check"] CLEANUP["For each active request:"] CLEANUP --> DONE_CHECK{"req.is_done?
generated ≥ max_new_tokens"} DONE_CHECK -- Yes --> MARK_DONE["Set status → done
Move to completed_requests"] DONE_CHECK -- No --> KEEP["Keep in active_requests"] end MARK_DONE --> STEP KEEP --> STEP STEP["step += 1"] --> WHILE
Let’s walk through the key changes from the continuous batching version:
Budget math. remaining_budget = token_budget - len(active_requests) reserves one token for each active decode request, then gives the rest to prefill. This is the core scheduling policy: decode always gets served first.
Chunking. chunk_size = min(remaining_budget, tokens_left) takes as much of the prompt as the budget allows. If the prompt has 20 tokens left and the budget allows 8, we process 8 now and come back for the rest next iteration. The prefill_cursor advances by chunk_size each time.
Position tracking. torch.arange(chunk_start, chunk_start + chunk_size) gives each token in the chunk its correct absolute position. The first chunk uses positions 0-7, the second uses 8-15, the third uses 16-19. The model’s position embedding table sees the right index regardless of which chunk the token arrived in.
Cache accumulation. On the second and third chunks, the KV cache from earlier chunks already exists on the request. We pass it as past_kvs so the model’s Head concatenates the new keys/values onto the existing cache rather than starting fresh.
Graduation. When is_fully_prefilled becomes true, the request transitions from prefilling to active, sampling its first generated token and joining the decode batch. The decode logic below is unchanged from the continuous batching post.
Testing
Test 1: regression. The first test verifies that the new chunked prefill code path doesn’t break anything when chunking isn’t actually needed. We feed three short prompts (9-12 tokens each) with a generous token budget of 32, which is larger than any prompt. Since no prompt exceeds the budget, every request gets prefilled in a single pass — identical to the old non-chunked behavior. The assertions check that each request completed, generated the correct number of tokens, that the KV cache has the expected sequence length, and that prefill_cursor advanced to the end of the prompt. If any of these fail, the chunked prefill refactoring broke the basic continuous batching logic.
# ══════════════════════════════════════════════════════════════════════════════
# Test 1: Short prompts (no chunking needed) — regression test
# ══════════════════════════════════════════════════════════════════════════════
print("=" * 60)
print("Test 1: Short prompts (no chunking needed)")
print("=" * 60)
request_queue = [
(0, Request(id=0, prompt_tokens=encode("O Romeo, "), max_new_tokens=17)),
(0, Request(id=1, prompt_tokens=encode("To be or "), max_new_tokens=22)),
(3, Request(id=2, prompt_tokens=encode("KING HENRY:\n"), max_new_tokens=15)),
]
# token_budget=32 is larger than any prompt, so no chunking occurs
completed = continuous_batching_generate(model, request_queue, max_batch_size=4, token_budget=32)
for req in sorted(completed, key=lambda r: r.id):
print(f"\n{'─'*40}")
print(f"Request {req.id} | {req.num_generated} tokens | status: {req.status}")
print(f"{'─'*40}")
print(decode(req.tokens_so_far))
for req in completed:
k, _ = req.kv_cache[(0, 0)]
expected_T = len(req.prompt_tokens) + req.num_generated - 1
assert k.shape[1] == expected_T, f"Req {req.id}: cache T={k.shape[1]}, expected {expected_T}"
assert req.status == "done"
assert req.num_generated == req.max_new_tokens
assert req.prefill_cursor == len(req.prompt_tokens), (
f"Req {req.id}: prefill_cursor={req.prefill_cursor}, expected {len(req.prompt_tokens)}")
print("\n✓ Test 1 passed — short prompts complete correctly (no chunking needed)!")
The result is
============================================================
Test 1: Short prompts (no chunking needed)
============================================================
[step 15] Completed request 0 (17 tokens)
[step 15] Completed request 2 (15 tokens)
[step 21] Completed request 1 (22 tokens)
────────────────────────────────────────
Request 0 | 17 tokens | status: done
────────────────────────────────────────
O Romeo, there:
I would we
────────────────────────────────────────
Request 1 | 22 tokens | status: done
────────────────────────────────────────
To be or steal firiever still b
────────────────────────────────────────
Request 2 | 15 tokens | status: done
────────────────────────────────────────
KING HENRY:
What of the pet
✓ Test 1 passed — short prompts complete correctly (no chunking needed)!
All three requests completed with the correct token counts (17, 22, 15) and correct cache shapes, and the prefill_cursor advanced to the full prompt length for each — even though we never actually chunked anything. This matters because chunked prefill touches the same code paths as regular prefill (the prefill_cursor logic, the is_fully_prefilled check, the budget calculation). If we introduced an off-by-one in the cursor or broke the transition from "prefilling" → "active", this test would catch it immediately. It’s the safety net that lets us refactor with confidence.
Test 2: actual chunking. This test exercises the chunking itself. We feed a single 20-token prompt with a token budget of only 8, forcing the scheduler to split the prefill into three chunks: 8 + 8 + 4 tokens across three separate scheduler iterations. The assertions verify that the request completed with the right number of generated tokens, that prefill_cursor reached the full prompt length (confirming all three chunks were processed), and that the KV cache shape matches prompt_length + generated_tokens - 1 — meaning the chunks were stitched together correctly without any gaps or double-counted positions.
# ══════════════════════════════════════════════════════════════════════════════
# Test 2: Long prompt that requires chunking
# 20-token prompt, token_budget=8 → should take 3 prefill steps (8+8+4)
# ══════════════════════════════════════════════════════════════════════════════
print("=" * 60)
print("Test 2: Long prompt that requires chunking")
print("=" * 60)
# Build a prompt of exactly 20 tokens (char-level model: 1 char = 1 token)
long_text = "KING RICHARD THE THI" # 20 characters = 20 tokens
long_prompt = encode(long_text)
assert len(long_prompt) == 20, f"Expected 20 tokens, got {len(long_prompt)}"
print(f"Prompt length: {len(long_prompt)} tokens")
print(f"Prompt text: '{long_text}'")
print(f"Token budget: 8")
print(f"Expected prefill chunks: 8 + 8 + 4 = 3 steps")
print()
request_queue = [
(0, Request(id=0, prompt_tokens=long_prompt, max_new_tokens=5)),
]
completed = continuous_batching_generate(model, request_queue, max_batch_size=4, token_budget=8)
req = completed[0]
print(f"\nStatus: {req.status}")
print(f"Generated {req.num_generated} tokens")
print(f"Prefill cursor: {req.prefill_cursor} / {len(req.prompt_tokens)}")
print(f"Output: {decode(req.tokens_so_far)}")
# Verify cache shape
k, _ = req.kv_cache[(0, 0)]
expected_T = len(req.prompt_tokens) + req.num_generated - 1
assert k.shape[1] == expected_T, f"Cache T={k.shape[1]}, expected {expected_T}"
assert req.status == "done"
assert req.prefill_cursor == len(req.prompt_tokens), "Prefill cursor should equal prompt length"
assert req.num_generated == req.max_new_tokens
print("\n✓ Test 2 passed — 20-token prompt was chunked (8+8+4) and completed correctly!")
The result is
============================================================
Test 2: Long prompt that requires chunking
============================================================
Prompt length: 20 tokens
Prompt text: 'KING RICHARD THE THI'
Token budget: 8
Expected prefill chunks: 8 + 8 + 4 = 3 steps
Status: done
Generated 5 tokens
Prefill cursor: 20 / 20
Output: KING RICHARD THE THI
✓ Test 2 passed — 20-token prompt was chunked (8+8+4) and completed correctly!
The key number is Prefill cursor: 20 / 20 — the cursor consumed all 20 prompt tokens across three separate forward passes and the KV cache ended up at the right shape. This is the first proof that the chunking math actually works: the chunk_start and chunk_size slicing correctly picks up where the previous chunk left off, the positional embeddings via torch.arange(chunk_start, chunk_start + chunk_size) give each token its correct absolute position, and the KV cache accumulates entries from each chunk without gaps or overlaps. If any chunk boundary was off by one, the cache shape assertion would fail.
Test 3: decode latency during chunked prefill. This test checks the whole point of chunked prefill: that decode requests don’t stall while a long prompt is being prefilled. Request 0 has a short 5-token prompt and starts decoding immediately at step 0. Request 1 arrives at step 1 with a 20-token prompt that needs multiple prefill chunks at a budget of 8. If the scheduler is working correctly, request 0 should continue getting a decode forward pass every single step — even while request 1’s prefill is still in progress. Both requests completing with the correct token counts confirms that the prefill chunks didn’t block decode.
# ══════════════════════════════════════════════════════════════════════════════
# Test 3: Verify decode latency isn't spiked
# An active decode request should get a forward pass EVERY step,
# even while a large prefill is in progress.
# ══════════════════════════════════════════════════════════════════════════════
print("=" * 60)
print("Test 3: Verify decode latency isn't spiked")
print("=" * 60)
# Short prompt starts decoding immediately; long prompt arrives at step 1
long_text_2 = "Now is the winter of" # 20 characters = 20 tokens
long_prompt_2 = encode(long_text_2)
print(f"Request 0: short prompt (5 tokens), decoding from step 0")
print(f"Request 1: long prompt ({len(long_prompt_2)} tokens), arrives step 1, budget=8")
print()
request_queue = [
(0, Request(id=0, prompt_tokens=encode("Hello"), max_new_tokens=15)),
(1, Request(id=1, prompt_tokens=long_prompt_2, max_new_tokens=5)),
]
# Measure per-step timing with a patched version
step_times = []
_orig_fn = continuous_batching_generate
# We can't easily instrument inside the function, so we measure total time
# and verify correct interleaving via step logs + assertions.
t0 = time.perf_counter()
completed = continuous_batching_generate(model, request_queue, max_batch_size=4, token_budget=8)
total_time = time.perf_counter() - t0
print(f"\nTotal wall time: {total_time:.4f}s")
print(f"Total tokens generated: {sum(r.num_generated for r in completed)}")
for req in sorted(completed, key=lambda r: r.id):
print(f" Request {req.id}: {req.num_generated} tokens, status={req.status}")
# Key assertion: both requests complete with correct token counts.
# If prefill was blocking decode, request 0 would stall.
req0 = [r for r in completed if r.id == 0][0]
req1 = [r for r in completed if r.id == 1][0]
assert req0.status == "done" and req0.num_generated == 15
assert req1.status == "done" and req1.num_generated == 5
# Verify cache shapes
for req in completed:
k, _ = req.kv_cache[(0, 0)]
expected_T = len(req.prompt_tokens) + req.num_generated - 1
assert k.shape[1] == expected_T, f"Req {req.id}: cache T={k.shape[1]}, expected {expected_T}"
print("\n✓ Test 3 passed — decode requests got service every step during chunked prefill!")
============================================================
Test 3: Verify decode latency isn't spiked
============================================================
Request 0: short prompt (5 tokens), decoding from step 0
Request 1: long prompt (20 tokens), arrives step 1, budget=8
[step 6] Completed request 1 (5 tokens)
[step 13] Completed request 0 (15 tokens)
Total wall time: 0.1418s
Total tokens generated: 20
Request 0: 15 tokens, status=done
Request 1: 5 tokens, status=done
✓ Test 3 passed — decode requests got service every step during chunked prefill!
Look at the completion steps: request 1 (the long prompt, 5 tokens to generate) finishes at step 6, while request 0 (the short prompt, 15 tokens to generate) finishes at step 13. Without chunked prefill, request 0 would have had to wait while request 1’s entire 20-token prompt was prefilled in one shot — stalling its decode for several steps. Here, request 0 kept generating the whole time. The 0.14s wall time for 20 total tokens also confirms there’s no obvious latency spike. In a production system, this is the difference between a user seeing the first few words of a response immediately versus staring at a blank screen while someone else’s long prompt hogs the GPU.
Test 4: mixed arrivals. This is the stress test. Four requests arrive at different times (steps 0, 2, 5, and 8) with varying prompt lengths — some short enough to prefill in one shot, others long enough to require multiple chunks. The token budget is 10, meaning the scheduler has to juggle partial prefills alongside active decode requests at every step. The assertions verify that all 4 requests completed with the correct number of generated tokens, that every prefill_cursor reached the full prompt length (no chunks were lost), and that all KV cache shapes are consistent. If any of these fail, the scheduler is either dropping chunks, miscounting positions, or corrupting caches when multiple requests overlap.
# ══════════════════════════════════════════════════════════════════════════════
# Test 4: Mixed arrivals
# Multiple requests arriving at different times with different prompt lengths.
# Some need chunking, some don't. All should complete correctly.
# ══════════════════════════════════════════════════════════════════════════════
print("=" * 60)
print("Test 4: Mixed arrivals with varying prompt lengths")
print("=" * 60)
request_queue = [
# Short prompt — fits in one chunk (5 tokens < budget of 10)
(0, Request(id=0, prompt_tokens=encode("Hello"), max_new_tokens=10)),
# Medium prompt — needs 2 chunks (16 tokens, budget=10 minus 1 decode = 9 per step)
(2, Request(id=1, prompt_tokens=encode("O Romeo, O Romeo!"), max_new_tokens=8)),
# Long prompt — needs 3+ chunks (25 tokens)
(5, Request(id=2, prompt_tokens=encode("Now is the winter of our d"), max_new_tokens=6)),
# Another short one arriving late
(8, Request(id=3, prompt_tokens=encode("Why?"), max_new_tokens=12)),
]
print("Queue:")
for arrival, req in request_queue:
print(f" step {arrival}: Request {req.id} — {len(req.prompt_tokens)} prompt tokens, "
f"wants {req.max_new_tokens} new tokens")
print(f"Token budget: 10")
print()
completed = continuous_batching_generate(model, request_queue, max_batch_size=4, token_budget=10)
print(f"\n{'═'*60}")
print(f"Results: {len(completed)} requests completed")
print(f"{'═'*60}")
for req in sorted(completed, key=lambda r: r.id):
print(f"\n{'─'*40}")
print(f"Request {req.id} | {req.num_generated} tokens | status: {req.status}")
print(f"Prompt: {len(req.prompt_tokens)} tokens | Prefill cursor: {req.prefill_cursor}")
print(f"{'─'*40}")
print(decode(req.tokens_so_far))
# Verify all 4 requests
assert len(completed) == 4, f"Expected 4 completed, got {len(completed)}"
for req in completed:
# Status and token count
assert req.status == "done", f"Req {req.id}: expected done, got {req.status}"
assert req.num_generated == req.max_new_tokens, (
f"Req {req.id}: expected {req.max_new_tokens} tokens, got {req.num_generated}")
# Prefill fully consumed
assert req.prefill_cursor == len(req.prompt_tokens), (
f"Req {req.id}: prefill_cursor={req.prefill_cursor}, expected {len(req.prompt_tokens)}")
# Cache shape
k, _ = req.kv_cache[(0, 0)]
expected_T = len(req.prompt_tokens) + req.num_generated - 1
assert k.shape[1] == expected_T, (
f"Req {req.id}: cache T={k.shape[1]}, expected {expected_T}")
print("\n✓ Test 4 passed — all mixed-arrival requests completed with correct cache shapes!")
The result is
============================================================
Test 4: Mixed arrivals with varying prompt lengths
============================================================
Queue:
step 0: Request 0 — 5 prompt tokens, wants 10 new tokens
step 2: Request 1 — 17 prompt tokens, wants 8 new tokens
step 5: Request 2 — 26 prompt tokens, wants 6 new tokens
step 8: Request 3 — 4 prompt tokens, wants 12 new tokens
Token budget: 10
[step 8] Completed request 0 (10 tokens)
[step 8] Completed request 1 (8 tokens)
[step 10] Completed request 2 (6 tokens)
[step 17] Completed request 3 (12 tokens)
════════════════════════════════════════════════════════════
Results: 4 requests completed
════════════════════════════════════════════════════════════
────────────────────────────────────────
Request 0 | 10 tokens | status: done
Prompt: 5 tokens | Prefill cursor: 5
────────────────────────────────────────
Hello, to shall
────────────────────────────────────────
Request 1 | 8 tokens | status: done
Prompt: 17 tokens | Prefill cursor: 17
────────────────────────────────────────
O Romeo, O Romeo!
Steed a
────────────────────────────────────────
Request 2 | 6 tokens | status: done
Prompt: 26 tokens | Prefill cursor: 26
────────────────────────────────────────
Now is the winter of our deman i
────────────────────────────────────────
Request 3 | 12 tokens | status: done
Prompt: 4 tokens | Prefill cursor: 4
────────────────────────────────────────
Why?
CAMILLO:
Y
✓ Test 4 passed — all mixed-arrival requests completed with correct cache shapes!
Four requests, four different arrival times, four different prompt lengths — all completed with correct token counts, fully-consumed prefill cursors, and consistent cache shapes. The completion order tells the story: requests 0 and 1 finish together at step 8, request 2 (26-token prompt) finishes at step 10, and request 3 (arrived late) finishes last at step 17. The scheduler correctly interleaved chunked prefills with ongoing decode work across all of them. This is the closest our toy implementation gets to simulating real traffic.
What we left on the table
There’s an important simplification worth flagging. A real chunked prefill system (like vLLM’s) mixes the prefill chunk and decode tokens into a single forward pass — one batch where some rows are prefilling and others are decoding. This requires a more complex attention mask: decode rows attend over their full KV cache plus the current token, while prefill rows attend only over the tokens in their current chunk (with causal masking within the chunk).
I kept the simpler approach: prefill and decode run as separate forward passes each iteration. This means two model calls per step instead of one, which wastes some GPU efficiency. But the scheduling is correct — decode never stalls, the chunks are the right size, and the KV caches accumulate properly. Fusing the two passes into one is a pure optimization that doesn’t change the scheduling logic.
A bug worth mentioning
During development, test 2 crashed with IndexError: list index out of range when trying to access completed[0]. The completed list was empty — the function returned before the request finished generating.
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
/tmp/ipykernel_854/2264159004.py in <cell line: 0>()
23 completed = continuous_batching_generate(model, request_queue, max_batch_size=4, token_budget=8)
24
---> 25 req = completed[0]
26 print(f"\nStatus: {req.status}")
27 print(f"Generated {req.num_generated} tokens")
IndexError: list index out of range
The bug was in the while loop condition: while active_requests or queue_idx < len(request_queue). When a request was mid-prefill and the queue was empty, neither condition was true — active_requests was empty (the request was in prefilling_requests, not active_requests) and queue_idx had already reached the end. The loop exited with the request half-processed. The fix was adding or prefilling_requests to the loop condition.
You can find the full source code on GitHub.
CZ