Skip to content

Commit

Permalink
chore: gopool
Browse files Browse the repository at this point in the history
  • Loading branch information
Calcium-Ion committed Jul 18, 2024
1 parent c9100b2 commit e84300f
Show file tree
Hide file tree
Showing 14 changed files with 77 additions and 115 deletions.
5 changes: 3 additions & 2 deletions controller/channel-test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/bytedance/gopkg/util/gopool"
"io"
"math"
"net/http"
Expand Down Expand Up @@ -217,7 +218,7 @@ func testAllChannels(notify bool) error {
if disableThreshold == 0 {
disableThreshold = 10000000 // a impossible value
}
go func() {
gopool.Go(func() {
for _, channel := range channels {
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
tik := time.Now()
Expand Down Expand Up @@ -265,7 +266,7 @@ func testAllChannels(notify bool) error {
common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
}
}
}()
})
return nil
}

Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ require (
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect
github.com/aws/smithy-go v1.20.2 // indirect
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b // indirect
github.com/bytedance/sonic v1.9.1 // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
Expand Down
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4 h1:JgHnonzbnA3pbqj76w
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4/go.mod h1:nZspkhg+9p8iApLFoyAqfyuMP0F38acy2Hm3r5r95Cg=
github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q=
github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E=
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b h1:LTGVFpNmNHhj0vhOlfgWueFJ32eK9blaIlHR2ciXOT0=
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b/go.mod h1:2ZlV9BaUH4+NXIBF0aMdKKAnHTzqH+iMU4KUjAbL23Q=
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
Expand Down Expand Up @@ -198,6 +200,7 @@ golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
Expand All @@ -206,6 +209,7 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220110181412-a018aaa089fe/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
Expand Down
5 changes: 3 additions & 2 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package main
import (
"embed"
"fmt"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-contrib/sessions"
"github.com/gin-contrib/sessions/cookie"
"github.com/gin-gonic/gin"
Expand Down Expand Up @@ -91,10 +92,10 @@ func main() {
go controller.AutomaticallyTestChannels(frequency)
}
if common.IsMasterNode && constant.UpdateTask {
common.SafeGoroutine(func() {
gopool.Go(func() {
controller.UpdateMidjourneyTaskBulk()
})
common.SafeGoroutine(func() {
gopool.Go(func() {
controller.UpdateTaskBulk()
})
}
Expand Down
3 changes: 2 additions & 1 deletion model/log.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package model
import (
"context"
"fmt"
"github.com/bytedance/gopkg/util/gopool"
"gorm.io/gorm"
"one-api/common"
"strings"
Expand Down Expand Up @@ -87,7 +88,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
common.LogError(ctx, "failed to record log: "+err.Error())
}
if common.DataExportEnabled {
common.SafeGoroutine(func() {
gopool.Go(func() {
LogQuotaData(userId, username, modelName, quota, common.GetTimestamp(), promptTokens+completionTokens)
})
}
Expand Down
5 changes: 3 additions & 2 deletions model/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package model

import (
"errors"
"github.com/bytedance/gopkg/util/gopool"
"gorm.io/gorm"
"one-api/common"
"sync"
Expand All @@ -28,12 +29,12 @@ func init() {
}

func InitBatchUpdater() {
go func() {
gopool.Go(func() {
for {
time.Sleep(time.Duration(common.BatchUpdateInterval) * time.Second)
batchUpdate()
}
}()
})
}

func addNewRecord(type_ int, id int, value int) {
Expand Down
2 changes: 1 addition & 1 deletion relay/channel/ali/adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
err, usage = aliEmbeddingHandler(c, resp)
default:
if info.IsStream {
err, usage = openai.OpenaiStreamHandler(c, resp, info)
err, usage = openai.OaiStreamHandler(c, resp, info)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
Expand Down
129 changes: 46 additions & 83 deletions relay/channel/claude/relay-claude.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,10 @@ import (
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
"time"
)

func stopReasonClaude2OpenAI(reason string) string {
Expand Down Expand Up @@ -332,91 +330,59 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
responseText := ""
createdTime := common.GetTimestamp()
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
scanner.Split(bufio.ScanLines)
service.SetEventStreamHeaders(c)

for scanner.Scan() {
data := scanner.Text()
info.SetFirstResponseTime()
if len(data) < 6 || !strings.HasPrefix(data, "data:") {
continue
}
if atEOF {
return len(data), data, nil
data = strings.TrimPrefix(data, "data:")
data = strings.TrimSpace(data)
var claudeResponse ClaudeResponse
err := json.Unmarshal([]byte(data), &claudeResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
continue
}
return 0, nil, nil
})
dataChan := make(chan string, 5)
stopChan := make(chan bool, 2)
go func() {
for scanner.Scan() {
data := scanner.Text()
if !strings.HasPrefix(data, "data: ") {
continue
}
data = strings.TrimPrefix(data, "data: ")
if !common.SafeSendStringTimeout(dataChan, data, constant.StreamingTimeout) {
// send data timeout, stop the stream
common.LogError(c, "send data timeout, stop the stream")
break
}

response, claudeUsage := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
if response == nil {
continue
}
stopChan <- true
}()
isFirst := true
service.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
if isFirst {
isFirst = false
info.FirstResponseTime = time.Now()
}
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r")
var claudeResponse ClaudeResponse
err := json.Unmarshal([]byte(data), &claudeResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
if requestMode == RequestModeCompletion {
responseText += claudeResponse.Completion
responseId = response.Id
} else {
if claudeResponse.Type == "message_start" {
// message_start, 获取usage
responseId = claudeResponse.Message.Id
info.UpstreamModelName = claudeResponse.Message.Model
usage.PromptTokens = claudeUsage.InputTokens
} else if claudeResponse.Type == "content_block_delta" {
responseText += claudeResponse.Delta.Text
} else if claudeResponse.Type == "message_delta" {
usage.CompletionTokens = claudeUsage.OutputTokens
usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens
} else if claudeResponse.Type == "content_block_start" {

response, claudeUsage := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
if response == nil {
return true
}
if requestMode == RequestModeCompletion {
responseText += claudeResponse.Completion
responseId = response.Id
} else {
if claudeResponse.Type == "message_start" {
// message_start, 获取usage
responseId = claudeResponse.Message.Id
info.UpstreamModelName = claudeResponse.Message.Model
usage.PromptTokens = claudeUsage.InputTokens
} else if claudeResponse.Type == "content_block_delta" {
responseText += claudeResponse.Delta.Text
} else if claudeResponse.Type == "message_delta" {
usage.CompletionTokens = claudeUsage.OutputTokens
usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens
} else if claudeResponse.Type == "content_block_start" {

} else {
return true
}
continue
}
//response.Id = responseId
response.Id = responseId
response.Created = createdTime
response.Model = info.UpstreamModelName
}
//response.Id = responseId
response.Id = responseId
response.Created = createdTime
response.Model = info.UpstreamModelName

err = service.ObjectData(c, response)
if err != nil {
common.SysError(err.Error())
}
return true
case <-stopChan:
return false
err = service.ObjectData(c, response)
if err != nil {
common.LogError(c, "send_stream_response_failed: "+err.Error())
}
})
}

if requestMode == RequestModeCompletion {
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
Expand All @@ -435,10 +401,7 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
}
}
service.Done(c)
err := resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
resp.Body.Close()
return nil, usage
}

Expand Down
2 changes: 1 addition & 1 deletion relay/channel/ollama/adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request

func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = openai.OpenaiStreamHandler(c, resp, info)
err, usage = openai.OaiStreamHandler(c, resp, info)
} else {
if info.RelayMode == relayconstant.RelayModeEmbeddings {
err, usage = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
Expand Down
2 changes: 1 addition & 1 deletion relay/channel/openai/adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
err, usage = OpenaiTTSHandler(c, resp, info)
default:
if info.IsStream {
err, usage = OpenaiStreamHandler(c, resp, info)
err, usage = OaiStreamHandler(c, resp, info)
} else {
err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
Expand Down
17 changes: 9 additions & 8 deletions relay/channel/openai/relay-openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"bytes"
"encoding/json"
"fmt"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
"io"
"net/http"
Expand All @@ -18,8 +19,8 @@ import (
"time"
)

func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
hasStreamUsage := false
func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
containStreamUsage := false
responseId := ""
var createAt int64 = 0
var systemFingerprint string
Expand All @@ -41,7 +42,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
stopChan := make(chan bool)
defer close(stopChan)

go func() {
gopool.Go(func() {
for scanner.Scan() {
info.SetFirstResponseTime()
ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second)
Expand All @@ -62,7 +63,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
}
}
common.SafeSendBool(stopChan, true)
}()
})

select {
case <-ticker.C:
Expand Down Expand Up @@ -91,7 +92,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
model = streamResponse.Model
if service.ValidUsage(streamResponse.Usage) {
usage = streamResponse.Usage
hasStreamUsage = true
containStreamUsage = true
}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
Expand All @@ -115,7 +116,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
model = streamResponse.Model
if service.ValidUsage(streamResponse.Usage) {
usage = streamResponse.Usage
hasStreamUsage = true
containStreamUsage = true
}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
Expand Down Expand Up @@ -155,12 +156,12 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
}
}

if !hasStreamUsage {
if !containStreamUsage {
usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7
}

if info.ShouldIncludeUsage && !hasStreamUsage {
if info.ShouldIncludeUsage && !containStreamUsage {
response := service.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
response.SetSystemFingerprint(systemFingerprint)
service.ObjectData(c, response)
Expand Down
2 changes: 1 addition & 1 deletion relay/channel/perplexity/adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request

func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = openai.OpenaiStreamHandler(c, resp, info)
err, usage = openai.OaiStreamHandler(c, resp, info)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
Expand Down
Loading

0 comments on commit e84300f

Please sign in to comment.