Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 50 additions & 1 deletion api/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,13 @@ type CreateEndpointInput struct {
WorkersMin int `json:"workersMin"`
WorkersMax int `json:"workersMax"`
FlashBootType string `json:"flashBootType"`
ModelReferences []string `json:"modelReferences"`
}

// there are many more fields in the result of the query but I just care about these for CLI port
type Endpoint struct {
Name string `json:"name"`
Id string
Id string `json:"id"`
}
type EndpointOut struct {
Data *EndpointData `json:"data"`
Expand Down Expand Up @@ -225,6 +226,53 @@ func UpdateEndpointTemplate(endpointId string, templateId string) (err error) {
return
}

func UpdateEndpointModel(endpointId string, endpointName string, modelReferences []string) (err error) {
input := Input{
Query: `
mutation saveEndpoint($input: EndpointInput!) {
saveEndpoint(input: $input) {
id
modelReferences
}
}
`,
Variables: map[string]interface{}{"input": map[string]interface{}{
"id": endpointId,
"name": endpointName,
"modelReferences": modelReferences,
}},
}
res, err := Query(input)
if err != nil {
return
}
defer res.Body.Close()
rawData, err := io.ReadAll(res.Body)
if err != nil {
return
}
if res.StatusCode != 200 {
err = fmt.Errorf("statuscode %d: %s", res.StatusCode, string(rawData))
return
}
data := make(map[string]interface{})
if err = json.Unmarshal(rawData, &data); err != nil {
return
}
gqlErrors, ok := data["errors"].([]interface{})
if ok && len(gqlErrors) > 0 {
firstErr, _ := gqlErrors[0].(map[string]interface{})
err = errors.New(firstErr["message"].(string))
return
}
gqldata, ok := data["data"].(map[string]interface{})
if !ok || gqldata == nil {
err = fmt.Errorf("data is nil: %s", string(rawData))
return
}
return
}

func GetEndpoints() (endpoints []*Endpoint, err error) {
input := Input{
Query: `
Expand All @@ -248,6 +296,7 @@ func GetEndpoints() (endpoints []*Endpoint, err error) {
workersMin
workersStandby
gpuCount
modelReferences
env {
key
value
Expand Down
15 changes: 15 additions & 0 deletions cmd/endpoint/endpoint.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package endpoint

import (
"github.com/spf13/cobra"
)

var Cmd = &cobra.Command{
Use: "endpoint",
Short: "manage serverless endpoints",
Long: "manage serverless endpoints on runpod",
}

func init() {
Cmd.AddCommand(modelCmd)
}
63 changes: 63 additions & 0 deletions cmd/endpoint/update_model.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package endpoint

import (
"fmt"

"github.com/runpod/runpodctl/api"
"github.com/spf13/cobra"
)

var clearModels bool

var modelCmd = &cobra.Command{
Use: "model <endpoint-id> [model-ref...]",
Short: "update model references on an endpoint",
Long: "set the model references (cached models) for a serverless endpoint; use --clear to remove all cached models",
Args: func(cmd *cobra.Command, args []string) error {
if clearModels {
return cobra.ExactArgs(1)(cmd, args)
}
return cobra.MinimumNArgs(2)(cmd, args)
},
RunE: runModel,
}

func init() {
modelCmd.Flags().BoolVar(&clearModels, "clear", false, "remove all model references from the endpoint")
}

func runModel(cmd *cobra.Command, args []string) error {
endpointID := args[0]

var modelRefs []string
if !clearModels {
modelRefs = args[1:]
}

endpoints, err := api.GetEndpoints()
if err != nil {
return fmt.Errorf("failed to list endpoints: %w", err)
}

var endpointName string
for _, ep := range endpoints {
if ep.Id == endpointID {
endpointName = ep.Name
break
}
}
if endpointName == "" {
return fmt.Errorf("endpoint %s not found", endpointID)
}

if err := api.UpdateEndpointModel(endpointID, endpointName, modelRefs); err != nil {
return fmt.Errorf("failed to update endpoint model: %w", err)
}

if clearModels {
fmt.Printf("cleared model references for endpoint %s (%s)\n", endpointName, endpointID)
} else {
fmt.Printf("updated model references for endpoint %s (%s)\n", endpointName, endpointID)
}
return nil
}
9 changes: 9 additions & 0 deletions cmd/project/functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,7 @@ func deployProject(networkVolumeId string) (endpointId string, err error) {
flashBootType := "FLASHBOOT"
idleTimeout := 5
endpointConfig, ok := config.Get("endpoint").(*toml.Tree)
var modelRefs []string
if ok {
if min, ok := endpointConfig.Get("active_workers").(int64); ok {
minWorkers = int(min)
Expand All @@ -597,6 +598,13 @@ func deployProject(networkVolumeId string) (endpointId string, err error) {
if idle, ok := endpointConfig.Get("idle_timeout").(int64); ok {
idleTimeout = int(idle)
}
if refs, ok := endpointConfig.Get("model_refs").([]interface{}); ok {
for _, r := range refs {
if str, ok := r.(string); ok {
modelRefs = append(modelRefs, str)
}
}
}
}
if err != nil {
deployedEndpointId, err = api.CreateEndpoint(&api.CreateEndpointInput{
Expand All @@ -610,6 +618,7 @@ func deployProject(networkVolumeId string) (endpointId string, err error) {
WorkersMin: minWorkers,
WorkersMax: maxWorkers,
FlashBootType: flashBootType,
ModelReferences: modelRefs,
})
if err != nil {
fmt.Println("error making endpoint")
Expand Down
5 changes: 5 additions & 0 deletions cmd/project/tomlBuilder.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ active_workers = 0
max_workers = 3
flashboot = true

# model_refs - List of model references to cache on the endpoint workers.
# Format: "owner/model-name" or "owner/model-name:branch".
# Example: ["runpod/stable-diffusion-v1-5", "meta-llama/Llama-2-7b-chat-hf"]
# model_refs = []

[runtime]
# python_version - Python version to use for the project.
#
Expand Down
2 changes: 2 additions & 0 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/runpod/runpodctl/cmd/billing"
"github.com/runpod/runpodctl/cmd/config"
"github.com/runpod/runpodctl/cmd/datacenter"
"github.com/runpod/runpodctl/cmd/endpoint"
"github.com/runpod/runpodctl/cmd/doctor"
"github.com/runpod/runpodctl/cmd/gpu"
"github.com/runpod/runpodctl/cmd/hub"
Expand Down Expand Up @@ -86,6 +87,7 @@ func registerCommands() {
rootCmd.AddCommand(pod.Cmd)
rootCmd.AddCommand(serverless.Cmd)
rootCmd.AddCommand(template.Cmd)
rootCmd.AddCommand(endpoint.Cmd)
rootCmd.AddCommand(model.Cmd)
rootCmd.AddCommand(volume.Cmd)
rootCmd.AddCommand(registry.Cmd)
Expand Down
Loading