diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4660bf34..870751bf 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -5,10 +5,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - name: Setup Go 1.26.3 + - name: Setup Go 1.26.4 uses: actions/setup-go@v5 with: - go-version: 1.26.3 + go-version: 1.26.4 # You can test your matrix by printing the current Go version - name: Display Go version run: go version diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 3e666ed5..708f74dc 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -34,7 +34,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v5 with: - go-version: "1.26.3" + go-version: "1.26.4" - name: Run GoReleaser uses: goreleaser/goreleaser-action@v6 diff --git a/api/model.go b/api/model.go index a9b4e928..c0637c11 100644 --- a/api/model.go +++ b/api/model.go @@ -290,7 +290,7 @@ func AddModelToRepo(input *AddModelToRepoInput) (*Model, error) { if err != nil { return nil, err } - if res.StatusCode != 200 { + if res.StatusCode != http.StatusOK { return nil, fmt.Errorf("statuscode %d: %s", res.StatusCode, string(rawData)) } @@ -374,7 +374,7 @@ query %s { if err != nil { return nil, err } - if res.StatusCode != 200 { + if res.StatusCode != http.StatusOK { return nil, fmt.Errorf("statuscode %d: %s", res.StatusCode, string(rawData)) } @@ -404,7 +404,7 @@ query %s { } if models == nil { - return nil, fmt.Errorf("data is nil: %s", string(rawData)) + models = []*Model{} } if input != nil { if input.Provider != "" { @@ -492,7 +492,7 @@ func GetModel(input *GetModelInput) (*Model, error) { if err != nil { return nil, err } - if res.StatusCode != 200 { + if res.StatusCode != http.StatusOK { return nil, fmt.Errorf("statuscode %d: %s", res.StatusCode, string(rawData)) } @@ -580,7 +580,7 @@ func RemoveModel(input *RemoveModelInput) (*ModelRepoMutationResult, error) { if err != nil { return nil, err } - if res.StatusCode != 200 { + if res.StatusCode != http.StatusOK { return nil, fmt.Errorf("statuscode %d: %s", res.StatusCode, string(rawData)) } @@ -714,7 +714,7 @@ func CreateModelRepoUpload(input *CreateModelRepoUploadInput) (*ModelRepoMutatio if err != nil { return nil, err } - if res.StatusCode != 200 { + if res.StatusCode != http.StatusOK { return nil, fmt.Errorf("statuscode %d: %s", res.StatusCode, string(rawData)) } @@ -863,7 +863,7 @@ func UpdateModelVersionStatus(hash, status string) (*ModelVersion, error) { if err != nil { return nil, err } - if res.StatusCode != 200 { + if res.StatusCode != http.StatusOK { return nil, fmt.Errorf("statuscode %d: %s", res.StatusCode, string(rawData)) } diff --git a/cmd/model/addModelToRepo.go b/cmd/model/addModelToRepo.go index c0881166..ca99dad4 100644 --- a/cmd/model/addModelToRepo.go +++ b/cmd/model/addModelToRepo.go @@ -56,7 +56,7 @@ func setModelGraphQLTimeout(cmd *cobra.Command) { } viper.Set(api.GraphQLTimeoutKey, modelGraphQLTimeoutValue) - fmt.Printf("--graphql-timeout not set; defaulting to %s for model creation operations\n", modelGraphQLTimeoutValue) + fmt.Fprintf(os.Stderr, "--graphql-timeout not set; defaulting to %s for model creation operations\n", modelGraphQLTimeoutValue) return } @@ -68,7 +68,7 @@ func setModelGraphQLTimeout(cmd *cobra.Command) { } viper.Set(api.GraphQLTimeoutKey, modelGraphQLTimeoutValue) - fmt.Printf("defaulting graphql timeout to %s for model creation operations\n", modelGraphQLTimeoutValue) + fmt.Fprintf(os.Stderr, "defaulting graphql timeout to %s for model creation operations\n", modelGraphQLTimeoutValue) } type completedPart struct { @@ -207,7 +207,7 @@ func runAddModel(cmd *cobra.Command, args []string) { uploadInput.Name = addModelName if len(modelFiles) > 0 { - err := uploadModelFiles(modelFiles, uploadInput) + err = uploadModelFiles(modelFiles, uploadInput) cobra.CheckErr(err) return } @@ -400,7 +400,7 @@ func completeModelUpload(upload *api.ModelRepoUpload, artifactPath string) error err = fmt.Errorf("upload part %d missing ETag", part.PartNumber) return } - completed = append(completed, completedPart{PartNumber: part.PartNumber, ETag: fmt.Sprintf("\"%s\"", etag)}) + completed = append(completed, completedPart{PartNumber: part.PartNumber, ETag: fmt.Sprintf("%q", etag)}) }() if err != nil { return err diff --git a/cmd/model/addModelToRepo_timeout_test.go b/cmd/model/addModelToRepo_timeout_test.go index f3a28417..757e85d8 100644 --- a/cmd/model/addModelToRepo_timeout_test.go +++ b/cmd/model/addModelToRepo_timeout_test.go @@ -13,11 +13,21 @@ func TestSetModelGraphQLTimeoutWithoutInheritedFlag(t *testing.T) { viper.Reset() cmd := &cobra.Command{Use: "add"} - setModelGraphQLTimeout(cmd) + // CLAUDE.md: informational output must not corrupt stdout for JSON + // consumers — the "defaulting graphql timeout" notice must land on stderr. + stdout, stderr := captureStdStreams(t, func() { + setModelGraphQLTimeout(cmd) + }) if got := viper.GetDuration(api.GraphQLTimeoutKey); got != modelGraphQLTimeoutValue { t.Fatalf("expected graphql timeout %s, got %s", modelGraphQLTimeoutValue, got) } + if stdout != "" { + t.Fatalf("stdout must remain empty, got %q", stdout) + } + if stderr == "" { + t.Fatal("expected timeout-default notice on stderr, got empty") + } } func TestSetModelGraphQLTimeoutRespectsExistingConfiguredValue(t *testing.T) { diff --git a/cmd/model/errors.go b/cmd/model/errors.go index f03caf86..7c8da2fe 100644 --- a/cmd/model/errors.go +++ b/cmd/model/errors.go @@ -2,10 +2,10 @@ package model import ( "errors" - "fmt" "strings" "github.com/runpod/runpodctl/api" + "github.com/runpod/runpodctl/internal/output" ) func handleModelRepoError(err error) bool { @@ -13,11 +13,11 @@ func handleModelRepoError(err error) bool { return false } if errors.Is(err, api.ErrModelRepoNotImplemented) { - fmt.Println(api.ErrModelRepoNotImplemented.Error()) + output.Error(api.ErrModelRepoNotImplemented) return true } if strings.Contains(err.Error(), "Model Repo feature is not enabled for this user") { - fmt.Println(err.Error()) + output.Error(err) return true } return false diff --git a/cmd/model/errors_test.go b/cmd/model/errors_test.go new file mode 100644 index 00000000..ddaba2d0 --- /dev/null +++ b/cmd/model/errors_test.go @@ -0,0 +1,122 @@ +package model + +import ( + "errors" + "io" + "os" + "strings" + "sync" + "testing" + + "github.com/runpod/runpodctl/api" +) + +// captureStdStreams runs fn with os.Stdout and os.Stderr replaced by pipes and +// returns whatever each stream received. It exists to assert that informational +// and error output goes to the correct stream — a regression class CLAUDE.md +// explicitly calls out (legacy commands losing stderr ⇒ corrupts stdout for +// JSON-consuming agents). +func captureStdStreams(t *testing.T, fn func()) (stdout, stderr string) { + t.Helper() + + origStdout, origStderr := os.Stdout, os.Stderr + stdoutR, stdoutW, err := os.Pipe() + if err != nil { + t.Fatalf("os.Pipe stdout: %v", err) + } + stderrR, stderrW, err := os.Pipe() + if err != nil { + t.Fatalf("os.Pipe stderr: %v", err) + } + os.Stdout, os.Stderr = stdoutW, stderrW + + var ( + wg sync.WaitGroup + stdoutBuf, stderrBuf strings.Builder + ) + wg.Add(2) + go func() { + defer wg.Done() + _, _ = io.Copy(&stdoutBuf, stdoutR) + }() + go func() { + defer wg.Done() + _, _ = io.Copy(&stderrBuf, stderrR) + }() + + defer func() { + os.Stdout, os.Stderr = origStdout, origStderr + }() + + fn() + + _ = stdoutW.Close() + _ = stderrW.Close() + wg.Wait() + _ = stdoutR.Close() + _ = stderrR.Close() + + return stdoutBuf.String(), stderrBuf.String() +} + +func TestHandleModelRepoError(t *testing.T) { + tests := []struct { + name string + err error + wantHandled bool + wantStderrSub string // empty = stderr must be empty + wantStdoutSub string // empty = stdout must be empty + }{ + { + name: "nil error is a no-op", + err: nil, + wantHandled: false, + }, + { + name: "ErrModelRepoNotImplemented routes to stderr", + err: api.ErrModelRepoNotImplemented, + wantHandled: true, + wantStderrSub: api.ErrModelRepoNotImplemented.Error(), + }, + { + name: "feature-not-enabled message routes to stderr", + err: errors.New("Model Repo feature is not enabled for this user"), + wantHandled: true, + wantStderrSub: "Model Repo feature is not enabled for this user", + }, + { + name: "unrelated error is not handled", + err: errors.New("some other failure"), + wantHandled: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var handled bool + stdout, stderr := captureStdStreams(t, func() { + handled = handleModelRepoError(tt.err) + }) + + if handled != tt.wantHandled { + t.Fatalf("handled = %v, want %v", handled, tt.wantHandled) + } + + // CLAUDE.md: deprecation warnings / handled errors must go to stderr + // only; stdout must stay clean for JSON-consuming agents. + if stdout != "" { + t.Fatalf("stdout must remain empty, got %q", stdout) + } + + if tt.wantStderrSub == "" { + if stderr != "" { + t.Fatalf("expected empty stderr, got %q", stderr) + } + return + } + if !strings.Contains(stderr, tt.wantStderrSub) { + t.Fatalf("stderr = %q, want substring %q", stderr, tt.wantStderrSub) + } + }) + } +} diff --git a/cmd/model/getModels_test.go b/cmd/model/getModels_test.go index a558c06d..6d21a7f9 100644 --- a/cmd/model/getModels_test.go +++ b/cmd/model/getModels_test.go @@ -1,7 +1,9 @@ package model import ( + "strconv" "testing" + "time" "github.com/runpod/runpodctl/api" ) @@ -22,6 +24,23 @@ func TestModelVersionHash(t *testing.T) { }}, want: "hash-1", }, + { + name: "skips nil version entries", + model: &api.Model{Versions: []*api.ModelVersion{ + nil, + {Hash: "hash-after-nil"}, + }}, + want: "hash-after-nil", + }, + { + name: "all hashes blank or whitespace", + model: &api.Model{Versions: []*api.ModelVersion{ + {Hash: ""}, + {Hash: " "}, + {Hash: "\t\n"}, + }}, + want: "", + }, { name: "no versions", model: &api.Model{}, @@ -42,3 +61,94 @@ func TestModelVersionHash(t *testing.T) { }) } } + +func TestValueOrDash(t *testing.T) { + tests := []struct { + name string + in string + want string + }{ + {name: "empty -> dash", in: "", want: "-"}, + {name: "whitespace only -> dash", in: " ", want: "-"}, + {name: "tabs and newlines -> dash", in: "\t\n ", want: "-"}, + {name: "value is preserved", in: "hello", want: "hello"}, + {name: "value with surrounding whitespace is trimmed", in: " hi ", want: "hi"}, + {name: "internal whitespace preserved", in: "a b c", want: "a b c"}, + {name: "single char", in: "x", want: "x"}, + {name: "unicode preserved", in: "café", want: "café"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := valueOrDash(tt.in); got != tt.want { + t.Fatalf("valueOrDash(%q) = %q, want %q", tt.in, got, tt.want) + } + }) + } +} + +func TestFormatTimestamp(t *testing.T) { + // Reference epoch values for 2021-01-01T00:00:00Z (chosen for stable RFC3339 + // formatting across precision tiers, no fractional seconds). + const ( + secs = int64(1609459200) // 10 digits + millis = int64(1609459200000) // 13 digits + micros = int64(1609459200000000) // 16 digits + nanos = int64(1609459200000000000) // 19 digits + want = "2021-01-01T00:00:00Z" + ) + + tests := []struct { + name string + in string + want string + }{ + // positive: each precision tier should round to the same RFC3339 second. + {name: "seconds (10 digits)", in: "1609459200", want: want}, + {name: "milliseconds (13 digits)", in: "1609459200000", want: want}, + {name: "microseconds (16 digits)", in: "1609459200000000", want: want}, + {name: "nanoseconds (19 digits)", in: "1609459200000000000", want: want}, + + // positive: whitespace is trimmed before parsing. + {name: "trims surrounding whitespace", in: " 1609459200 ", want: want}, + + // negative: empty / blank renders as dash. + {name: "empty -> dash", in: "", want: "-"}, + {name: "whitespace only -> dash", in: " ", want: "-"}, + + // negative: non-numeric is passed through unchanged (ISO strings from + // the API should not be mangled). + {name: "ISO 8601 passthrough", in: "2024-06-10T12:00:00Z", want: "2024-06-10T12:00:00Z"}, + {name: "garbage passthrough", in: "not-a-number", want: "not-a-number"}, + + // corner: epoch zero across precisions. + {name: "epoch zero seconds", in: "0", want: "1970-01-01T00:00:00Z"}, + + // corner: negative timestamp (pre-epoch). + {name: "negative seconds (pre-epoch)", in: "-1", want: "1969-12-31T23:59:59Z"}, + + // boundary: default-branch path. A length-11 millisecond timestamp + // (e.g. 99999999999 ms) is < 1e12 so the default branch treats it as + // seconds — documenting the current (lossy) behavior so a future change + // of the heuristic is intentional. + {name: "11-digit value falls through default as seconds", in: "10000000000", want: time.Unix(10000000000, 0).UTC().Format(time.RFC3339)}, + {name: "12-digit value above default threshold treated as ms", in: "1000000000001", want: time.Unix(1000000000, 1*int64(time.Millisecond)).UTC().Format(time.RFC3339)}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := formatTimestamp(tt.in); got != tt.want { + t.Fatalf("formatTimestamp(%q) = %q, want %q", tt.in, got, tt.want) + } + }) + } + + // Sanity-check the precision-tier constants stay aligned: every tier must + // resolve to the same RFC3339 instant when fed through formatTimestamp. + for _, v := range []int64{secs, millis, micros, nanos} { + got := formatTimestamp(strconv.FormatInt(v, 10)) + if got != want { + t.Fatalf("precision tier %d produced %q, want %q", v, got, want) + } + } +} diff --git a/cmd/serverless/create.go b/cmd/serverless/create.go index f2cb35bd..555e3515 100644 --- a/cmd/serverless/create.go +++ b/cmd/serverless/create.go @@ -3,7 +3,7 @@ package serverless import ( "encoding/json" "fmt" - "math/rand" + "math/rand/v2" "strings" "github.com/runpod/runpodctl/internal/api" @@ -24,6 +24,9 @@ examples: # create from a template runpodctl serverless create --template-id --gpu-id "NVIDIA GeForce RTX 4090" + # create from a template and attach a model + runpodctl serverless create --template-id --gpu-id ADA_24 --model-reference https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct:main + # create from a hub repo runpodctl hub search vllm # find the hub id runpodctl serverless create --hub-id --gpu-id "NVIDIA GeForce RTX 4090" @@ -53,6 +56,7 @@ var ( createFlashBoot bool createExecutionTimeout int createNetworkVolumeIDs string + createModelReferences []string ) func init() { @@ -74,6 +78,7 @@ func init() { createCmd.Flags().BoolVar(&createFlashBoot, "flash-boot", true, "enable flash boot") createCmd.Flags().IntVar(&createExecutionTimeout, "execution-timeout", -1, "max seconds per request") createCmd.Flags().StringVar(&createNetworkVolumeIDs, "network-volume-ids", "", "comma-separated network volume ids for multi-region") + createCmd.Flags().StringArrayVar(&createModelReferences, "model-reference", nil, "model reference to attach to the endpoint (repeatable)") } func runCreate(cmd *cobra.Command, args []string) error { @@ -83,6 +88,13 @@ func runCreate(cmd *cobra.Command, args []string) error { if createTemplateID != "" && createHubID != "" { return fmt.Errorf("--template-id and --hub-id are mutually exclusive; use one or the other") } + if createHubID != "" && len(createModelReferences) > 0 { + return fmt.Errorf("--model-reference is only supported with --template-id") + } + computeType := strings.ToUpper(strings.TrimSpace(createComputeType)) + if len(createModelReferences) > 0 && computeType != "" && computeType != "GPU" { + return fmt.Errorf("--model-reference is only supported with --compute-type GPU") + } client, err := api.NewClient() if err != nil { @@ -93,7 +105,7 @@ func runCreate(cmd *cobra.Command, args []string) error { req := &api.EndpointCreateRequest{ Name: createName, TemplateID: createTemplateID, - ComputeType: strings.ToUpper(strings.TrimSpace(createComputeType)), + ComputeType: computeType, GpuCount: createGpuCount, WorkersMin: createWorkersMin, WorkersMax: createWorkersMax, @@ -171,7 +183,6 @@ func runCreate(cmd *cobra.Command, args []string) error { endpointName = listing.Title } - //nolint:gosec templateName := fmt.Sprintf("%s__template__%s", endpointName, randomString(7)) gqlReq := &api.EndpointCreateGQLInput{ @@ -259,6 +270,18 @@ func runCreate(cmd *cobra.Command, args []string) error { req.NetworkVolumeIDs = strings.Split(createNetworkVolumeIDs, ",") } + if len(createModelReferences) > 0 { + gqlReq := buildTemplateEndpointGQLInput(req, gpuTypeID, createDataCenterIDs, createModelReferences) + gqlEndpoint, gqlErr := client.CreateEndpointGQL(gqlReq) + if gqlErr != nil { + output.Error(gqlErr) + return fmt.Errorf("failed to create endpoint: %w", gqlErr) + } + + format := output.ParseFormat(cmd.Flag("output").Value.String()) + return output.Print(gqlEndpoint, &output.Config{Format: format}) + } + endpoint, err := client.CreateEndpoint(req) if err != nil { output.Error(err) @@ -280,11 +303,31 @@ func runCreate(cmd *cobra.Command, args []string) error { return output.Print(endpoint, &output.Config{Format: format}) } +func buildTemplateEndpointGQLInput(req *api.EndpointCreateRequest, gpuTypeID, locations string, modelReferences []string) *api.EndpointCreateGQLInput { + // saveEndpoint derives GPU by default when instanceIds is omitted; computeType is not part of EndpointInput. + return &api.EndpointCreateGQLInput{ + Name: req.Name, + TemplateID: req.TemplateID, + GpuIDs: gpuTypeID, + GpuCount: req.GpuCount, + WorkersMin: req.WorkersMin, + WorkersMax: req.WorkersMax, + Locations: locations, + NetworkVolumeID: req.NetworkVolumeID, + ModelReferences: modelReferences, + } +} + +// randomString returns an n-character lowercase-alphanumeric suffix. +// The suffix is used only to disambiguate generated template names; it is not +// a token, secret, or anything an attacker benefits from predicting. gosec +// G404 flags math/rand/v2 generically, so the directive below is a deliberate +// suppression rather than a missed crypto/rand call. func randomString(n int) string { const letters = "abcdefghijklmnopqrstuvwxyz0123456789" b := make([]byte, n) - for i := range b { - b[i] = letters[rand.Intn(len(letters))] //nolint:gosec + for i := range n { + b[i] = letters[rand.IntN(len(letters))] //nolint:gosec // non-crypto: template-name uniqueness suffix, not a token } return string(b) } diff --git a/cmd/serverless/randomstring_bench_test.go b/cmd/serverless/randomstring_bench_test.go new file mode 100644 index 00000000..a30c0327 --- /dev/null +++ b/cmd/serverless/randomstring_bench_test.go @@ -0,0 +1,137 @@ +package serverless + +import ( + "math/rand/v2" + "testing" +) + +// Three implementations of the same 36-char alphanumeric ID generator. +// +// randomString (current, exported via the package) — one rand.IntN call per +// output character. Simple and unbiased. +// +// randomStringBitmask — pulls 64 random bits at a time via rand.Uint64 and +// extracts indices using a rejection-sampling loop (6 bits per attempt; reject +// values >= 36). Amortises source reads across the whole string but does extra +// branching and shift work per index. +// +// randomStringSingleUint64 — pulls one rand.Uint64 and uses repeated %36. Has +// modulo bias (36 doesn't divide 2^64 evenly), but n=7 makes the bias +// negligible for a non-cryptographic uniqueness tag. +// +// All three produce strings from the same alphabet; only the underlying source +// reads and per-character work differ. + +const benchLetters = "abcdefghijklmnopqrstuvwxyz0123456789" + +func randomStringBitmask(n int) string { + const ( + alphabetLen = uint64(36) + bitsPerIdx = 6 + mask = uint64(1<>= bitsPerIdx + left-- + if idx >= alphabetLen { + continue + } + b[i] = benchLetters[idx] + i++ + } + return string(b) +} + +func randomStringSingleUint64(n int) string { + const alphabetLen = uint64(36) + b := make([]byte, n) + x := rand.Uint64() //nolint:gosec // non-crypto: template-name uniqueness suffix + for i := range n { + b[i] = benchLetters[x%alphabetLen] + x /= alphabetLen + if x == 0 && i+1 < n { + // Pull more bits if we've consumed all the entropy from the first + // Uint64. log_36(2^64) ≈ 12.36, so for n <= 12 we never refill. + x = rand.Uint64() //nolint:gosec // non-crypto: template-name uniqueness suffix + } + } + return string(b) +} + +func BenchmarkRandomString_Current_7(b *testing.B) { + b.ReportAllocs() + for b.Loop() { + _ = randomString(7) + } +} + +func BenchmarkRandomString_Bitmask_7(b *testing.B) { + b.ReportAllocs() + for b.Loop() { + _ = randomStringBitmask(7) + } +} + +func BenchmarkRandomString_SingleUint64_7(b *testing.B) { + b.ReportAllocs() + for b.Loop() { + _ = randomStringSingleUint64(7) + } +} + +// Larger N to amplify per-character cost differences in case the n=7 +// measurements are dominated by allocation noise. +func BenchmarkRandomString_Current_64(b *testing.B) { + b.ReportAllocs() + for b.Loop() { + _ = randomString(64) + } +} + +func BenchmarkRandomString_Bitmask_64(b *testing.B) { + b.ReportAllocs() + for b.Loop() { + _ = randomStringBitmask(64) + } +} + +// Correctness check: every produced rune must be in the documented alphabet, +// and the length must match. Doesn't assert uniformity (out of scope), just +// that the alternative implementations don't emit garbage. +func TestRandomStringAlternativesProduceValidChars(t *testing.T) { + const iterations = 2000 + for _, tc := range []struct { + name string + fn func(int) string + }{ + {"current", randomString}, + {"bitmask", randomStringBitmask}, + {"single-uint64", randomStringSingleUint64}, + } { + t.Run(tc.name, func(t *testing.T) { + for range iterations { + s := tc.fn(7) + if len(s) != 7 { + t.Fatalf("len = %d, want 7 (%q)", len(s), s) + } + for _, r := range s { + if !((r >= 'a' && r <= 'z') || (r >= '0' && r <= '9')) { + t.Fatalf("invalid rune %q in %q", r, s) + } + } + } + }) + } +} diff --git a/cmd/serverless/serverless_test.go b/cmd/serverless/serverless_test.go index 6e574438..49255069 100644 --- a/cmd/serverless/serverless_test.go +++ b/cmd/serverless/serverless_test.go @@ -2,8 +2,11 @@ package serverless import ( "bytes" + "strings" "testing" + "github.com/runpod/runpodctl/internal/api" + "github.com/spf13/cobra" ) @@ -65,6 +68,313 @@ func TestCreateCmd_Flags(t *testing.T) { if flags.Lookup("workers-max") == nil { t.Error("expected --workers-max flag") } + if flags.Lookup("model-reference") == nil { + t.Error("expected --model-reference flag") + } +} + +func TestCreateCmd_RejectsHubWithModelReference(t *testing.T) { + oldTemplateID := createTemplateID + oldHubID := createHubID + oldModelReferences := createModelReferences + t.Cleanup(func() { + createTemplateID = oldTemplateID + createHubID = oldHubID + createModelReferences = oldModelReferences + }) + + createTemplateID = "" + createHubID = "hub-123" + createModelReferences = []string{"https://local/user/model:hash"} + + err := runCreate(&cobra.Command{}, nil) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "--model-reference is only supported with --template-id") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestCreateCmd_RejectsCPUWithModelReference(t *testing.T) { + oldTemplateID := createTemplateID + oldHubID := createHubID + oldComputeType := createComputeType + oldModelReferences := createModelReferences + t.Cleanup(func() { + createTemplateID = oldTemplateID + createHubID = oldHubID + createComputeType = oldComputeType + createModelReferences = oldModelReferences + }) + + createTemplateID = "tpl-123" + createHubID = "" + createComputeType = "CPU" + createModelReferences = []string{"https://local/user/model:hash"} + + err := runCreate(&cobra.Command{}, nil) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "--model-reference is only supported with --compute-type GPU") { + t.Fatalf("unexpected error: %v", err) + } +} + +// TestCreateCmd_FlagValidation exercises runCreate's flag-validation guards +// (the only logic that runs before the API client is constructed). Each row +// must fail with a specific error before any network call is attempted. +func TestCreateCmd_FlagValidation(t *testing.T) { + type fields struct { + templateID string + hubID string + computeType string + modelReferences []string + } + + tests := []struct { + name string + fields fields + wantErrSub string + }{ + { + name: "negative: template and hub both set", + fields: fields{ + templateID: "tpl-1", + hubID: "hub-1", + }, + wantErrSub: "mutually exclusive", + }, + { + name: "negative: hub + model-reference", + fields: fields{ + hubID: "hub-1", + modelReferences: []string{"hf://m"}, + }, + wantErrSub: "--model-reference is only supported with --template-id", + }, + { + name: "negative: CPU + model-reference (lowercase)", + fields: fields{ + templateID: "tpl-1", + computeType: "cpu", + modelReferences: []string{"hf://m"}, + }, + wantErrSub: "--model-reference is only supported with --compute-type GPU", + }, + { + name: "boundary: whitespace-only compute-type is treated as unspecified", + fields: fields{ + templateID: "tpl-1", + computeType: " ", + modelReferences: []string{"hf://m"}, + }, + // "" / whitespace passes the compute-type guard; the test should NOT + // fail on the compute-type message. It WILL fail later when the API + // client is constructed without RUNPOD_API_KEY in env, so we assert + // the error is not the compute-type guard's message. + wantErrSub: "", + }, + { + name: "corner: explicit GPU + model-reference passes validation", + fields: fields{ + templateID: "tpl-1", + computeType: "GPU", + modelReferences: []string{"hf://m"}, + }, + wantErrSub: "", + }, + { + name: "corner: mixed-case GPU normalises to GPU", + fields: fields{ + templateID: "tpl-1", + computeType: "gPu", + modelReferences: []string{"hf://m"}, + }, + wantErrSub: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + oldTemplateID := createTemplateID + oldHubID := createHubID + oldComputeType := createComputeType + oldModelReferences := createModelReferences + t.Cleanup(func() { + createTemplateID = oldTemplateID + createHubID = oldHubID + createComputeType = oldComputeType + createModelReferences = oldModelReferences + }) + + createTemplateID = tt.fields.templateID + createHubID = tt.fields.hubID + createComputeType = tt.fields.computeType + createModelReferences = tt.fields.modelReferences + + err := runCreate(&cobra.Command{}, nil) + + if tt.wantErrSub != "" { + if err == nil { + t.Fatalf("expected error containing %q, got nil", tt.wantErrSub) + } + if !strings.Contains(err.Error(), tt.wantErrSub) { + t.Fatalf("expected error containing %q, got %v", tt.wantErrSub, err) + } + return + } + + // rows with wantErrSub == "" should pass the guard layer; any error + // returned must come from a *later* layer (e.g. api.NewClient or the + // API call), not from the guard itself. Asserting absence of the + // guard-layer substrings keeps the test resilient to env differences. + if err != nil { + for _, guardMsg := range []string{ + "mutually exclusive", + "--model-reference is only supported with --template-id", + "--model-reference is only supported with --compute-type GPU", + } { + if strings.Contains(err.Error(), guardMsg) { + t.Fatalf("guard %q tripped unexpectedly: %v", guardMsg, err) + } + } + } + }) + } +} + +func TestBuildTemplateEndpointGQLInput(t *testing.T) { + tests := []struct { + name string + req *api.EndpointCreateRequest + gpuTypeID string + locations string + modelReferences []string + want api.EndpointCreateGQLInput + }{ + { + name: "positive: all fields populated", + req: &api.EndpointCreateRequest{ + Name: "test-endpoint", + TemplateID: "tpl-123", + GpuCount: 2, + WorkersMin: 1, + WorkersMax: 3, + NetworkVolumeID: "vol-123", + }, + gpuTypeID: "ADA_24", + locations: "US-KS-2", + modelReferences: []string{"https://local/user/model:hash"}, + want: api.EndpointCreateGQLInput{ + Name: "test-endpoint", + TemplateID: "tpl-123", + GpuIDs: "ADA_24", + GpuCount: 2, + WorkersMin: 1, + WorkersMax: 3, + Locations: "US-KS-2", + NetworkVolumeID: "vol-123", + ModelReferences: []string{"https://local/user/model:hash"}, + }, + }, + { + name: "boundary: zero workers and zero gpu count", + req: &api.EndpointCreateRequest{ + Name: "ep", + TemplateID: "tpl-z", + }, + gpuTypeID: "L40", + locations: "", + modelReferences: []string{"hf://m"}, + want: api.EndpointCreateGQLInput{ + Name: "ep", + TemplateID: "tpl-z", + GpuIDs: "L40", + ModelReferences: []string{"hf://m"}, + }, + }, + { + name: "corner: multiple model references preserved in order", + req: &api.EndpointCreateRequest{ + TemplateID: "tpl-multi", + }, + gpuTypeID: "A100", + modelReferences: []string{ + "hf://a:v1", + "hf://b:v2", + "hf://c:v3", + }, + want: api.EndpointCreateGQLInput{ + TemplateID: "tpl-multi", + GpuIDs: "A100", + ModelReferences: []string{"hf://a:v1", "hf://b:v2", "hf://c:v3"}, + }, + }, + { + name: "corner: empty input + nil model refs", + req: &api.EndpointCreateRequest{}, + gpuTypeID: "", + locations: "", + modelReferences: nil, + want: api.EndpointCreateGQLInput{}, + }, + { + name: "negative: name omitted is empty (server auto-generates)", + req: &api.EndpointCreateRequest{ + TemplateID: "tpl-noname", + }, + gpuTypeID: "RTX4090", + modelReferences: []string{"hf://m"}, + want: api.EndpointCreateGQLInput{ + TemplateID: "tpl-noname", + GpuIDs: "RTX4090", + ModelReferences: []string{"hf://m"}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := buildTemplateEndpointGQLInput(tt.req, tt.gpuTypeID, tt.locations, tt.modelReferences) + if got == nil { + t.Fatal("expected non-nil result") + } + if got.Name != tt.want.Name { + t.Errorf("Name = %q, want %q", got.Name, tt.want.Name) + } + if got.TemplateID != tt.want.TemplateID { + t.Errorf("TemplateID = %q, want %q", got.TemplateID, tt.want.TemplateID) + } + if got.GpuIDs != tt.want.GpuIDs { + t.Errorf("GpuIDs = %q, want %q", got.GpuIDs, tt.want.GpuIDs) + } + if got.GpuCount != tt.want.GpuCount { + t.Errorf("GpuCount = %d, want %d", got.GpuCount, tt.want.GpuCount) + } + if got.WorkersMin != tt.want.WorkersMin { + t.Errorf("WorkersMin = %d, want %d", got.WorkersMin, tt.want.WorkersMin) + } + if got.WorkersMax != tt.want.WorkersMax { + t.Errorf("WorkersMax = %d, want %d", got.WorkersMax, tt.want.WorkersMax) + } + if got.Locations != tt.want.Locations { + t.Errorf("Locations = %q, want %q", got.Locations, tt.want.Locations) + } + if got.NetworkVolumeID != tt.want.NetworkVolumeID { + t.Errorf("NetworkVolumeID = %q, want %q", got.NetworkVolumeID, tt.want.NetworkVolumeID) + } + if len(got.ModelReferences) != len(tt.want.ModelReferences) { + t.Fatalf("ModelReferences len = %d, want %d (%#v)", len(got.ModelReferences), len(tt.want.ModelReferences), got.ModelReferences) + } + for i := range got.ModelReferences { + if got.ModelReferences[i] != tt.want.ModelReferences[i] { + t.Errorf("ModelReferences[%d] = %q, want %q", i, got.ModelReferences[i], tt.want.ModelReferences[i]) + } + } + }) + } } func TestUpdateCmd_Flags(t *testing.T) { diff --git a/docs/runpodctl_serverless_create.md b/docs/runpodctl_serverless_create.md index 4a3523b4..3686a6d7 100644 --- a/docs/runpodctl_serverless_create.md +++ b/docs/runpodctl_serverless_create.md @@ -13,6 +13,9 @@ examples: # create from a template runpodctl serverless create --template-id --gpu-id "NVIDIA GeForce RTX 4090" + # create from a template and attach a model + runpodctl serverless create --template-id --gpu-id ADA_24 --model-reference https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct:main + # create from a hub repo runpodctl hub search vllm # find the hub id runpodctl serverless create --hub-id --gpu-id "NVIDIA GeForce RTX 4090" @@ -27,25 +30,26 @@ runpodctl serverless create [flags] ### Options ``` - --compute-type string compute type (GPU or CPU) (default "GPU") - --data-center-ids string comma-separated list of data center ids - --env strings env vars in KEY=VALUE format; overrides hub defaults (repeatable) - --execution-timeout int max seconds per request (default -1) - --flash-boot enable flash boot (default true) - --gpu-count int number of gpus per worker (default 1) - --gpu-id string gpu id (from 'runpodctl gpu list') - -h, --help help for create - --hub-id string hub listing id; accepts both SERVERLESS and POD types (alternative to --template-id) - --idle-timeout int seconds before idle worker scales down (1-3600) (default -1) - --min-cuda-version string minimum cuda version (e.g., 12.6) - --name string endpoint name - --network-volume-id string network volume id to attach - --network-volume-ids string comma-separated network volume ids for multi-region - --scale-by string autoscale strategy: delay (seconds of queue wait) or requests (pending request count) - --scale-threshold int trigger point for autoscaler (delay: seconds, requests: count) (default -1) - --template-id string template id (required if no --hub-id) - --workers-max int maximum number of workers (default 3) - --workers-min int minimum number of workers + --compute-type string compute type (GPU or CPU) (default "GPU") + --data-center-ids string comma-separated list of data center ids + --env strings env vars in KEY=VALUE format; overrides hub defaults (repeatable) + --execution-timeout int max seconds per request (default -1) + --flash-boot enable flash boot (default true) + --gpu-count int number of gpus per worker (default 1) + --gpu-id string gpu id (from 'runpodctl gpu list') + -h, --help help for create + --hub-id string hub listing id; accepts both SERVERLESS and POD types (alternative to --template-id) + --idle-timeout int seconds before idle worker scales down (1-3600) (default -1) + --min-cuda-version string minimum cuda version (e.g., 12.6) + --model-reference stringArray model reference to attach to the endpoint (repeatable) + --name string endpoint name + --network-volume-id string network volume id to attach + --network-volume-ids string comma-separated network volume ids for multi-region + --scale-by string autoscale strategy: delay (seconds of queue wait) or requests (pending request count) + --scale-threshold int trigger point for autoscaler (delay: seconds, requests: count) (default -1) + --template-id string template id (required if no --hub-id) + --workers-max int maximum number of workers (default 3) + --workers-min int minimum number of workers ``` ### Options inherited from parent commands diff --git a/go.mod b/go.mod index e172bdcb..d2b5e66b 100644 --- a/go.mod +++ b/go.mod @@ -1,14 +1,12 @@ module github.com/runpod/runpodctl -go 1.25.9 - -toolchain go1.26.3 +go 1.26.4 require ( github.com/denisbrodbeck/machineid v1.0.1 github.com/fatih/color v1.16.0 github.com/gobwas/glob v0.2.3 - github.com/google/uuid v1.4.0 + github.com/google/uuid v1.6.0 github.com/manifoldco/promptui v0.9.0 github.com/olekukonko/tablewriter v0.0.5 github.com/pelletier/go-toml v1.9.5 @@ -19,9 +17,9 @@ require ( github.com/schollz/progressbar/v3 v3.14.3 github.com/spf13/cobra v1.8.1 github.com/spf13/viper v1.19.0 - golang.org/x/crypto v0.50.0 + golang.org/x/crypto v0.53.0 golang.org/x/crypto/x509roots/fallback v0.0.0-20260213171211-a408498e5541 - golang.org/x/mod v0.34.0 + golang.org/x/mod v0.37.0 golang.org/x/time v0.5.0 gopkg.in/yaml.v3 v3.0.1 ) @@ -57,9 +55,9 @@ require ( go.uber.org/atomic v1.9.0 // indirect go.uber.org/multierr v1.9.0 // indirect golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect - golang.org/x/net v0.53.0 // indirect - golang.org/x/sys v0.43.0 // indirect - golang.org/x/term v0.42.0 // indirect - golang.org/x/text v0.36.0 // indirect + golang.org/x/net v0.55.0 // indirect + golang.org/x/sys v0.46.0 // indirect + golang.org/x/term v0.44.0 // indirect + golang.org/x/text v0.38.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect ) diff --git a/go.sum b/go.sum index dc5864f3..3039e503 100644 --- a/go.sum +++ b/go.sum @@ -32,6 +32,8 @@ github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/uuid v1.4.0 h1:MtMxsa51/r9yyhkyLsVeVt0B+BGQZzpQiTQ4eHZ8bc4= github.com/google/uuid v1.4.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= @@ -140,6 +142,8 @@ golang.org/x/crypto v0.35.0 h1:b15kiHdrGCHrP6LvwaQ3c03kgNhhiMgvlhxHQhmg2Xs= golang.org/x/crypto v0.35.0/go.mod h1:dy7dXNW32cAb/6/PRuTNsix8T+vJAqvuIy5Bli/x0YQ= golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI= golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q= +golang.org/x/crypto v0.53.0 h1:QZ4Muo8THX6CizN2vPPd5fBGHyogrdK9fG4wLPFUsto= +golang.org/x/crypto v0.53.0/go.mod h1:DNLU434OwVakk9PzuwV8w62mAJpRJL3vsgcfp4Qnsio= golang.org/x/crypto/x509roots/fallback v0.0.0-20260213171211-a408498e5541 h1:FmKxj9ocLKn45jiR2jQMwCVhDvaK7fKQFzfuT9GvyK8= golang.org/x/crypto/x509roots/fallback v0.0.0-20260213171211-a408498e5541/go.mod h1:+UoQFNBq2p2wO+Q6ddVtYc25GZ6VNdOMyyrd4nrqrKs= golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g= @@ -150,6 +154,8 @@ golang.org/x/mod v0.22.0 h1:D4nJWe9zXqHOmWqj4VMOJhvzj7bEZg4wEYa759z1pH4= golang.org/x/mod v0.22.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= golang.org/x/mod v0.34.0 h1:xIHgNUUnW6sYkcM5Jleh05DvLOtwc6RitGHbDk4akRI= golang.org/x/mod v0.34.0/go.mod h1:ykgH52iCZe79kzLLMhyCUzhMci+nQj+0XkbXpNYtVjY= +golang.org/x/mod v0.37.0 h1:vF1DjpVEshcIqoEaauuHebaLk1O1forxjxBaVn884JQ= +golang.org/x/mod v0.37.0/go.mod h1:m8S8VeM9r4dzDwjrKO0a1sZP3YjeMamRRlD+fmR2Q/0= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= @@ -161,6 +167,8 @@ golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0= golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA= golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs= +golang.org/x/net v0.55.0 h1:bcvxaJn3e1U6InsFWt1JUq1aSjnRxLzT2rtD2KfkDF8= +golang.org/x/net v0.55.0/go.mod h1:L5U2KuzuOe1lY7Z+aWVIKK6qEeJXnXV9yzGA+WCHJww= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -181,6 +189,8 @@ golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/sys v0.46.0 h1:noSf2Fq6F8DBgS+LysIkx7rIExoNHJsxOAtPp4rthXw= +golang.org/x/sys v0.46.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= @@ -191,6 +201,8 @@ golang.org/x/term v0.29.0 h1:L6pJp37ocefwRRtYPKSWOWzOtWSxVajvz2ldH/xi3iU= golang.org/x/term v0.29.0/go.mod h1:6bl4lRlvVuDgSf3179VpIxBF0o10JUpXWOnI7nErv7s= golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY= golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY= +golang.org/x/term v0.44.0 h1:0rLvDRCtNj0gZkyIXhCyOb2OAzEhLVqc4B+hrsBhrmc= +golang.org/x/term v0.44.0/go.mod h1:7ze4MdzUzLXpSAoFP1H0bOI9aXDqveSvatT5vKcFh2Y= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= @@ -202,6 +214,8 @@ golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM= golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY= golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg= golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164= +golang.org/x/text v0.38.0 h1:sXmwo9DwP3OK9EZ7PqAdaooSGozfl/3a6/xJcbzPRhE= +golang.org/x/text v0.38.0/go.mod h1:YXZt3QhHUKYT53r2lLKFIVi6Ao1jdzrTR/KQ09qyxF4= golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/internal/api/endpoints.go b/internal/api/endpoints.go index dd62487a..3967afb8 100644 --- a/internal/api/endpoints.go +++ b/internal/api/endpoints.go @@ -23,6 +23,7 @@ type Endpoint struct { MinCudaVersion string `json:"minCudaVersion,omitempty"` Flashboot *bool `json:"flashboot,omitempty"` ExecutionTimeoutMs int `json:"executionTimeoutMs,omitempty"` + ModelReferences []string `json:"modelReferences,omitempty"` Template map[string]interface{} `json:"template,omitempty"` Workers []interface{} `json:"workers,omitempty"` } @@ -197,7 +198,7 @@ func (c *Client) DeleteEndpoint(endpointID string) error { // EndpointCreateGQLInput is the input for creating an endpoint via GraphQL // Used when hubReleaseId is needed (REST API doesn't support it) type EndpointCreateGQLInput struct { - Name string `json:"name"` + Name string `json:"name,omitempty"` HubReleaseID string `json:"hubReleaseId,omitempty"` TemplateID string `json:"templateId,omitempty"` Template *EndpointTemplateInput `json:"template,omitempty"` @@ -207,6 +208,7 @@ type EndpointCreateGQLInput struct { WorkersMax int `json:"workersMax,omitempty"` Locations string `json:"locations,omitempty"` NetworkVolumeID string `json:"networkVolumeId,omitempty"` + ModelReferences []string `json:"modelReferences,omitempty"` } // EndpointTemplateInput is the inline template for endpoint creation via GraphQL @@ -225,6 +227,7 @@ func (c *Client) CreateEndpointGQL(req *EndpointCreateGQLInput) (*Endpoint, erro saveEndpoint(input: $input) { id name + templateId gpuIds networkVolumeId locations @@ -234,6 +237,7 @@ func (c *Client) CreateEndpointGQL(req *EndpointCreateGQLInput) (*Endpoint, erro workersMin workersMax gpuCount + modelReferences } } ` diff --git a/internal/api/endpoints_test.go b/internal/api/endpoints_test.go index a279eff5..f7cadc1c 100644 --- a/internal/api/endpoints_test.go +++ b/internal/api/endpoints_test.go @@ -4,6 +4,7 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "strings" "testing" "github.com/spf13/viper" @@ -95,6 +96,85 @@ func TestCreateEndpoint(t *testing.T) { } } +func TestEndpointCreateGQLInputOmitsEmptyName(t *testing.T) { + data, err := json.Marshal(EndpointCreateGQLInput{ + TemplateID: "tpl-123", + }) + if err != nil { + t.Fatalf("failed to marshal request: %v", err) + } + if strings.Contains(string(data), `"name"`) { + t.Fatalf("expected empty name to be omitted, got %s", data) + } +} + +func TestCreateEndpointGQLIncludesModelReferences(t *testing.T) { + modelReference := "https://local/user/model:hash" + oldAPIURL := viper.GetString("apiUrl") + t.Cleanup(func() { + viper.Set("apiUrl", oldAPIURL) + }) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("expected POST, got %s", r.Method) + } + + var body struct { + Query string `json:"query"` + Variables struct { + Input EndpointCreateGQLInput `json:"input"` + } `json:"variables"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Fatalf("failed to decode request: %v", err) + } + if !strings.Contains(body.Query, "modelReferences") { + t.Error("expected query to select modelReferences") + } + if !strings.Contains(body.Query, "templateId") { + t.Error("expected query to select templateId") + } + if len(body.Variables.Input.ModelReferences) != 1 || body.Variables.Input.ModelReferences[0] != modelReference { + t.Fatalf("unexpected model references: %#v", body.Variables.Input.ModelReferences) + } + + json.NewEncoder(w).Encode(map[string]interface{}{ + "data": map[string]interface{}{ + "saveEndpoint": map[string]interface{}{ + "id": "new-ep-id", + "name": body.Variables.Input.Name, + "templateId": body.Variables.Input.TemplateID, + "modelReferences": body.Variables.Input.ModelReferences, + }, + }, + }) + })) + defer server.Close() + + viper.Set("apiUrl", server.URL) + client := &Client{ + baseURL: "http://rest.example", + apiKey: "test-key", + httpClient: server.Client(), + } + + endpoint, err := client.CreateEndpointGQL(&EndpointCreateGQLInput{ + Name: "test-endpoint", + TemplateID: "tpl-123", + ModelReferences: []string{modelReference}, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if endpoint.TemplateID != "tpl-123" { + t.Errorf("expected template id tpl-123, got %s", endpoint.TemplateID) + } + if len(endpoint.ModelReferences) != 1 || endpoint.ModelReferences[0] != modelReference { + t.Fatalf("unexpected response model references: %#v", endpoint.ModelReferences) + } +} + func TestUpdateEndpoint(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPatch {