Skip to content

Commit d396a70

Browse files
committed
feat: add native TTFT, TPOT, ITL, and E2E latency tracking to framework
-Implements critical inference metrics directly within the IGW framework, removing the dependency on the SLO predictor plugin for observability. -Framework now natively tracks Time to First Token (TTFT), Time to Predict Output Token (TPOT), Inference Time Latency (ITL), and End-to-End (E2E) latency for all inference requests. -Added tests to validate metrics tracking Signed-off-by: Sathvik <Sathvik.S@ibm.com>
1 parent 5df0acd commit d396a70

File tree

5 files changed

+185
-3
lines changed

5 files changed

+185
-3
lines changed

pkg/epp/handlers/response.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"encoding/json"
2222
"fmt"
2323
"strings"
24+
"time"
2425

2526
configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
2627
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
@@ -84,6 +85,22 @@ func (s *StreamingServer) HandleResponseBodyModelStreaming(ctx context.Context,
8485
logger.Error(err, "error in HandleResponseBodyStreaming")
8586
}
8687

88+
// Record TTFT on the first token chunk.
89+
// We check for "data: " prefix to ensure it's a data chunk, and exclude "[DONE]" message.
90+
if reqCtx.GeneratedTokenCount == 0 && strings.Contains(responseText, streamingRespPrefix) && !strings.Contains(responseText, streamingEndMsg) {
91+
ttft := time.Since(reqCtx.RequestReceivedTimestamp).Seconds()
92+
reqCtx.TTFT = ttft
93+
metrics.RecordRequestTTFT(ctx, reqCtx.IncomingModelName, reqCtx.TargetModelName, ttft)
94+
reqCtx.GeneratedTokenCount = 1
95+
reqCtx.LastTokenTimestamp = time.Now()
96+
} else if reqCtx.GeneratedTokenCount > 0 && strings.Contains(responseText, streamingRespPrefix) && !strings.Contains(responseText, streamingEndMsg) {
97+
// Record ITL for subsequent tokens
98+
itl := time.Since(reqCtx.LastTokenTimestamp).Seconds()
99+
metrics.RecordRequestITL(ctx, reqCtx.IncomingModelName, reqCtx.TargetModelName, itl)
100+
reqCtx.LastTokenTimestamp = time.Now()
101+
reqCtx.GeneratedTokenCount++
102+
}
103+
87104
// Parse usage on EVERY chunk to catch split streams (where usage and [DONE] are in different chunks).
88105
if resp := parseRespForUsage(ctx, responseText); resp.Usage.TotalTokens > 0 {
89106
reqCtx.Usage = resp.Usage
@@ -98,6 +115,22 @@ func (s *StreamingServer) HandleResponseBodyModelStreaming(ctx context.Context,
98115
cachedToken = reqCtx.Usage.PromptTokenDetails.CachedTokens
99116
}
100117
metrics.RecordPromptCachedTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, cachedToken)
118+
119+
// Record TPOT
120+
// TPOT = (Total Duration - TTFT) / (OutputTokens - 1)
121+
if reqCtx.Usage.CompletionTokens > 1 && reqCtx.TTFT > 0 {
122+
totalDuration := time.Since(reqCtx.RequestReceivedTimestamp).Seconds()
123+
generationDuration := totalDuration - reqCtx.TTFT
124+
metrics.RecordRequestDecodeDuration(ctx, reqCtx.IncomingModelName, reqCtx.TargetModelName, generationDuration)
125+
metrics.RecordRequestE2ELatency(ctx, reqCtx.IncomingModelName, reqCtx.TargetModelName, totalDuration)
126+
127+
// Avoid division by zero just in case
128+
if count := float64(reqCtx.Usage.CompletionTokens - 1); count > 0 {
129+
avgTPOT := generationDuration / count
130+
reqCtx.TPOT = avgTPOT
131+
metrics.RecordRequestTPOT(ctx, reqCtx.IncomingModelName, reqCtx.TargetModelName, avgTPOT)
132+
}
133+
}
101134
}
102135
}
103136

pkg/epp/handlers/response_test.go

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"context"
2121
"encoding/json"
2222
"testing"
23+
"time"
2324

