From a512091acba02b6088304a46ef41bff7791d77d3 Mon Sep 17 00:00:00 2001 From: jiahui Date: Sat, 14 Sep 2024 14:40:18 +0800 Subject: [PATCH] fix get payment with invoice arg --- controllers/pkg/database/cockroach/accountv2.go | 8 ++++---- service/account/dao/interface.go | 14 ++++++++++++-- service/account/helper/request.go | 3 +-- service/account/router/router.go | 6 ++++++ 4 files changed, 23 insertions(+), 8 deletions(-) diff --git a/controllers/pkg/database/cockroach/accountv2.go b/controllers/pkg/database/cockroach/accountv2.go index 19233ebd91a..192585a6b4c 100644 --- a/controllers/pkg/database/cockroach/accountv2.go +++ b/controllers/pkg/database/cockroach/accountv2.go @@ -651,7 +651,7 @@ func (c *Cockroach) GetPaymentWithID(paymentID string) (*types.Payment, error) { return &payment, nil } -func (c *Cockroach) GetPaymentWithLimit(ops *types.UserQueryOpts, req types.LimitReq, invoiced bool) ([]types.Payment, types.LimitResp, error) { +func (c *Cockroach) GetPaymentWithLimit(ops *types.UserQueryOpts, req types.LimitReq, invoiced *bool) ([]types.Payment, types.LimitResp, error) { var payment []types.Payment var total int64 var limitResp types.LimitResp @@ -662,10 +662,10 @@ func (c *Cockroach) GetPaymentWithLimit(ops *types.UserQueryOpts, req types.Limi } queryPayment := types.Payment{PaymentRaw: types.PaymentRaw{UserUID: userUID}} - if invoiced { - queryPayment.InvoicedAt = true - } query := c.DB.Model(&types.Payment{}).Where(queryPayment) + if invoiced != nil { + query = query.Where("invoiced_at = ?", *invoiced) + } if !req.StartTime.IsZero() { query = query.Where("created_at >= ?", req.StartTime) } diff --git a/service/account/dao/interface.go b/service/account/dao/interface.go index 2f0d7a17e99..dd392fea5c5 100644 --- a/service/account/dao/interface.go +++ b/service/account/dao/interface.go @@ -1055,11 +1055,21 @@ func (m *MongoDB) getAppStoreList(req helper.GetCostAppListReq, skip, pageSize i return } -func (m *MongoDB) Disconnect(ctx context.Context) error { +func (m *Account) Disconnect(ctx context.Context) error { if m == nil { return nil } - return m.Client.Disconnect(ctx) + if m.MongoDB != nil && m.MongoDB.Client != nil { + if err := m.MongoDB.Client.Disconnect(ctx); err != nil { + return fmt.Errorf("failed to close mongodb client: %v", err) + } + } + if m.Cockroach != nil && m.Cockroach.ck != nil { + if err := m.ck.Close(); err != nil { + return fmt.Errorf("failed to close cockroach client: %v", err) + } + } + return nil } func (m *MongoDB) GetConsumptionAmount(req helper.ConsumptionRecordReq) (int64, error) { diff --git a/service/account/helper/request.go b/service/account/helper/request.go index ffda39c369c..a61612d70dc 100644 --- a/service/account/helper/request.go +++ b/service/account/helper/request.go @@ -142,7 +142,7 @@ type GetPaymentReq struct { // @Summary Invoiced // @Description Invoiced // @JSONSchema - Invoiced bool `json:"invoiced,omitempty" bson:"invoiced" example:"true"` + Invoiced *bool `json:"invoiced,omitempty" bson:"invoiced" example:"true"` // @Summary Authentication information // @Description Authentication information @@ -388,7 +388,6 @@ func ParseAppCostsReq(c *gin.Context) (*AppCostsReq, error) { return nil, fmt.Errorf("bind json error: %v", err) } setDefaultTimeRange(&userCosts.TimeRange) - userCosts.Owner = strings.TrimPrefix(userCosts.Owner, "ns-") return userCosts, nil } diff --git a/service/account/router/router.go b/service/account/router/router.go index 653199e8b11..9f57c470d4d 100644 --- a/service/account/router/router.go +++ b/service/account/router/router.go @@ -1,6 +1,7 @@ package router import ( + "context" "fmt" "log" "os" @@ -29,6 +30,11 @@ func RegisterPayRouter() { if err := dao.InitDB(); err != nil { log.Fatalf("Error initializing database: %v", err) } + defer func() { + if err := dao.DBClient.Disconnect(context.Background()); err != nil { + log.Fatalf("Error disconnecting database: %v", err) + } + }() // /account/v1alpha1/{/namespaces | /properties | {/costs | /costs/recharge | /costs/consumption | /costs/properties}} router.Group(helper.GROUP). POST(helper.GetHistoryNamespaces, api.GetBillingHistoryNamespaceList).