Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 115 additions & 12 deletions backend/bizpkg/config/modelmgr/deprecate_model_get.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"os"
"path/filepath"
"strings"
"sync"
"time"

"github.com/cloudwego/eino-ext/components/model/deepseek"
Expand All @@ -37,13 +38,17 @@ import (
"github.com/coze-dev/coze-studio/backend/pkg/kvstore"
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
"github.com/coze-dev/coze-studio/backend/pkg/logs"
"github.com/coze-dev/coze-studio/backend/pkg/safego"
"github.com/coze-dev/coze-studio/backend/types/consts"

"google.golang.org/genai"
"gopkg.in/yaml.v3"
)

var oldModels []*Model
var (
oldModels []*Model
oldModelsMu sync.RWMutex
)

func initOldModelConf(ctx context.Context, oss storage.Storage, c *ModelConfig) error {
wd, err := os.Getwd()
Expand All @@ -67,17 +72,6 @@ func initOldModelConf(ctx context.Context, oss storage.Storage, c *ModelConfig)
oldModels = append(oldModels, envModel)
}

for _, q := range oldModels {
if q.Provider.IconURI != "" {
url, err := oss.GetObjectUrl(ctx, q.Provider.IconURI)
if err != nil {
logs.CtxWarnf(ctx, "get model icon url failed, err: %v", err)
} else {
q.Provider.IconURL = url
}
}
}

for _, old := range oldModels {
if old.ID <= 0 {
logs.CtxWarnf(ctx, "model id is invalid, model: %v", old.ID)
Expand All @@ -104,9 +98,118 @@ func initOldModelConf(ctx context.Context, oss storage.Storage, c *ModelConfig)
return fmt.Errorf("get model by id failed, err: %w", err)
}

refreshOldModelIconURLCache(ctx, oss)
startOldModelIconURLCacheRefresh(ctx, oss)

return nil
}

func cloneOldModel(m *Model) *Model {
if m == nil || m.Model == nil {
return nil
}

clonedModel := *m.Model
if m.Model.Provider != nil {
clonedProvider := *m.Model.Provider
clonedModel.Provider = &clonedProvider
}

return &Model{Model: &clonedModel}
}

func getOldModelsForResponse() []*Model {
oldModelsMu.RLock()
defer oldModelsMu.RUnlock()

result := make([]*Model, 0, len(oldModels))
for _, old := range oldModels {
cloned := cloneOldModel(old)
if cloned == nil {
continue
}
result = append(result, cloned)
}

return result
}

func getOldModelsForResponseWithLimit(limit int) []*Model {
if limit <= 0 {
return []*Model{}
}

oldModelsMu.RLock()
defer oldModelsMu.RUnlock()

if limit > len(oldModels) {
limit = len(oldModels)
}

result := make([]*Model, 0, limit)
for idx := 0; idx < len(oldModels) && len(result) < limit; idx++ {
cloned := cloneOldModel(oldModels[idx])
if cloned == nil {
continue
}
result = append(result, cloned)
}

return result
}

func getOldModelByIDForResponse(id int64) (*Model, bool) {
oldModelsMu.RLock()
defer oldModelsMu.RUnlock()

for _, old := range oldModels {
if old == nil || old.Model == nil {
continue
}
if old.ID == id {
cloned := cloneOldModel(old)
if cloned == nil {
return nil, false
}
return cloned, true
}
}

return nil, false
}

func startOldModelIconURLCacheRefresh(ctx context.Context, oss storage.Storage) {
safego.Go(ctx, func() {
ticker := time.NewTicker(24 * time.Hour)
defer ticker.Stop()
for {
select {
case <-ticker.C:
refreshOldModelIconURLCache(context.Background(), oss)
}
}
})
}

func refreshOldModelIconURLCache(ctx context.Context, oss storage.Storage) {
oldModelsMu.Lock()
defer oldModelsMu.Unlock()

for _, q := range oldModels {
if q == nil || q.Provider == nil || q.Provider.IconURI == "" {
continue
}

url, err := oss.GetObjectUrl(ctx, q.Provider.IconURI)
if err != nil {
logs.CtxWarnf(ctx, "get model icon url failed, err: %v", err)
continue
}

q.Provider.IconURL = url
}
}

func initModelByEnv() (*Model, error) {
if os.Getenv("MODEL_PROTOCOL_0") == "" || os.Getenv("MODEL_OPENCOZE_ID_0") == "" {
return nil, nil
Expand Down
22 changes: 8 additions & 14 deletions backend/bizpkg/config/modelmgr/model_get.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func (c *ModelConfig) getModelList(ctx context.Context, includeDeleteModel bool)
}

if useOldModel {
return oldModels, nil
return getOldModelsForResponse(), nil
}

var allModels []*model.ModelInstance
Expand Down Expand Up @@ -126,10 +126,7 @@ func (c *ModelConfig) GetOnlineModelListWithLimit(ctx context.Context, limit int
}

if useOldModel {
if limit > len(oldModels) {
limit = len(oldModels)
}
return oldModels[:limit], nil
return getOldModelsForResponseWithLimit(limit), nil
}

allModels, err := query.ModelInstance.WithContext(ctx).Limit(limit).Find()
Expand All @@ -155,11 +152,9 @@ func (c *ModelConfig) MGetModelByID(ctx context.Context, ids []int64) ([]*Model,
if useOldModel {
modelList := make([]*Model, 0, len(ids))
for _, id := range ids {
for _, old := range oldModels {
if old.ID == id {
modelList = append(modelList, old)
break
}
old, ok := getOldModelByIDForResponse(id)
if ok {
modelList = append(modelList, old)
}
}
return modelList, nil
Expand Down Expand Up @@ -188,10 +183,9 @@ func (c *ModelConfig) GetModelByID(ctx context.Context, modelID int64) (*Model,
}

if useOldModel {
for _, old := range oldModels {
if old.ID == modelID {
return old, nil
}
old, ok := getOldModelByIDForResponse(modelID)
if ok {
return old, nil
}
return nil, fmt.Errorf("model %d not found", modelID)
}
Expand Down