2425
"github.com/google/go-cmp/cmp"
2526
"github.com/stretchr/testify/assert"
@@ -327,3 +328,68 @@ func TestGenerateResponseHeaders_Sanitization(t *testing.T) {
327328
assert.NotContains(t, gotHeaders, metadata.DestinationEndpointKey)
328329
assert.NotContains(t, gotHeaders, "content-length")
329330
}
331+
332+
func TestHandleResponseBodyModelStreaming_Metrics(t *testing.T) {
333+
t.Parallel()
334+
ctx := context.Background()
335+
336+
t.Run("TTFT Recording", func(t *testing.T) {
337+
server := &StreamingServer{director: &mockDirector{}}
338+
reqCtx := &RequestContext{
339+
RequestReceivedTimestamp: time.Now().Add(-100 * time.Millisecond),
340+
IncomingModelName: "model-a",
341+
TargetModelName: "model-b",
342+
}
343+
344+
chunk := `data: {"choices":[{"text":"First token"}]}`
345+
server.HandleResponseBodyModelStreaming(ctx, reqCtx, chunk)
346+
347+
assert.Greater(t, reqCtx.TTFT, 0.0, "TTFT should be recorded and greater than 0")
348+
assert.Equal(t, 1, reqCtx.GeneratedTokenCount, "GeneratedTokenCount should be 1")
349+
assert.False(t, reqCtx.LastTokenTimestamp.IsZero(), "LastTokenTimestamp should be set")
350+
})
351+
352+
t.Run("ITL Recording", func(t *testing.T) {
353+
server := &StreamingServer{director: &mockDirector{}}
354+
reqCtx := &RequestContext{
355+
RequestReceivedTimestamp: time.Now().Add(-1 * time.Second),
356+
IncomingModelName: "model-a",
357+
TargetModelName: "model-b",
358+
// Simulate first token already received
359+
GeneratedTokenCount: 1,
360+
LastTokenTimestamp: time.Now().Add(-50 * time.Millisecond),
361+
TTFT: 0.1,
362+
}
363+
364+
chunk := `data: {"choices":[{"text":"Second token"}]}`
365+
server.HandleResponseBodyModelStreaming(ctx, reqCtx, chunk)
366+
367+
// ITL is not stored in ReqCtx, but we can verify state updates
368+
assert.Equal(t, 2, reqCtx.GeneratedTokenCount, "GeneratedTokenCount should increment")
369+
assert.True(t, time.Since(reqCtx.LastTokenTimestamp) < 10*time.Millisecond, "LastTokenTimestamp should be updated to Now")
370+
})
371+
372+
t.Run("TPOT and E2E Recording", func(t *testing.T) {
373+
server := &StreamingServer{director: &mockDirector{}}
374+
reqCtx := &RequestContext{
375+
RequestReceivedTimestamp: time.Now().Add(-1 * time.Second),
376+
IncomingModelName: "model-a",
377+
TargetModelName: "model-b",
378+
TTFT: 0.1,
379+
GeneratedTokenCount: 10,
380+
}
381+
382+
// Usage that triggers TPOT calc
383+
chunk := `data: {"usage":{"prompt_tokens":5,"completion_tokens":11,"total_tokens":16}}` + "\n" + `data: [DONE]`
384+
server.HandleResponseBodyModelStreaming(ctx, reqCtx, chunk)
385+
386+
assert.True(t, reqCtx.ResponseComplete, "Response should be complete")
387+
assert.Greater(t, reqCtx.TPOT, 0.0, "TPOT should be calculated")
388+
389+
// Expected TPOT calc: (TotalDuration - TTFT) / (CompletionTokens - 1)
390+
// TotalDuration ~ 1.0s, TTFT = 0.1s -> GenDuration ~ 0.9s
391+
// Tokens - 1 = 10
392+
// TPOT ~ 0.09
393+
assert.InDelta(t, 0.09, reqCtx.TPOT, 0.05, "TPOT should be approximately correct")
394+
})
395+
}

