diff --git a/internal/api/chat/create_conversation_message_stream_v2.go b/internal/api/chat/create_conversation_message_stream_v2.go index 3537ec79..46f92610 100644 --- a/internal/api/chat/create_conversation_message_stream_v2.go +++ b/internal/api/chat/create_conversation_message_stream_v2.go @@ -321,7 +321,7 @@ func (s *ChatServerV2) CreateConversationMessageStream( } } - openaiChatHistory, inappChatHistory, err := s.aiClientV2.ChatCompletionStreamV2(ctx, stream, conversation.ID.Hex(), modelSlug, conversation.OpenaiChatHistoryCompletion, llmProvider, customModel) + openaiChatHistory, inappChatHistory, _, err := s.aiClientV2.ChatCompletionStreamV2(ctx, stream, conversation.UserID, conversation.ProjectID, conversation.ID.Hex(), modelSlug, conversation.OpenaiChatHistoryCompletion, llmProvider, customModel) if err != nil { return s.sendStreamError(stream, err) } @@ -347,7 +347,7 @@ func (s *ChatServerV2) CreateConversationMessageStream( for i, bsonMsg := range conversation.InappChatHistory { protoMessages[i] = mapper.BSONToChatMessageV2(bsonMsg) } - title, err := s.aiClientV2.GetConversationTitleV2(ctx, protoMessages, llmProvider, modelSlug, customModel) + title, err := s.aiClientV2.GetConversationTitleV2(ctx, conversation.UserID, conversation.ProjectID, protoMessages, llmProvider, modelSlug, customModel) if err != nil { s.logger.Error("Failed to get conversation title", "error", err, "conversationID", conversation.ID.Hex()) return diff --git a/internal/models/usage.go b/internal/models/usage.go new file mode 100644 index 00000000..0e6af50d --- /dev/null +++ b/internal/models/usage.go @@ -0,0 +1,71 @@ +package models + +import ( + "time" + + "go.mongodb.org/mongo-driver/v2/bson" +) + +// HourlyUsage tracks cost per user, per project, per hour. +// Each document represents one hour bucket of usage. +type HourlyUsage struct { + ID bson.ObjectID `bson:"_id"` + UserID bson.ObjectID `bson:"user_id"` + ProjectID string `bson:"project_id"` + HourBucket bson.DateTime `bson:"hour_bucket"` // Timestamp truncated to the hour + SuccessCost float64 `bson:"success_cost"` // Cost in USD for successful requests + FailedCost float64 `bson:"failed_cost"` // Cost in USD for failed requests + UpdatedAt bson.DateTime `bson:"updated_at"` +} + +func (u HourlyUsage) CollectionName() string { + return "hourly_usages" +} + +// WeeklyUsage tracks cost per user, per project, per week. +// Each document represents one week bucket of usage. +type WeeklyUsage struct { + ID bson.ObjectID `bson:"_id"` + UserID bson.ObjectID `bson:"user_id"` + ProjectID string `bson:"project_id"` + WeekBucket bson.DateTime `bson:"week_bucket"` // Timestamp truncated to the week (Monday) + SuccessCost float64 `bson:"success_cost"` // Cost in USD for successful requests + FailedCost float64 `bson:"failed_cost"` // Cost in USD for failed requests + UpdatedAt bson.DateTime `bson:"updated_at"` +} + +func (u WeeklyUsage) CollectionName() string { + return "weekly_usages" +} + +// LifetimeUsage tracks total cost per user, per project, across all time. +// Each document represents the cumulative usage for a user-project pair. +type LifetimeUsage struct { + ID bson.ObjectID `bson:"_id"` + UserID bson.ObjectID `bson:"user_id"` + ProjectID string `bson:"project_id"` + SuccessCost float64 `bson:"success_cost"` // Total cost in USD for successful requests + FailedCost float64 `bson:"failed_cost"` // Total cost in USD for failed requests + UpdatedAt bson.DateTime `bson:"updated_at"` +} + +func (u LifetimeUsage) CollectionName() string { + return "lifetime_usages" +} + +// TruncateToHour truncates a time to the start of its hour. +func TruncateToHour(t time.Time) time.Time { + return t.Truncate(time.Hour) +} + +// TruncateToWeek truncates a time to the start of its week (Monday 00:00:00 UTC). +func TruncateToWeek(t time.Time) time.Time { + t = t.UTC() + weekday := int(t.Weekday()) + if weekday == 0 { + weekday = 7 // Sunday becomes 7 + } + // Subtract days to get to Monday + monday := t.AddDate(0, 0, -(weekday - 1)) + return time.Date(monday.Year(), monday.Month(), monday.Day(), 0, 0, 0, 0, time.UTC) +} diff --git a/internal/services/toolkit/client/client_v2.go b/internal/services/toolkit/client/client_v2.go index d32e01f1..f7e4cdf5 100644 --- a/internal/services/toolkit/client/client_v2.go +++ b/internal/services/toolkit/client/client_v2.go @@ -20,6 +20,7 @@ type AIClientV2 struct { reverseCommentService *services.ReverseCommentService projectService *services.ProjectService + usageService *services.UsageService cfg *cfg.Cfg logger *logger.Logger } @@ -62,6 +63,7 @@ func NewAIClientV2( reverseCommentService *services.ReverseCommentService, projectService *services.ProjectService, + usageService *services.UsageService, cfg *cfg.Cfg, logger *logger.Logger, ) *AIClientV2 { @@ -109,6 +111,7 @@ func NewAIClientV2( reverseCommentService: reverseCommentService, projectService: projectService, + usageService: usageService, cfg: cfg, logger: logger, } diff --git a/internal/services/toolkit/client/completion_v2.go b/internal/services/toolkit/client/completion_v2.go index 7266d669..ee4ad765 100644 --- a/internal/services/toolkit/client/completion_v2.go +++ b/internal/services/toolkit/client/completion_v2.go @@ -6,11 +6,19 @@ import ( "paperdebugger/internal/models" "paperdebugger/internal/services/toolkit/handler" chatv2 "paperdebugger/pkg/gen/api/chat/v2" + "strconv" "strings" + "time" "github.com/openai/openai-go/v3" + "go.mongodb.org/mongo-driver/v2/bson" ) +// UsageCost holds cost information from a completion. +type UsageCost struct { + Cost float64 +} + // define []openai.ChatCompletionMessageParamUnion as OpenAIChatHistory // ChatCompletion orchestrates a chat completion process with a language model (e.g., GPT), handling tool calls and message history management. @@ -24,13 +32,14 @@ import ( // Returns: // 1. The full chat history sent to the language model (including any tool call results). // 2. The incremental chat history visible to the user (including tool call results and assistant responses). -// 3. An error, if any occurred during the process. -func (a *AIClientV2) ChatCompletionV2(ctx context.Context, modelSlug string, messages OpenAIChatHistory, llmProvider *models.LLMProviderConfig, customModel *models.CustomModel) (OpenAIChatHistory, AppChatHistory, error) { - openaiChatHistory, inappChatHistory, err := a.ChatCompletionStreamV2(ctx, nil, "", modelSlug, messages, llmProvider, customModel) +// 3. Cost information (in USD). +// 4. An error, if any occurred during the process. +func (a *AIClientV2) ChatCompletionV2(ctx context.Context, userID bson.ObjectID, projectID string, modelSlug string, messages OpenAIChatHistory, llmProvider *models.LLMProviderConfig, customModel *models.CustomModel) (OpenAIChatHistory, AppChatHistory, UsageCost, error) { + openaiChatHistory, inappChatHistory, usage, err := a.ChatCompletionStreamV2(ctx, nil, userID, projectID, "", modelSlug, messages, llmProvider, customModel) if err != nil { - return nil, nil, err + return nil, nil, usage, err } - return openaiChatHistory, inappChatHistory, nil + return openaiChatHistory, inappChatHistory, usage, nil } // ChatCompletionStream orchestrates a streaming chat completion process with a language model (e.g., GPT), handling tool calls, message history management, and real-time streaming of responses to the client. @@ -46,17 +55,20 @@ func (a *AIClientV2) ChatCompletionV2(ctx context.Context, modelSlug string, mes // Returns: (same as ChatCompletion) // 1. The full chat history sent to the language model (including any tool call results). // 2. The incremental chat history visible to the user (including tool call results and assistant responses). -// 3. An error, if any occurred during the process. (However, in the streaming mode, the error is not returned, but sending by callbackStream) +// 3. Cost information (in USD, accumulated across all calls). +// 4. An error, if any occurred during the process. (However, in the streaming mode, the error is not returned, but sending by callbackStream) // // This function works as follows: (same as ChatCompletion) // - It initializes the chat history for the language model and the user, and sets up a stream handler for real-time updates. // - It repeatedly sends the current chat history to the language model, receives streaming responses, and forwards them to the client as they arrive. // - If tool calls are required, it handles them and appends the results to the chat history, then continues the loop. // - If no tool calls are needed, it appends the assistant's response and exits the loop. -// - Finally, it returns the updated chat histories and any error encountered. -func (a *AIClientV2) ChatCompletionStreamV2(ctx context.Context, callbackStream chatv2.ChatService_CreateConversationMessageStreamServer, conversationId string, modelSlug string, messages OpenAIChatHistory, llmProvider *models.LLMProviderConfig, customModel *models.CustomModel) (OpenAIChatHistory, AppChatHistory, error) { +// - Finally, it returns the updated chat histories, accumulated cost, and any error encountered. +func (a *AIClientV2) ChatCompletionStreamV2(ctx context.Context, callbackStream chatv2.ChatService_CreateConversationMessageStreamServer, userID bson.ObjectID, projectID string, conversationId string, modelSlug string, messages OpenAIChatHistory, llmProvider *models.LLMProviderConfig, customModel *models.CustomModel) (OpenAIChatHistory, AppChatHistory, UsageCost, error) { openaiChatHistory := messages inappChatHistory := AppChatHistory{} + usage := UsageCost{} + success := false // Track whether the request completed successfully streamHandler := handler.NewStreamHandlerV2(callbackStream, conversationId, modelSlug) @@ -65,6 +77,19 @@ func (a *AIClientV2) ChatCompletionStreamV2(ctx context.Context, callbackStream streamHandler.SendFinalization() }() + // Track usage on all exit paths (success or error) to prevent abuse + // Only track if userID is provided and user is not using their own API key (BYOK) + defer func() { + if !userID.IsZero() && !llmProvider.IsCustomModel && usage.Cost > 0 { + // Use a detached context since the request context may be canceled + trackCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := a.usageService.TrackUsage(trackCtx, userID, projectID, usage.Cost, success); err != nil { + a.logger.Error("Error while tracking usage", "error", err) + } + } + }() + oaiClient := a.GetOpenAIClient(llmProvider) params := getDefaultParamsV2(modelSlug, a.toolCallHandler.Registry, customModel) @@ -77,6 +102,7 @@ func (a *AIClientV2) ChatCompletionStreamV2(ctx context.Context, callbackStream answer_content := "" answer_content_id := "" has_sent_part_begin := false + has_finished := false tool_info := map[int]map[string]string{} toolCalls := []openai.FinishedChatCompletionToolCall{} handleReasoning := func(raw string) (string, bool) { @@ -92,12 +118,18 @@ func (a *AIClientV2) ChatCompletionStreamV2(ctx context.Context, callbackStream } for stream.Next() { - // time.Sleep(5000 * time.Millisecond) // DEBUG POINT: change this to test in a slow mode chunk := stream.Current() + // Capture cost from any chunk that has usage data (OpenRouter sends usage in a separate chunk after FinishReason) + if chunk.Usage.PromptTokens > 0 || chunk.Usage.CompletionTokens > 0 { + if costField, ok := chunk.Usage.JSON.ExtraFields["cost"]; ok { + if cost, err := strconv.ParseFloat(costField.Raw(), 64); err == nil { + usage.Cost += cost + } + } + } + if len(chunk.Choices) == 0 { - // Handle usage information - // fmt.Printf("Usage: %+v\n", chunk.Usage) continue } @@ -180,17 +212,15 @@ func (a *AIClientV2) ChatCompletionStreamV2(ctx context.Context, callbackStream } } - if chunk.Choices[0].FinishReason != "" { - // fmt.Printf("FinishReason: %s\n", chunk.Choices[0].FinishReason) - // answer_content += chunk.Choices[0].Delta.Content - // fmt.Printf("answer_content: %s\n", answer_content) + if chunk.Choices[0].FinishReason != "" && !has_finished { streamHandler.HandleTextDoneItem(chunk, answer_content, reasoning_content) - break + has_finished = true + // Don't break - continue reading to capture the usage chunk that comes after } } if err := stream.Err(); err != nil { - return nil, nil, err + return nil, nil, usage, err } if answer_content != "" { @@ -200,7 +230,7 @@ func (a *AIClientV2) ChatCompletionStreamV2(ctx context.Context, callbackStream // Execute the calls (if any), return incremental data openaiToolHistory, inappToolHistory, err := a.toolCallHandler.HandleToolCallsV2(ctx, toolCalls, streamHandler) if err != nil { - return nil, nil, err + return nil, nil, usage, err } // // Record the tool call results @@ -213,5 +243,6 @@ func (a *AIClientV2) ChatCompletionStreamV2(ctx context.Context, callbackStream } } - return openaiChatHistory, inappChatHistory, nil + success = true + return openaiChatHistory, inappChatHistory, usage, nil } diff --git a/internal/services/toolkit/client/get_citation_keys.go b/internal/services/toolkit/client/get_citation_keys.go index 2344d49d..a7063b21 100644 --- a/internal/services/toolkit/client/get_citation_keys.go +++ b/internal/services/toolkit/client/get_citation_keys.go @@ -241,7 +241,7 @@ func (a *AIClientV2) GetCitationKeys(ctx context.Context, sentence string, userI // Bibliography is placed at the start of the prompt to leverage prompt caching message := fmt.Sprintf("Bibliography: %s\nSentence: %s\nBased on the sentence and bibliography, suggest only the most relevant citation keys separated by commas with no spaces (e.g. key1,key2). Be selective and only include citations that are directly relevant. Avoid suggesting more than 3 citations. If no relevant citations are found, return '%s'.", bibliography, sentence, emptyCitation) - _, resp, err := a.ChatCompletionV2(ctx, "gpt-5.2", OpenAIChatHistory{ + _, resp, _, err := a.ChatCompletionV2(ctx, userId, projectId, "gpt-5.2", OpenAIChatHistory{ openai.SystemMessage("You are a helpful assistant that suggests relevant citation keys."), openai.UserMessage(message), }, llmProvider, nil) diff --git a/internal/services/toolkit/client/get_citation_keys_test.go b/internal/services/toolkit/client/get_citation_keys_test.go index 4d2a857d..802e6bbf 100644 --- a/internal/services/toolkit/client/get_citation_keys_test.go +++ b/internal/services/toolkit/client/get_citation_keys_test.go @@ -25,10 +25,12 @@ func setupTestClient(t *testing.T) (*client.AIClientV2, *services.ProjectService } projectService := services.NewProjectService(dbInstance, cfg.GetCfg(), logger.GetLogger()) + usageService := services.NewUsageService(dbInstance, cfg.GetCfg(), logger.GetLogger()) aiClient := client.NewAIClientV2( dbInstance, &services.ReverseCommentService{}, projectService, + usageService, cfg.GetCfg(), logger.GetLogger(), ) diff --git a/internal/services/toolkit/client/get_conversation_title_v2.go b/internal/services/toolkit/client/get_conversation_title_v2.go index 27840c7c..0e9129dd 100644 --- a/internal/services/toolkit/client/get_conversation_title_v2.go +++ b/internal/services/toolkit/client/get_conversation_title_v2.go @@ -11,9 +11,10 @@ import ( "github.com/openai/openai-go/v3" "github.com/samber/lo" + "go.mongodb.org/mongo-driver/v2/bson" ) -func (a *AIClientV2) GetConversationTitleV2(ctx context.Context, inappChatHistory []*chatv2.Message, llmProvider *models.LLMProviderConfig, modelSlug string, customModel *models.CustomModel) (string, error) { +func (a *AIClientV2) GetConversationTitleV2(ctx context.Context, userID bson.ObjectID, projectID string, inappChatHistory []*chatv2.Message, llmProvider *models.LLMProviderConfig, modelSlug string, customModel *models.CustomModel) (string, error) { messages := lo.Map(inappChatHistory, func(message *chatv2.Message, _ int) string { if _, ok := message.Payload.MessageType.(*chatv2.MessagePayload_Assistant); ok { return fmt.Sprintf("Assistant: %s", message.Payload.GetAssistant().GetContent()) @@ -35,7 +36,7 @@ func (a *AIClientV2) GetConversationTitleV2(ctx context.Context, inappChatHistor modelToUse = modelSlug } - _, resp, err := a.ChatCompletionV2(ctx, modelToUse, OpenAIChatHistory{ + _, resp, _, err := a.ChatCompletionV2(ctx, userID, projectID, modelToUse, OpenAIChatHistory{ openai.SystemMessage("You are a helpful assistant that generates a title for a conversation."), openai.UserMessage(message), }, llmProvider, customModel) diff --git a/internal/services/toolkit/client/utils_v2.go b/internal/services/toolkit/client/utils_v2.go index 884b91eb..3e6752b3 100644 --- a/internal/services/toolkit/client/utils_v2.go +++ b/internal/services/toolkit/client/utils_v2.go @@ -94,6 +94,9 @@ func getDefaultParamsV2(modelSlug string, toolRegistry *registry.ToolRegistryV2, Tools: toolRegistry.GetTools(), ParallelToolCalls: openaiv3.Bool(true), Store: openaiv3.Bool(false), + StreamOptions: openaiv3.ChatCompletionStreamOptionsParam{ + IncludeUsage: openaiv3.Bool(true), + }, } } } @@ -105,6 +108,9 @@ func getDefaultParamsV2(modelSlug string, toolRegistry *registry.ToolRegistryV2, Tools: toolRegistry.GetTools(), // Tool registration is managed centrally by the registry ParallelToolCalls: openaiv3.Bool(true), Store: openaiv3.Bool(false), // Must set to false, because we are construct our own chat history. + StreamOptions: openaiv3.ChatCompletionStreamOptionsParam{ + IncludeUsage: openaiv3.Bool(true), + }, } } diff --git a/internal/services/usage.go b/internal/services/usage.go new file mode 100644 index 00000000..3c61125a --- /dev/null +++ b/internal/services/usage.go @@ -0,0 +1,190 @@ +package services + +import ( + "context" + "time" + + "paperdebugger/internal/libs/cfg" + "paperdebugger/internal/libs/db" + "paperdebugger/internal/libs/logger" + "paperdebugger/internal/models" + + "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/mongo" + "go.mongodb.org/mongo-driver/v2/mongo/options" +) + +type UsageService struct { + BaseService + hourlyCollection *mongo.Collection + weeklyCollection *mongo.Collection + lifetimeCollection *mongo.Collection +} + +func NewUsageService(db *db.DB, cfg *cfg.Cfg, logger *logger.Logger) *UsageService { + base := NewBaseService(db, cfg, logger) + hourlyCollection := base.db.Collection((models.HourlyUsage{}).CollectionName()) + weeklyCollection := base.db.Collection((models.WeeklyUsage{}).CollectionName()) + lifetimeCollection := base.db.Collection((models.LifetimeUsage{}).CollectionName()) + + // Hourly usage indexes + hourlyIndexModels := []mongo.IndexModel{ + { + Keys: bson.D{ + {Key: "user_id", Value: 1}, + {Key: "project_id", Value: 1}, + {Key: "hour_bucket", Value: 1}, + }, + Options: options.Index().SetUnique(true), + }, + { + Keys: bson.D{ + {Key: "project_id", Value: 1}, + {Key: "hour_bucket", Value: 1}, + }, + }, + { + Keys: bson.D{ + {Key: "hour_bucket", Value: 1}, + }, + Options: options.Index().SetExpireAfterSeconds(14 * 24 * 60 * 60), // 2 weeks TTL + }, + } + _, err := hourlyCollection.Indexes().CreateMany(context.Background(), hourlyIndexModels) + if err != nil { + logger.Error("Failed to create indexes for hourly_usages collection", err) + } + + // Weekly usage indexes + weeklyIndexModels := []mongo.IndexModel{ + { + Keys: bson.D{ + {Key: "user_id", Value: 1}, + {Key: "project_id", Value: 1}, + {Key: "week_bucket", Value: 1}, + }, + Options: options.Index().SetUnique(true), + }, + { + Keys: bson.D{ + {Key: "project_id", Value: 1}, + {Key: "week_bucket", Value: 1}, + }, + }, + { + Keys: bson.D{ + {Key: "week_bucket", Value: 1}, + }, + Options: options.Index().SetExpireAfterSeconds(14 * 24 * 60 * 60), // 2 weeks TTL + }, + } + _, err = weeklyCollection.Indexes().CreateMany(context.Background(), weeklyIndexModels) + if err != nil { + logger.Error("Failed to create indexes for weekly_usages collection", err) + } + + // Lifetime usage indexes (no TTL since it's lifetime) + lifetimeIndexModels := []mongo.IndexModel{ + { + Keys: bson.D{ + {Key: "user_id", Value: 1}, + {Key: "project_id", Value: 1}, + }, + Options: options.Index().SetUnique(true), + }, + { + Keys: bson.D{ + {Key: "project_id", Value: 1}, + }, + }, + } + _, err = lifetimeCollection.Indexes().CreateMany(context.Background(), lifetimeIndexModels) + if err != nil { + logger.Error("Failed to create indexes for lifetime_usages collection", err) + } + + return &UsageService{ + BaseService: base, + hourlyCollection: hourlyCollection, + weeklyCollection: weeklyCollection, + lifetimeCollection: lifetimeCollection, + } +} + +// TrackUsage increments cost for a user/project in hourly, weekly, and lifetime buckets. +// Uses upsert to create or update the usage records atomically. +// The success parameter indicates whether the request completed successfully. +// We will be charging only for successful requests, but we track failed requests for monitoring. +func (s *UsageService) TrackUsage(ctx context.Context, userID bson.ObjectID, projectID string, cost float64, success bool) error { + if cost == 0 { + return nil + } + + now := time.Now() + + // Track hourly usage + if err := s.trackHourlyUsage(ctx, userID, projectID, cost, success, now); err != nil { + return err + } + + // Track weekly usage + if err := s.trackWeeklyUsage(ctx, userID, projectID, cost, success, now); err != nil { + return err + } + + // Track lifetime usage + if err := s.trackLifetimeUsage(ctx, userID, projectID, cost, success, now); err != nil { + return err + } + + return nil +} + +func (s *UsageService) upsertUsage(ctx context.Context, collection *mongo.Collection, filter bson.M, cost float64, success bool, now time.Time) error { + costField := "failed_cost" + if success { + costField = "success_cost" + } + + update := bson.M{ + "$inc": bson.M{ + costField: cost, + }, + "$set": bson.M{ + "updated_at": bson.NewDateTimeFromTime(now), + }, + "$setOnInsert": bson.M{ + "_id": bson.NewObjectID(), + }, + } + + opts := options.UpdateOne().SetUpsert(true) + _, err := collection.UpdateOne(ctx, filter, update, opts) + return err +} + +func (s *UsageService) trackHourlyUsage(ctx context.Context, userID bson.ObjectID, projectID string, cost float64, success bool, now time.Time) error { + filter := bson.M{ + "user_id": userID, + "project_id": projectID, + "hour_bucket": bson.NewDateTimeFromTime(models.TruncateToHour(now)), + } + return s.upsertUsage(ctx, s.hourlyCollection, filter, cost, success, now) +} + +func (s *UsageService) trackWeeklyUsage(ctx context.Context, userID bson.ObjectID, projectID string, cost float64, success bool, now time.Time) error { + filter := bson.M{ + "user_id": userID, + "project_id": projectID, + "week_bucket": bson.NewDateTimeFromTime(models.TruncateToWeek(now)), + } + return s.upsertUsage(ctx, s.weeklyCollection, filter, cost, success, now) +} + +func (s *UsageService) trackLifetimeUsage(ctx context.Context, userID bson.ObjectID, projectID string, cost float64, success bool, now time.Time) error { + filter := bson.M{ + "user_id": userID, + "project_id": projectID, + } + return s.upsertUsage(ctx, s.lifetimeCollection, filter, cost, success, now) +} diff --git a/internal/services/usage_test.go b/internal/services/usage_test.go new file mode 100644 index 00000000..1ba0f76c --- /dev/null +++ b/internal/services/usage_test.go @@ -0,0 +1,136 @@ +package services_test + +import ( + "context" + "os" + "testing" + "time" + + "paperdebugger/internal/libs/cfg" + "paperdebugger/internal/libs/db" + "paperdebugger/internal/libs/logger" + "paperdebugger/internal/models" + "paperdebugger/internal/services" + + "github.com/stretchr/testify/assert" + "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/mongo" +) + +func setupTestUsageService(t *testing.T) (*services.UsageService, *mongo.Database) { + os.Setenv("PD_MONGO_URI", "mongodb://localhost:27017") + dbInstance, err := db.NewDB(cfg.GetCfg(), logger.GetLogger()) + if err != nil { + t.Fatalf("failed to connect to test db: %v", err) + } + return services.NewUsageService(dbInstance, cfg.GetCfg(), logger.GetLogger()), + dbInstance.Database("paperdebugger") +} + +// TestTrackUsage_FailedCompletion verifies that when a completion fails +// (success=false), the cost is recorded under failed_cost (not success_cost) +// across all three buckets: hourly, weekly, and lifetime. +func TestTrackUsage_FailedCompletion(t *testing.T) { + us, database := setupTestUsageService(t) + ctx := context.Background() + + userID := bson.NewObjectID() + projectID := "test-project-" + bson.NewObjectID().Hex() + cost := 0.0125 + + // Clean up after the test + t.Cleanup(func() { + filter := bson.M{"user_id": userID, "project_id": projectID} + _, _ = database.Collection(models.HourlyUsage{}.CollectionName()).DeleteMany(ctx, filter) + _, _ = database.Collection(models.WeeklyUsage{}.CollectionName()).DeleteMany(ctx, filter) + _, _ = database.Collection(models.LifetimeUsage{}.CollectionName()).DeleteMany(ctx, filter) + }) + + err := us.TrackUsage(ctx, userID, projectID, cost, false) + assert.NoError(t, err) + + now := time.Now() + + // Hourly bucket: failed_cost incremented, success_cost untouched. + var hourly models.HourlyUsage + err = database.Collection(models.HourlyUsage{}.CollectionName()).FindOne(ctx, bson.M{ + "user_id": userID, + "project_id": projectID, + "hour_bucket": bson.NewDateTimeFromTime(models.TruncateToHour(now)), + }).Decode(&hourly) + assert.NoError(t, err) + assert.InDelta(t, cost, hourly.FailedCost, 1e-9) + assert.Equal(t, 0.0, hourly.SuccessCost) + + // Weekly bucket. + var weekly models.WeeklyUsage + err = database.Collection(models.WeeklyUsage{}.CollectionName()).FindOne(ctx, bson.M{ + "user_id": userID, + "project_id": projectID, + "week_bucket": bson.NewDateTimeFromTime(models.TruncateToWeek(now)), + }).Decode(&weekly) + assert.NoError(t, err) + assert.InDelta(t, cost, weekly.FailedCost, 1e-9) + assert.Equal(t, 0.0, weekly.SuccessCost) + + // Lifetime bucket. + var lifetime models.LifetimeUsage + err = database.Collection(models.LifetimeUsage{}.CollectionName()).FindOne(ctx, bson.M{ + "user_id": userID, + "project_id": projectID, + }).Decode(&lifetime) + assert.NoError(t, err) + assert.InDelta(t, cost, lifetime.FailedCost, 1e-9) + assert.Equal(t, 0.0, lifetime.SuccessCost) +} + +// TestTrackUsage_FailedThenSuccess verifies that failed and successful +// completions accumulate into separate fields on the same bucket document. +func TestTrackUsage_FailedThenSuccess(t *testing.T) { + us, database := setupTestUsageService(t) + ctx := context.Background() + + userID := bson.NewObjectID() + projectID := "test-project-" + bson.NewObjectID().Hex() + failedCost := 0.02 + successCost := 0.05 + + t.Cleanup(func() { + filter := bson.M{"user_id": userID, "project_id": projectID} + _, _ = database.Collection(models.HourlyUsage{}.CollectionName()).DeleteMany(ctx, filter) + _, _ = database.Collection(models.WeeklyUsage{}.CollectionName()).DeleteMany(ctx, filter) + _, _ = database.Collection(models.LifetimeUsage{}.CollectionName()).DeleteMany(ctx, filter) + }) + + assert.NoError(t, us.TrackUsage(ctx, userID, projectID, failedCost, false)) + assert.NoError(t, us.TrackUsage(ctx, userID, projectID, successCost, true)) + + var lifetime models.LifetimeUsage + err := database.Collection(models.LifetimeUsage{}.CollectionName()).FindOne(ctx, bson.M{ + "user_id": userID, + "project_id": projectID, + }).Decode(&lifetime) + assert.NoError(t, err) + assert.InDelta(t, failedCost, lifetime.FailedCost, 1e-9) + assert.InDelta(t, successCost, lifetime.SuccessCost, 1e-9) +} + +// TestTrackUsage_ZeroCostNoOp verifies that a zero-cost failed completion +// (e.g., the provider never returned a usage chunk) writes nothing. +func TestTrackUsage_ZeroCostNoOp(t *testing.T) { + us, database := setupTestUsageService(t) + ctx := context.Background() + + userID := bson.NewObjectID() + projectID := "test-project-" + bson.NewObjectID().Hex() + + err := us.TrackUsage(ctx, userID, projectID, 0, false) + assert.NoError(t, err) + + count, err := database.Collection(models.LifetimeUsage{}.CollectionName()).CountDocuments(ctx, bson.M{ + "user_id": userID, + "project_id": projectID, + }) + assert.NoError(t, err) + assert.Equal(t, int64(0), count) +} diff --git a/internal/wire.go b/internal/wire.go index f823bc2e..0ec32146 100644 --- a/internal/wire.go +++ b/internal/wire.go @@ -43,6 +43,7 @@ var Set = wire.NewSet( services.NewProjectService, services.NewPromptService, services.NewOAuthService, + services.NewUsageService, cfg.GetCfg, logger.GetLogger, diff --git a/internal/wire_gen.go b/internal/wire_gen.go index 75c4e91a..20ed866d 100644 --- a/internal/wire_gen.go +++ b/internal/wire_gen.go @@ -38,7 +38,8 @@ func InitializeApp() (*api.Server, error) { aiClient := client.NewAIClient(dbDB, reverseCommentService, projectService, cfgCfg, loggerLogger) chatService := services.NewChatService(dbDB, cfgCfg, loggerLogger) chatServiceServer := chat.NewChatServer(aiClient, chatService, projectService, userService, loggerLogger, cfgCfg) - aiClientV2 := client.NewAIClientV2(dbDB, reverseCommentService, projectService, cfgCfg, loggerLogger) + usageService := services.NewUsageService(dbDB, cfgCfg, loggerLogger) + aiClientV2 := client.NewAIClientV2(dbDB, reverseCommentService, projectService, usageService, cfgCfg, loggerLogger) chatServiceV2 := services.NewChatServiceV2(dbDB, cfgCfg, loggerLogger) chatv2ChatServiceServer := chat.NewChatServerV2(aiClientV2, chatServiceV2, projectService, userService, loggerLogger, cfgCfg) promptService := services.NewPromptService(dbDB, cfgCfg, loggerLogger) @@ -55,4 +56,4 @@ func InitializeApp() (*api.Server, error) { // wire.go: -var Set = wire.NewSet(api.NewServer, api.NewGrpcServer, api.NewGinServer, auth.NewOAuthHandler, auth.NewAuthServer, chat.NewChatServer, chat.NewChatServerV2, user.NewUserServer, project.NewProjectServer, comment.NewCommentServer, client.NewAIClient, client.NewAIClientV2, services.NewReverseCommentService, services.NewChatService, services.NewChatServiceV2, services.NewTokenService, services.NewUserService, services.NewProjectService, services.NewPromptService, services.NewOAuthService, cfg.GetCfg, logger.GetLogger, db.NewDB) +var Set = wire.NewSet(api.NewServer, api.NewGrpcServer, api.NewGinServer, auth.NewOAuthHandler, auth.NewAuthServer, chat.NewChatServer, chat.NewChatServerV2, user.NewUserServer, project.NewProjectServer, comment.NewCommentServer, client.NewAIClient, client.NewAIClientV2, services.NewReverseCommentService, services.NewChatService, services.NewChatServiceV2, services.NewTokenService, services.NewUserService, services.NewProjectService, services.NewPromptService, services.NewOAuthService, services.NewUsageService, cfg.GetCfg, logger.GetLogger, db.NewDB) diff --git a/webapp/_webapp/src/libs/apiclient.ts b/webapp/_webapp/src/libs/apiclient.ts index bf38e694..65343077 100644 --- a/webapp/_webapp/src/libs/apiclient.ts +++ b/webapp/_webapp/src/libs/apiclient.ts @@ -136,6 +136,13 @@ class ApiClient { } catch (error) { if (error instanceof AxiosError) { const errorData = error.response?.data; + if (!errorData || typeof errorData !== "object") { + const message = error.message || "Network error"; + if (!options?.ignoreErrorToast) { + errorToast(message, "Request Failed"); + } + throw new Error(message); + } const errorPayload = fromJson(ErrorSchema, errorData); if (!options?.ignoreErrorToast) { const message = this.cleanErrorMessage(errorPayload.message);