pkg/epp/handlers/server.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,12 @@ type RequestContext struct {
9999

100100
Response *Response
101101

102+
// Metrics
103+
TTFT float64
104+
TPOT float64
105+
LastTokenTimestamp time.Time
106+
GeneratedTokenCount int
107+
102108
reqHeaderResp *extProcPb.ProcessingResponse
103109
reqBodyResp []*extProcPb.ProcessingResponse
104110
reqTrailerResp *extProcPb.ProcessingResponse
@@ -145,7 +151,8 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
145151
// Create request context to share states during life time of an HTTP request.
146152
// See https://github.com/envoyproxy/envoy/issues/17540.
147153
reqCtx := &RequestContext{
148-
RequestState: RequestReceived,
154+
RequestState: RequestReceived,
155+
RequestReceivedTimestamp: time.Now(),
149156
Request: &Request{
150157
Headers: make(map[string]string),
151158
Body: make(map[string]any),

pkg/epp/metrics/metrics.go

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ const (
5757
TypeTTFTPredictionDuration = "ttft_prediction_duration"
5858
TypeTTFTSLOViolation = "ttft_slo_violation"
5959
TypeTTFTSLOThreshold = "ttft_slo_threshold"
60+
61+
TypeITL = "itl"
62+
TypeDecodeDuration = "decode_duration"
63+
TypeE2ELatency = "e2e_latency"
6064
)
6165

6266
var (
@@ -176,6 +180,36 @@ var (
176180
ModelLabels,
177181
)
178182

183+
requestITL = prometheus.NewHistogramVec(
184+
prometheus.HistogramOpts{
185+
Subsystem: InferenceObjectiveComponent,
186+
Name: "request_itl_seconds",
187+
Help: metricsutil.HelpMsgWithStability("Inference model Inter-Token Latency distribution in seconds for each model and target model.", compbasemetrics.ALPHA),
188+
Buckets: TPOTBuckets,
189+
},
190+
ModelLabels,
191+
)
192+
193+
requestDecodeDuration = prometheus.NewHistogramVec(
194+
prometheus.HistogramOpts{
195+
Subsystem: InferenceObjectiveComponent,
196+
Name: "request_decode_duration_seconds",
197+
Help: metricsutil.HelpMsgWithStability("Inference model Decode Duration distribution in seconds for each model and target model.", compbasemetrics.ALPHA),
198+
Buckets: GeneralLatencyBuckets,
199+
},
200+
ModelLabels,
201+
)
202+
203+
requestE2ELatency = prometheus.NewHistogramVec(
204+
prometheus.HistogramOpts{
205+
Subsystem: InferenceObjectiveComponent,
206+
Name: "request_e2e_latency_seconds",
207+
Help: metricsutil.HelpMsgWithStability("Inference model E2E Latency (TTFT + Decode) distribution in seconds for each model and target model.", compbasemetrics.ALPHA),
208+
Buckets: GeneralLatencyBuckets,
209+
},
210+
ModelLabels,
211+
)
212+
179213
sloViolationCounter = prometheus.NewCounterVec(
180214
prometheus.CounterOpts{
181215
Subsystem: InferenceObjectiveComponent,
@@ -443,6 +477,9 @@ func Register(customCollectors ...prometheus.Collector) {
443477
metrics.Registry.MustRegister(requestPredictedTTFT)
444478
metrics.Registry.MustRegister(requestTPOTPredictionDuration)
445479
metrics.Registry.MustRegister(requestTTFTPredictionDuration)
480+
metrics.Registry.MustRegister(requestITL)
481+
metrics.Registry.MustRegister(requestDecodeDuration)
482+
metrics.Registry.MustRegister(requestE2ELatency)
446483

447484
// Register SLO violation counters
448485
metrics.Registry.MustRegister(sloViolationCounter)
@@ -490,6 +527,9 @@ func Reset() {
490527
requestPredictedTTFT.Reset()
491528
requestTPOTPredictionDuration.Reset()
492529
requestTTFTPredictionDuration.Reset()
530+
requestITL.Reset()
531+
requestDecodeDuration.Reset()
532+
requestE2ELatency.Reset()
493533

494534
// Reset SLO violation counter
495535
sloViolationCounter.Reset()
@@ -667,6 +707,42 @@ func RecordRequestTTFTPredictionDuration(ctx context.Context, modelName, targetM
667707
return true
668708
}
669709

710+
// RecordRequestITL records the Inter-Token Latency.
711+
func RecordRequestITL(ctx context.Context, modelName, targetModelName string, itl float64) bool {
712+
if itl < 0 {
713+
log.FromContext(ctx).V(logutil.DEFAULT).Error(nil, "ITL value must be non-negative",
714+
"modelName", modelName, "targetModelName", targetModelName, "itl", itl)
715+
return false
716+
}
717+
requestITL.WithLabelValues(modelName, targetModelName).Observe(itl)
718+
inferenceGauges.WithLabelValues(modelName, targetModelName, TypeITL).Set(itl)
719+
return true
720+
}
721+
722+
// RecordRequestDecodeDuration records the Decode Duration.
723+
func RecordRequestDecodeDuration(ctx context.Context, modelName, targetModelName string, duration float64) bool {
724+
if duration < 0 {
725+
log.FromContext(ctx).V(logutil.DEFAULT).Error(nil, "Decode duration value must be non-negative",
726+
"modelName", modelName, "targetModelName", targetModelName, "duration", duration)
727+
return false
728+
}
729+
requestDecodeDuration.WithLabelValues(modelName, targetModelName).Observe(duration)
730+
inferenceGauges.WithLabelValues(modelName, targetModelName, TypeDecodeDuration).Set(duration)
731+
return true
732+
}
733+
734+
// RecordRequestE2ELatency records the E2E Latency (TTFT + Decode).
735+
func RecordRequestE2ELatency(ctx context.Context, modelName, targetModelName string, duration float64) bool {
736+
if duration < 0 {
737+
log.FromContext(ctx).V(logutil.DEFAULT).Error(nil, "E2E latency value must be non-negative",
738+
"modelName", modelName, "targetModelName", targetModelName, "duration", duration)
739+
return false
740+
}
741+
requestE2ELatency.WithLabelValues(modelName, targetModelName).Observe(duration)
742+
inferenceGauges.WithLabelValues(modelName, targetModelName, TypeE2ELatency).Set(duration)
743+
return true
744+
}
745+
670746
// RecordResponseSizes records the response sizes.
671747
func RecordResponseSizes(modelName, targetModelName string, size int) {
672748
responseSizes.WithLabelValues(modelName, targetModelName).Observe(float64(size))

pkg/epp/scheduling/framework/plugins/multi/predicted_latency/requestcontrol_hooks.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ func (t *PredictedLatency) ResponseComplete(ctx context.Context, request *schedu
219219

220220
if predictedLatencyCtx.ttft > 0 {
221221
logger.V(logutil.TRACE).Info("Averages calculated", "avgActualTTFT", predictedLatencyCtx.ttft, "avgPredictedTTFT", predictedLatencyCtx.predictedTTFT)
222-
metrics.RecordRequestTTFT(ctx, predictedLatencyCtx.incomingModelName, request.TargetModel, predictedLatencyCtx.ttft/1000)
222+
// metrics.RecordRequestTTFT(ctx, predictedLatencyCtx.incomingModelName, request.TargetModel, predictedLatencyCtx.ttft/1000)
223223
metrics.RecordRequestPredictedTTFT(ctx, predictedLatencyCtx.incomingModelName, request.TargetModel, predictedLatencyCtx.predictedTTFT/1000)
224224
if predictedLatencyCtx.ttftSLO > 0 {
225225
metrics.RecordRequestTTFTWithSLO(ctx, predictedLatencyCtx.incomingModelName, request.TargetModel, predictedLatencyCtx.ttft, predictedLatencyCtx.ttftSLO)
@@ -228,7 +228,7 @@ func (t *PredictedLatency) ResponseComplete(ctx context.Context, request *schedu
228228

229229
if predictedLatencyCtx.avgTPOT > 0 {
230230
logger.V(logutil.TRACE).Info("Averages calculated", "avgActualTPOT", predictedLatencyCtx.avgTPOT, "avgPredictedTPOT", predictedLatencyCtx.avgPredictedTPOT)
231-
metrics.RecordRequestTPOT(ctx, predictedLatencyCtx.incomingModelName, request.TargetModel, predictedLatencyCtx.avgTPOT/1000)
231+
// metrics.RecordRequestTPOT(ctx, predictedLatencyCtx.incomingModelName, request.TargetModel, predictedLatencyCtx.avgTPOT/1000)
232232
metrics.RecordRequestPredictedTPOT(ctx, predictedLatencyCtx.incomingModelName, request.TargetModel, predictedLatencyCtx.avgPredictedTPOT/1000)
233233
if predictedLatencyCtx.avgTPOTSLO > 0 {
234234
metrics.RecordRequestTPOTWithSLO(ctx, predictedLatencyCtx.incomingModelName, request.TargetModel, predictedLatencyCtx.avgTPOT, predictedLatencyCtx.avgTPOTSLO)

0 commit comments

Comments
 (0)