diff --git a/acceptance/experimental/air/get/out.test.toml b/acceptance/experimental/air/get/out.test.toml new file mode 100644 index 0000000000..d6187dcb04 --- /dev/null +++ b/acceptance/experimental/air/get/out.test.toml @@ -0,0 +1,3 @@ +Local = true +Cloud = false +EnvMatrix.DATABRICKS_BUNDLE_ENGINE = [] diff --git a/acceptance/experimental/air/get/output.txt b/acceptance/experimental/air/get/output.txt new file mode 100644 index 0000000000..6ce803659b --- /dev/null +++ b/acceptance/experimental/air/get/output.txt @@ -0,0 +1,36 @@ + +=== get (text) +>>> [CLI] experimental air get 123 +Run ID: 123 +Status: SUCCESS +Submitted: [TIMESTAMP] +Duration: 12s +Retries: 0 +Experiment: my-exp +User: user@example.com +Accelerators: 8x H100 +MLflow: [DATABRICKS_URL]/ml/experiments/exp1/runs/run1/artifacts/logs/node_0 +Dashboard: https://my-workspace.cloud.databricks.test/jobs/runs/123 + +=== get (json) +>>> [CLI] experimental air get 123 -o json +{ + "v": 1, + "ts": "[TIMESTAMP]", + "data": { + "run_id": "123", + "status": "SUCCESS", + "started_at": "[TIMESTAMP]", + "duration_seconds": 12, + "attempt_number": 0, + "experiment_name": "my-exp", + "dashboard_url": "https://my-workspace.cloud.databricks.test/jobs/runs/123", + "mlflow_url": "[DATABRICKS_URL]/ml/experiments/exp1/runs/run1/artifacts/logs/node_0" + } +} + +=== invalid run id +>>> [CLI] experimental air get notanumber +Error: invalid RUN_ID "notanumber": must be a positive integer + +Exit code: 1 diff --git a/acceptance/experimental/air/get/script b/acceptance/experimental/air/get/script new file mode 100644 index 0000000000..e0ea8d10f8 --- /dev/null +++ b/acceptance/experimental/air/get/script @@ -0,0 +1,8 @@ +title "get (text)" +trace $CLI experimental air get 123 + +title "get (json)" +trace $CLI experimental air get 123 -o json + +title "invalid run id" +errcode trace $CLI experimental air get notanumber diff --git a/acceptance/experimental/air/get/test.toml b/acceptance/experimental/air/get/test.toml new file mode 100644 index 0000000000..b6219b87f0 --- /dev/null +++ b/acceptance/experimental/air/get/test.toml @@ -0,0 +1,40 @@ +# This command does not deploy a bundle, so no engine matrix is needed. +[EnvMatrix] +DATABRICKS_BUNDLE_ENGINE = [] + +# The SDK occasionally probes host reachability with a HEAD request; stub it so +# the test is deterministic. +[[Server]] +Pattern = "HEAD /" +Response.Body = '' + +# A single GenAI-compute run with an experiment, GPUs, and a creator. +[[Server]] +Pattern = "GET /api/2.2/jobs/runs/get" +Response.Body = ''' +{ + "run_id": 123, + "run_page_url": "https://my-workspace.cloud.databricks.test/jobs/runs/123", + "creator_user_name": "user@example.com", + "start_time": 1700000000000, + "end_time": 1700000012000, + "state": {"life_cycle_state": "TERMINATED", "result_state": "SUCCESS"}, + "tasks": [ + { + "task_key": "train", + "attempt_number": 0, + "gen_ai_compute_task": { + "mlflow_experiment_name": "/Users/user@example.com/my-exp", + "compute": {"gpu_type": "GPU_8xH100", "num_gpus": 8} + } + } + ] +} +''' + +# MLflow identifiers for the deep-link (runs/get-output is not modeled by the typed SDK). +[[Server]] +Pattern = "GET /api/2.2/jobs/runs/get-output" +Response.Body = ''' +{"gen_ai_compute_output": {"run_info": {"mlflow_experiment_id": "exp1", "mlflow_run_id": "run1"}}} +''' diff --git a/acceptance/experimental/air/unimplemented/output.txt b/acceptance/experimental/air/unimplemented/output.txt index 4a07a38a37..0a86360c78 100644 --- a/acceptance/experimental/air/unimplemented/output.txt +++ b/acceptance/experimental/air/unimplemented/output.txt @@ -5,12 +5,6 @@ Error: `air run` is not implemented yet Exit code: 1 -=== get ->>> [CLI] experimental air get 123 -Error: `air get` is not implemented yet - -Exit code: 1 - === list >>> [CLI] experimental air list Error: `air list` is not implemented yet diff --git a/acceptance/experimental/air/unimplemented/script b/acceptance/experimental/air/unimplemented/script index 2ed885c0e6..e6e8d33ef9 100644 --- a/acceptance/experimental/air/unimplemented/script +++ b/acceptance/experimental/air/unimplemented/script @@ -3,9 +3,6 @@ title "run" errcode trace $CLI experimental air run -title "get" -errcode trace $CLI experimental air get 123 - title "list" errcode trace $CLI experimental air list diff --git a/experimental/air/cmd/format.go b/experimental/air/cmd/format.go new file mode 100644 index 0000000000..88f620ee7c --- /dev/null +++ b/experimental/air/cmd/format.go @@ -0,0 +1,154 @@ +package aircmd + +import ( + "fmt" + "strings" + "time" + + "github.com/databricks/databricks-sdk-go/service/jobs" +) + +// gpuDisplayNames maps the GPU identifiers returned by the backend to the short +// names we show to users. Unknown identifiers are shown unchanged. +var gpuDisplayNames = map[string]string{ + "h100_80gb": "H100", + "a10": "A10", + "GPU_1xA10": "A10", + "GPU_8xH100": "H100", + "GPU_1xH100": "H100", +} + +// runStatus returns the single status word to show for a run. The backend +// reports two values: a lifecycle state (e.g. PENDING, RUNNING) and, once the +// run has finished, a result state (e.g. SUCCESS, FAILED). The result state is +// the more meaningful one, so we prefer it when it is set. +func runStatus(state *jobs.RunState) string { + if state == nil { + return "UNKNOWN" + } + if state.ResultState != "" { + return string(state.ResultState) + } + if state.LifeCycleState != "" { + return string(state.LifeCycleState) + } + return "UNKNOWN" +} + +// startedAt converts the run's start time (epoch milliseconds) to an RFC 3339 +// UTC string, or returns nil if the run has not started yet. +func startedAt(run *jobs.Run) *string { + if run.StartTime == 0 { + return nil + } + s := time.UnixMilli(run.StartTime).UTC().Format(time.RFC3339) + return &s +} + +// durationSeconds returns how long the run has taken, in whole seconds, or nil +// if it has not started. For a finished run this is the elapsed time; for a +// still-running run it is the time since it started. +func durationSeconds(run *jobs.Run) *int64 { + if run.StartTime == 0 { + return nil + } + + var endMillis int64 + switch { + case run.RunDuration > 0: + // The backend already computed the duration for us. + d := run.RunDuration / 1000 + return &d + case run.EndTime > 0: + endMillis = run.EndTime + default: + // Still running: measure against the current time. + endMillis = time.Now().UnixMilli() + } + + d := (endMillis - run.StartTime) / 1000 + return &d +} + +// formatDuration turns a number of seconds into a compact human string such as +// "1h 2m 3s". Trailing zero units are dropped, but a lone "0s" is kept so the +// result is never empty. +func formatDuration(totalSeconds int64) string { + hours := totalSeconds / 3600 + minutes := (totalSeconds % 3600) / 60 + seconds := totalSeconds % 60 + + var parts []string + if hours > 0 { + parts = append(parts, fmt.Sprintf("%dh", hours)) + } + if minutes > 0 { + parts = append(parts, fmt.Sprintf("%dm", minutes)) + } + if seconds > 0 || len(parts) == 0 { + parts = append(parts, fmt.Sprintf("%ds", seconds)) + } + return strings.Join(parts, " ") +} + +// latestAttemptNumber returns the retry count of the run's most recent task. +// Tasks start at attempt 0, so a value of 0 means the run has not been retried. +func latestAttemptNumber(run *jobs.Run) int { + if len(run.Tasks) == 0 { + return 0 + } + return run.Tasks[len(run.Tasks)-1].AttemptNumber +} + +// experimentName returns the MLflow experiment name for the run, or nil if there +// isn't one. Experiment names are often stored under a user's home folder (e.g. +// "/Users/me@example.com/my-experiment"); we strip that prefix so users see just +// the experiment name they chose. +func experimentName(run *jobs.Run) *string { + if len(run.Tasks) == 0 { + return nil + } + task := run.Tasks[0].GenAiComputeTask + if task == nil || task.MlflowExperimentName == "" { + return nil + } + name := stripExperimentUserPrefix(task.MlflowExperimentName) + return &name +} + +// stripExperimentUserPrefix removes a leading "/Users//" from an +// experiment name, leaving the remainder. Names without that prefix are returned +// unchanged. +func stripExperimentUserPrefix(name string) string { + if !strings.HasPrefix(name, "/Users/") { + return name + } + // Split into ["", "Users", "", ""]; keep "". + parts := strings.SplitN(name, "/", 4) + if len(parts) == 4 { + return parts[3] + } + return name +} + +// accelerators returns a short description of the GPUs the run uses, such as +// "8x H100", or an empty string if the run has no GPU compute attached. +func accelerators(run *jobs.Run) string { + if len(run.Tasks) == 0 { + return "" + } + task := run.Tasks[0].GenAiComputeTask + if task == nil || task.Compute == nil || task.Compute.NumGpus == 0 { + return "" + } + return fmt.Sprintf("%dx %s", task.Compute.NumGpus, gpuDisplayName(task.Compute.GpuType)) +} + +// gpuDisplayName returns the friendly name for a GPU identifier, falling back to +// the identifier itself when it is not one we recognize. +func gpuDisplayName(gpuType string) string { + if name, ok := gpuDisplayNames[gpuType]; ok { + return name + } + return gpuType +} diff --git a/experimental/air/cmd/format_test.go b/experimental/air/cmd/format_test.go new file mode 100644 index 0000000000..c3e2e865b8 --- /dev/null +++ b/experimental/air/cmd/format_test.go @@ -0,0 +1,131 @@ +package aircmd + +import ( + "testing" + + "github.com/databricks/databricks-sdk-go/service/jobs" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestFormatDuration(t *testing.T) { + cases := []struct { + seconds int64 + want string + }{ + {0, "0s"}, + {45, "45s"}, + {60, "1m"}, + {63, "1m 3s"}, + {3600, "1h"}, + {3723, "1h 2m 3s"}, + {7260, "2h 1m"}, + } + for _, c := range cases { + assert.Equal(t, c.want, formatDuration(c.seconds)) + } +} + +func TestStripExperimentUserPrefix(t *testing.T) { + cases := []struct { + name string + want string + }{ + {"/Users/me@example.com/my-experiment", "my-experiment"}, + {"/Users/me@example.com/nested/path", "nested/path"}, + {"my-experiment", "my-experiment"}, + {"/Shared/team-experiment", "/Shared/team-experiment"}, + {"/Users/me@example.com", "/Users/me@example.com"}, + } + for _, c := range cases { + assert.Equal(t, c.want, stripExperimentUserPrefix(c.name)) + } +} + +func TestGpuDisplayName(t *testing.T) { + assert.Equal(t, "H100", gpuDisplayName("h100_80gb")) + assert.Equal(t, "A10", gpuDisplayName("GPU_1xA10")) + assert.Equal(t, "A10", gpuDisplayName("a10")) + assert.Equal(t, "H100", gpuDisplayName("GPU_8xH100")) + assert.Equal(t, "H100", gpuDisplayName("GPU_1xH100")) + // Unknown identifiers pass through unchanged. + assert.Equal(t, "b200", gpuDisplayName("b200")) + assert.Equal(t, "", gpuDisplayName("")) +} + +func TestRunStatusPrefersResultState(t *testing.T) { + // Result state wins once the run has finished. + assert.Equal(t, "SUCCESS", runStatus(&jobs.RunState{ + LifeCycleState: jobs.RunLifeCycleStateTerminated, + ResultState: jobs.RunResultStateSuccess, + })) + // Before completion only the lifecycle state is set. + assert.Equal(t, "RUNNING", runStatus(&jobs.RunState{ + LifeCycleState: jobs.RunLifeCycleStateRunning, + })) + // Non-nil state with neither field set, and nil state. + assert.Equal(t, "UNKNOWN", runStatus(&jobs.RunState{})) + assert.Equal(t, "UNKNOWN", runStatus(nil)) +} + +func TestStartedAt(t *testing.T) { + // Not started yet. + assert.Nil(t, startedAt(&jobs.Run{})) + // 1700000000000 ms == 2023-11-14T22:13:20Z. + got := startedAt(&jobs.Run{StartTime: 1700000000000}) + require.NotNil(t, got) + assert.Equal(t, "2023-11-14T22:13:20Z", *got) +} + +func TestDurationSeconds(t *testing.T) { + // Not started yet. + assert.Nil(t, durationSeconds(&jobs.Run{})) + + // Backend-provided duration wins (milliseconds → seconds). + d := durationSeconds(&jobs.Run{StartTime: 1700000000000, RunDuration: 5000}) + require.NotNil(t, d) + assert.Equal(t, int64(5), *d) + + // Finished run with no RunDuration: end - start. + d = durationSeconds(&jobs.Run{StartTime: 1700000000000, EndTime: 1700000012000}) + require.NotNil(t, d) + assert.Equal(t, int64(12), *d) + + // Still running: measured against the current time, so positive. + d = durationSeconds(&jobs.Run{StartTime: 1700000000000}) + require.NotNil(t, d) + assert.Positive(t, *d) +} + +func TestLatestAttemptNumber(t *testing.T) { + assert.Equal(t, 0, latestAttemptNumber(&jobs.Run{})) + run := &jobs.Run{Tasks: []jobs.RunTask{{AttemptNumber: 0}, {AttemptNumber: 2}}} + assert.Equal(t, 2, latestAttemptNumber(run)) +} + +func TestExperimentName(t *testing.T) { + assert.Nil(t, experimentName(&jobs.Run{})) + assert.Nil(t, experimentName(&jobs.Run{Tasks: []jobs.RunTask{{}}})) + assert.Nil(t, experimentName(&jobs.Run{Tasks: []jobs.RunTask{{ + GenAiComputeTask: &jobs.GenAiComputeTask{MlflowExperimentName: ""}, + }}})) + got := experimentName(&jobs.Run{Tasks: []jobs.RunTask{{ + GenAiComputeTask: &jobs.GenAiComputeTask{MlflowExperimentName: "/Users/me@example.com/exp"}, + }}}) + require.NotNil(t, got) + assert.Equal(t, "exp", *got) +} + +func TestAccelerators(t *testing.T) { + assert.Equal(t, "", accelerators(&jobs.Run{})) + assert.Equal(t, "", accelerators(&jobs.Run{Tasks: []jobs.RunTask{{}}})) + assert.Equal(t, "", accelerators(&jobs.Run{Tasks: []jobs.RunTask{{ + GenAiComputeTask: &jobs.GenAiComputeTask{}, + }}})) + assert.Equal(t, "", accelerators(&jobs.Run{Tasks: []jobs.RunTask{{ + GenAiComputeTask: &jobs.GenAiComputeTask{Compute: &jobs.ComputeConfig{NumGpus: 0}}, + }}})) + assert.Equal(t, "8x H100", accelerators(&jobs.Run{Tasks: []jobs.RunTask{{ + GenAiComputeTask: &jobs.GenAiComputeTask{Compute: &jobs.ComputeConfig{NumGpus: 8, GpuType: "GPU_8xH100"}}, + }}})) +} diff --git a/experimental/air/cmd/get.go b/experimental/air/cmd/get.go index 0ab0b8226b..cc486b722f 100644 --- a/experimental/air/cmd/get.go +++ b/experimental/air/cmd/get.go @@ -1,19 +1,187 @@ package aircmd import ( + "context" + "errors" + "fmt" + "io" + "strconv" + "github.com/databricks/cli/cmd/root" + "github.com/databricks/cli/libs/cmdctx" + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/flags" + "github.com/databricks/cli/libs/log" + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/apierr" + "github.com/databricks/databricks-sdk-go/service/jobs" "github.com/spf13/cobra" ) +// getData is the payload printed by `air get`. The json-tagged fields form +// the machine-readable output; fields tagged `json:"-"` are shown only in the +// human-readable text view. +type getData struct { + RunID string `json:"run_id"` + Status string `json:"status"` + StartedAt *string `json:"started_at"` + DurationSeconds *int64 `json:"duration_seconds"` + AttemptNumber int `json:"attempt_number"` + ExperimentName *string `json:"experiment_name"` + DashboardURL string `json:"dashboard_url"` + MLflowURL *string `json:"mlflow_url"` + + // Duration is the human-readable form of DurationSeconds, e.g. "12m 3s". + Duration string `json:"-"` + // Accelerators describes the run's GPUs, e.g. "8x H100". + Accelerators string `json:"-"` + // User is the run's creator. Text-only; JSON omits it, matching `air get --json`. + User string `json:"-"` + // Sweep replaces the single-run view for foreach runs. Text-only; JSON omits it. + Sweep *sweepInfo `json:"-"` +} + +// getTemplate is the text-mode layout. It reads from the JSON envelope, so +// every field is reached through ".Data". Optional rows are hidden when empty. +const getTemplate = `{{- if .Data.Sweep -}} +Sweep Run ID: {{.Data.RunID}} +Status: {{.Data.Status}} +Total: {{.Data.Sweep.Total}} +Completed: {{.Data.Sweep.Completed}} +Succeeded: {{.Data.Sweep.Succeeded}} +Failed: {{.Data.Sweep.Failed}} +Active: {{.Data.Sweep.Active}} +{{- if .Data.Sweep.Tasks}} + +Sweep Tasks: +{{printf " %-24s %-14s %-12s %s" "TASK" "RUN ID" "STATUS" "EXPERIMENT"}} +{{- range .Data.Sweep.Tasks}} +{{printf " %-24s %-14s %-12s %s" .TaskKey .RunID .Status .Experiment}} +{{- end}} +{{- end}} +{{- else -}} +Run ID: {{.Data.RunID}} +Status: {{.Data.Status}} +{{- if .Data.StartedAt}} +Submitted: {{.Data.StartedAt}} +{{- end}} +{{- if .Data.Duration}} +Duration: {{.Data.Duration}} +{{- end}} +Retries: {{.Data.AttemptNumber}} +{{- if .Data.ExperimentName}} +Experiment: {{.Data.ExperimentName}} +{{- end}} +{{- if .Data.User}} +User: {{.Data.User}} +{{- end}} +{{- if .Data.Accelerators}} +Accelerators: {{.Data.Accelerators}} +{{- end}} +{{- if .Data.MLflowURL}} +MLflow: {{.Data.MLflowURL}} +{{- end}} +Dashboard: {{.Data.DashboardURL}} +{{- end}} +` + func newGetCommand() *cobra.Command { cmd := &cobra.Command{ Use: "get RUN_ID", Args: root.ExactArgs(1), Short: "Show details for a run", - RunE: func(cmd *cobra.Command, args []string) error { - return notImplemented("get") + Annotations: map[string]string{ + "template": getTemplate, }, } + cmd.PreRunE = root.MustWorkspaceClient + + cmd.RunE = func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + w := cmdctx.WorkspaceClient(ctx) + + runID, err := strconv.ParseInt(args[0], 10, 64) + if err != nil || runID <= 0 { + return fmt.Errorf("invalid RUN_ID %q: must be a positive integer", args[0]) + } + + run, err := w.Jobs.GetRun(ctx, jobs.GetRunRequest{RunId: runID}) + if err != nil { + // The backend returns this when the run ID is unknown to the user. + if errors.Is(err, apierr.ErrResourceDoesNotExist) { + return fmt.Errorf("run %d not found: check the run ID and that it is a job run ID", runID) + } + return fmt.Errorf("failed to get status for run %d: %w", runID, err) + } + + data := buildGetData(run) + data.MLflowURL = mlflowURL(ctx, w, run) + if task := findForEachTask(run); task != nil { + data.Sweep = buildSweepInfo(ctx, w, task) + } + + // Text mode shows the training-config YAML before the status, mirroring + // `air get`. JSON output omits it, matching `air get --json`. + if root.OutputType(cmd) == flags.OutputText { + if path := yamlConfigPath(run); path != "" { + printConfigYAML(ctx, w, path) + } + } + return renderEnvelope(ctx, data) + } + return cmd } + +// buildGetData extracts the fields we display from a run. +func buildGetData(run *jobs.Run) getData { + data := getData{ + RunID: strconv.FormatInt(run.RunId, 10), + Status: runStatus(run.State), + StartedAt: startedAt(run), + DurationSeconds: durationSeconds(run), + AttemptNumber: latestAttemptNumber(run), + ExperimentName: experimentName(run), + DashboardURL: run.RunPageUrl, + Accelerators: accelerators(run), + User: run.CreatorUserName, + } + if data.DurationSeconds != nil { + data.Duration = formatDuration(*data.DurationSeconds) + } + return data +} + +// yamlConfigPath returns the run's training-config YAML path, or "" if none. +func yamlConfigPath(run *jobs.Run) string { + if len(run.Tasks) == 0 { + return "" + } + task := run.Tasks[0].GenAiComputeTask + if task == nil { + return "" + } + return task.YamlParametersFilePath +} + +// printConfigYAML downloads the run's training-config YAML and prints it. It is +// best-effort: a failure is surfaced as a warning but does not fail status. +func printConfigYAML(ctx context.Context, w *databricks.WorkspaceClient, path string) { + r, err := w.Workspace.Download(ctx, path) + if err != nil { + log.Warnf(ctx, "air get: could not download training config %s: %v", path, err) + return + } + defer r.Close() + + content, err := io.ReadAll(r) + if err != nil { + log.Warnf(ctx, "air get: could not read training config %s: %v", path, err) + return + } + + cmdio.LogString(ctx, "Training Configuration:") + cmdio.LogString(ctx, string(content)) + cmdio.LogString(ctx, "") +} diff --git a/experimental/air/cmd/get_test.go b/experimental/air/cmd/get_test.go new file mode 100644 index 0000000000..6dfdc54db7 --- /dev/null +++ b/experimental/air/cmd/get_test.go @@ -0,0 +1,211 @@ +package aircmd + +import ( + "bytes" + "io" + "strings" + "testing" + "text/template" + + "github.com/databricks/cli/libs/cmdctx" + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/databricks-sdk-go/apierr" + "github.com/databricks/databricks-sdk-go/experimental/mocks" + "github.com/databricks/databricks-sdk-go/service/jobs" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +// renderGet renders the status template against the JSON envelope, exactly as +// the command does, so the test covers the real template branches. +func renderGet(t *testing.T, data getData) string { + t.Helper() + tmpl, err := template.New("status").Parse(getTemplate) + require.NoError(t, err) + var buf bytes.Buffer + require.NoError(t, tmpl.Execute(&buf, envelope{V: envelopeVersion, Data: data})) + return buf.String() +} + +func TestGetTemplateSingleRun(t *testing.T) { + out := renderGet(t, getData{ + RunID: "123", + Status: "RUNNING", + User: "me@example.com", + DashboardURL: "https://example.test/run/123", + }) + assert.Contains(t, out, "Run ID: 123") + assert.Contains(t, out, "Status: RUNNING") + assert.Contains(t, out, "User:") + assert.Contains(t, out, "me@example.com") + assert.Contains(t, out, "Dashboard: https://example.test/run/123") + assert.NotContains(t, out, "Sweep") +} + +func TestGetRunInvalidID(t *testing.T) { + m := mocks.NewMockWorkspaceClient(t) + ctx := cmdctx.SetWorkspaceClient(cmdio.MockDiscard(t.Context()), m.WorkspaceClient) + cmd := newGetCommand() + cmd.SetContext(ctx) + + err := cmd.RunE(cmd, []string{"abc"}) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid RUN_ID") +} + +func TestGetRunNotFound(t *testing.T) { + m := mocks.NewMockWorkspaceClient(t) + m.GetMockJobsAPI().EXPECT().GetRun(mock.Anything, jobs.GetRunRequest{RunId: 5}).Return( + nil, apierr.ErrResourceDoesNotExist) + ctx := cmdctx.SetWorkspaceClient(cmdio.MockDiscard(t.Context()), m.WorkspaceClient) + cmd := newGetCommand() + cmd.SetContext(ctx) + + err := cmd.RunE(cmd, []string{"5"}) + require.Error(t, err) + assert.Contains(t, err.Error(), "run 5 not found") +} + +func TestPrintConfigYAML(t *testing.T) { + t.Run("downloads and prints", func(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + m := mocks.NewMockWorkspaceClient(t) + // The mock asserts Download is called with the resolved path. + m.GetMockWorkspaceAPI().EXPECT(). + Download(mock.Anything, "/Workspace/cfg.yaml"). + Return(io.NopCloser(strings.NewReader("epochs: 3\n")), nil) + + printConfigYAML(ctx, m.WorkspaceClient, "/Workspace/cfg.yaml") + }) + + t.Run("download failure is non-fatal", func(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + m := mocks.NewMockWorkspaceClient(t) + m.GetMockWorkspaceAPI().EXPECT(). + Download(mock.Anything, "/Workspace/missing.yaml"). + Return(nil, apierr.ErrResourceDoesNotExist) + + // Must not panic: a failed config fetch is best-effort. + printConfigYAML(ctx, m.WorkspaceClient, "/Workspace/missing.yaml") + }) +} + +func TestYAMLConfigPath(t *testing.T) { + // No tasks, or a task without GenAiComputeTask, yields no path. + assert.Equal(t, "", yamlConfigPath(&jobs.Run{})) + assert.Equal(t, "", yamlConfigPath(&jobs.Run{Tasks: []jobs.RunTask{{}}})) + + run := &jobs.Run{Tasks: []jobs.RunTask{{ + GenAiComputeTask: &jobs.GenAiComputeTask{YamlParametersFilePath: "/Workspace/cfg.yaml"}, + }}} + assert.Equal(t, "/Workspace/cfg.yaml", yamlConfigPath(run)) +} + +func TestGetTemplateSweep(t *testing.T) { + out := renderGet(t, getData{ + RunID: "456", + Status: "RUNNING", + Sweep: &sweepInfo{ + Total: 4, Completed: 2, Succeeded: 1, Failed: 1, Active: 2, + Tasks: []sweepTask{ + {TaskKey: "iter_0", RunID: "789", Status: "SUCCESS", Experiment: "my-exp"}, + {TaskKey: "iter_1", RunID: "790", Status: "FAILED", Experiment: "my-exp"}, + }, + }, + }) + assert.Contains(t, out, "Sweep Run ID: 456") + assert.Contains(t, out, "Total: 4") + assert.Contains(t, out, "Sweep Tasks:") + assert.Contains(t, out, "iter_0") + assert.Contains(t, out, "iter_1") + assert.Contains(t, out, "FAILED") + assert.Contains(t, out, "my-exp") + // The single-run rows must not appear in the sweep view. + assert.NotContains(t, out, "Dashboard:") +} + +func TestGetTemplateSweepNoTasks(t *testing.T) { + // A sweep whose iterations haven't materialized yet: counts show, but the + // task table header is hidden. + out := renderGet(t, getData{ + RunID: "456", + Status: "RUNNING", + Sweep: &sweepInfo{Total: 4, Active: 4}, + }) + assert.Contains(t, out, "Sweep Run ID: 456") + assert.Contains(t, out, "Total: 4") + assert.NotContains(t, out, "Sweep Tasks:") +} + +func TestGetTemplateMinimal(t *testing.T) { + // Only the always-present rows render; optional rows are hidden when empty. + out := renderGet(t, getData{RunID: "1", Status: "PENDING", DashboardURL: "https://example.test/1"}) + assert.Contains(t, out, "Run ID: 1") + assert.Contains(t, out, "Status: PENDING") + assert.Contains(t, out, "Retries: 0") + assert.Contains(t, out, "Dashboard: https://example.test/1") + for _, hidden := range []string{"Submitted:", "Duration:", "Experiment:", "User:", "Accelerators:", "MLflow:"} { + assert.NotContains(t, out, hidden) + } +} + +func TestGetTemplateAllFields(t *testing.T) { + started := "2023-11-14T22:13:20Z" + exp := "exp" + mlflow := "https://example.test/ml/exp/1" + out := renderGet(t, getData{ + RunID: "1", + Status: "SUCCESS", + StartedAt: &started, + Duration: "12s", + AttemptNumber: 2, + ExperimentName: &exp, + User: "me@example.com", + Accelerators: "8x H100", + MLflowURL: &mlflow, + DashboardURL: "https://example.test/1", + }) + for _, want := range []string{ + "Submitted: 2023-11-14T22:13:20Z", + "Duration: 12s", + "Retries: 2", + "Experiment: exp", + "User: me@example.com", + "Accelerators: 8x H100", + "MLflow: https://example.test/ml/exp/1", + "Dashboard: https://example.test/1", + } { + assert.Contains(t, out, want) + } +} + +func TestBuildStatusData(t *testing.T) { + run := &jobs.Run{ + RunId: 123, + RunPageUrl: "https://example.test/run/123", + CreatorUserName: "me@example.com", + StartTime: 1700000000000, + EndTime: 1700000012000, + State: &jobs.RunState{ResultState: jobs.RunResultStateSuccess}, + Tasks: []jobs.RunTask{{ + AttemptNumber: 1, + GenAiComputeTask: &jobs.GenAiComputeTask{ + MlflowExperimentName: "/Users/me@example.com/exp", + Compute: &jobs.ComputeConfig{NumGpus: 8, GpuType: "GPU_8xH100"}, + }, + }}, + } + d := buildGetData(run) + assert.Equal(t, "123", d.RunID) + assert.Equal(t, "SUCCESS", d.Status) + assert.Equal(t, 1, d.AttemptNumber) + assert.Equal(t, "https://example.test/run/123", d.DashboardURL) + assert.Equal(t, "me@example.com", d.User) + assert.Equal(t, "8x H100", d.Accelerators) + assert.Equal(t, "12s", d.Duration) + require.NotNil(t, d.ExperimentName) + assert.Equal(t, "exp", *d.ExperimentName) + require.NotNil(t, d.DurationSeconds) + assert.Equal(t, int64(12), *d.DurationSeconds) +} diff --git a/experimental/air/cmd/mlflow.go b/experimental/air/cmd/mlflow.go new file mode 100644 index 0000000000..97d085b012 --- /dev/null +++ b/experimental/air/cmd/mlflow.go @@ -0,0 +1,65 @@ +package aircmd + +import ( + "context" + "fmt" + "net/http" + "strings" + + "github.com/databricks/cli/libs/log" + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/client" + "github.com/databricks/databricks-sdk-go/service/jobs" +) + +// getRunOutputResponse is the slice of the jobs runs/get-output response we care +// about. The MLflow identifiers live under a gen_ai_compute_output field that +// the typed SDK does not model, so we call the endpoint directly and parse just +// these fields. +type getRunOutputResponse struct { + GenAiComputeOutput *struct { + RunInfo *struct { + MlflowExperimentID string `json:"mlflow_experiment_id"` + MlflowRunID string `json:"mlflow_run_id"` + } `json:"run_info"` + } `json:"gen_ai_compute_output"` +} + +// mlflowURL returns a link to the run's MLflow logs, or nil if it can't be +// built. The link is a convenience, so any failure here (missing task, endpoint +// error, run not yet started) is logged and treated as "no link" rather than +// failing the whole command. +func mlflowURL(ctx context.Context, w *databricks.WorkspaceClient, run *jobs.Run) *string { + if len(run.Tasks) == 0 { + return nil + } + // The MLflow output is attached to the task run, not the parent job run. + taskRunID := run.Tasks[0].RunId + + apiClient, err := client.New(w.Config) + if err != nil { + log.Debugf(ctx, "air get: could not build API client for MLflow link: %v", err) + return nil + } + + var out getRunOutputResponse + err = apiClient.Do(ctx, http.MethodGet, "/api/2.2/jobs/runs/get-output", + nil, map[string]any{"run_id": taskRunID}, nil, &out) + if err != nil { + log.Debugf(ctx, "air get: could not fetch run output for MLflow link: %v", err) + return nil + } + + if out.GenAiComputeOutput == nil || out.GenAiComputeOutput.RunInfo == nil { + return nil + } + info := out.GenAiComputeOutput.RunInfo + if info.MlflowExperimentID == "" || info.MlflowRunID == "" { + return nil + } + + host := strings.TrimRight(w.Config.Host, "/") + url := fmt.Sprintf("%s/ml/experiments/%s/runs/%s/artifacts/logs/node_0", + host, info.MlflowExperimentID, info.MlflowRunID) + return &url +} diff --git a/experimental/air/cmd/mlflow_test.go b/experimental/air/cmd/mlflow_test.go new file mode 100644 index 0000000000..bbc4fef982 --- /dev/null +++ b/experimental/air/cmd/mlflow_test.go @@ -0,0 +1,64 @@ +package aircmd + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/service/jobs" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// newTestWorkspaceClient builds a WorkspaceClient pointed at a mock HTTP server. +// mlflowURL calls the runs/get-output REST endpoint directly (the field it needs +// is not modeled by the typed SDK), so it must be exercised over HTTP. +func newTestWorkspaceClient(t *testing.T, host string) *databricks.WorkspaceClient { + t.Helper() + w, err := databricks.NewWorkspaceClient(&databricks.Config{Host: host, Token: "token"}) + require.NoError(t, err) + return w +} + +// runOutputServer serves the given runs/get-output body and a stub for the SDK's +// well-known config discovery request. *hit is set when get-output is called. +func runOutputServer(t *testing.T, body string, hit *bool) *httptest.Server { + t.Helper() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/2.2/jobs/runs/get-output" { + *hit = true + _, _ = w.Write([]byte(body)) + return + } + _, _ = w.Write([]byte(`{}`)) + })) + t.Cleanup(srv.Close) + return srv +} + +func TestMLflowURL(t *testing.T) { + ctx := t.Context() + run := &jobs.Run{Tasks: []jobs.RunTask{{RunId: 99}}} + + t.Run("builds the deep-link on success", func(t *testing.T) { + var hit bool + srv := runOutputServer(t, `{"gen_ai_compute_output":{"run_info":{"mlflow_experiment_id":"E1","mlflow_run_id":"R1"}}}`, &hit) + + got := mlflowURL(ctx, newTestWorkspaceClient(t, srv.URL), run) + require.NotNil(t, got) + assert.True(t, hit, "runs/get-output should have been called") + assert.Equal(t, srv.URL+"/ml/experiments/E1/runs/R1/artifacts/logs/node_0", *got) + }) + + t.Run("nil when the run has no MLflow info", func(t *testing.T) { + var hit bool + srv := runOutputServer(t, `{}`, &hit) + assert.Nil(t, mlflowURL(ctx, newTestWorkspaceClient(t, srv.URL), run)) + }) + + t.Run("nil when the run has no tasks", func(t *testing.T) { + // Returns before any HTTP call, so the host is never contacted. + assert.Nil(t, mlflowURL(ctx, newTestWorkspaceClient(t, "https://unused.invalid"), &jobs.Run{})) + }) +} diff --git a/experimental/air/cmd/output.go b/experimental/air/cmd/output.go new file mode 100644 index 0000000000..3da766a7d4 --- /dev/null +++ b/experimental/air/cmd/output.go @@ -0,0 +1,39 @@ +package aircmd + +import ( + "context" + "time" + + "github.com/databricks/cli/libs/cmdio" +) + +// envelopeVersion is the envelope's format-version marker. The Python `air` CLI +// hardcodes it to 1; it lets consumers detect a future incompatible change to +// the envelope shape. +const envelopeVersion = 1 + +// envelope is the JSON shape that the AI runtime CLI prints: +// +// { "v": 1, "ts": "2024-01-15T14:30:45Z", "data": { ... } } +// +// It mirrors the envelope used by the original Python `air` CLI so existing +// consumers keep working after the port to Go. +type envelope struct { + // V is the envelope format-version marker (always 1). + V int `json:"v"` + // TS is the wall-clock time the response was produced, in RFC 3339 UTC. + // It is an absolute timestamp, not an elapsed duration. + TS string `json:"ts"` + // Data is the command-specific payload. + Data any `json:"data"` +} + +// renderEnvelope wraps data in the JSON envelope and prints it. +// Fields that should appear only in text output are tagged `json:"-"` on the payload struct. +func renderEnvelope(ctx context.Context, data any) error { + return cmdio.Render(ctx, envelope{ + V: envelopeVersion, + TS: time.Now().UTC().Format(time.RFC3339), + Data: data, + }) +} diff --git a/experimental/air/cmd/output_test.go b/experimental/air/cmd/output_test.go new file mode 100644 index 0000000000..73a5572c3f --- /dev/null +++ b/experimental/air/cmd/output_test.go @@ -0,0 +1,13 @@ +package aircmd + +import ( + "testing" + + "github.com/databricks/cli/libs/cmdio" + "github.com/stretchr/testify/require" +) + +func TestRenderEnvelope(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + require.NoError(t, renderEnvelope(ctx, getData{RunID: "1", Status: "RUNNING"})) +} diff --git a/experimental/air/cmd/stubs_test.go b/experimental/air/cmd/stubs_test.go index a6e24177f3..5e35bcdcd1 100644 --- a/experimental/air/cmd/stubs_test.go +++ b/experimental/air/cmd/stubs_test.go @@ -14,7 +14,6 @@ import ( func TestStubCommandsReturnNotImplemented(t *testing.T) { stubs := map[string]*cobra.Command{ "run": newRunCommand(), - "get": newGetCommand(), "list": newListCommand(), "logs": newLogsCommand(), "cancel": newCancelCommand(), diff --git a/experimental/air/cmd/sweep.go b/experimental/air/cmd/sweep.go new file mode 100644 index 0000000000..b346f43f1b --- /dev/null +++ b/experimental/air/cmd/sweep.go @@ -0,0 +1,76 @@ +package aircmd + +import ( + "context" + "strconv" + + "github.com/databricks/cli/libs/log" + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/service/jobs" +) + +// sweepInfo summarizes a "foreach" run, which fans a single config out into many +// iterations (a hyperparameter sweep). It is shown only in text output. +type sweepInfo struct { + Total int + Succeeded int + Failed int + Active int + Completed int + Tasks []sweepTask +} + +// sweepTask is one iteration of a sweep. +type sweepTask struct { + TaskKey string + RunID string + Status string + Experiment string +} + +// findForEachTask returns the run's foreach task if it has one, or nil. A run is +// a sweep when one of its tasks fans out into iterations. +func findForEachTask(run *jobs.Run) *jobs.RunTask { + for i := range run.Tasks { + if run.Tasks[i].ForEachTask != nil { + return &run.Tasks[i] + } + } + return nil +} + +// buildSweepInfo gathers the iteration counts and per-iteration rows for a +// sweep. The counts come from the task we already have; the individual +// iterations require a second lookup. If that lookup fails we still return the +// counts (logging the failure) so the user sees the summary. +func buildSweepInfo(ctx context.Context, w *databricks.WorkspaceClient, task *jobs.RunTask) *sweepInfo { + info := &sweepInfo{} + if task.ForEachTask.Stats != nil && task.ForEachTask.Stats.TaskRunStats != nil { + stats := task.ForEachTask.Stats.TaskRunStats + info.Total = stats.TotalIterations + info.Succeeded = stats.SucceededIterations + info.Failed = stats.FailedIterations + info.Active = stats.ActiveIterations + info.Completed = stats.CompletedIterations + } + + // The iterations are returned as part of a run lookup on the foreach task. + iterated, err := w.Jobs.GetRun(ctx, jobs.GetRunRequest{RunId: task.RunId}) + if err != nil { + log.Debugf(ctx, "air get: could not fetch sweep iterations: %v", err) + return info + } + + for _, it := range iterated.Iterations { + row := sweepTask{ + TaskKey: it.TaskKey, + RunID: strconv.FormatInt(it.RunId, 10), + Status: runStatus(it.State), + } + if it.GenAiComputeTask != nil && it.GenAiComputeTask.MlflowExperimentName != "" { + row.Experiment = stripExperimentUserPrefix(it.GenAiComputeTask.MlflowExperimentName) + } + info.Tasks = append(info.Tasks, row) + } + return info +} diff --git a/experimental/air/cmd/sweep_test.go b/experimental/air/cmd/sweep_test.go new file mode 100644 index 0000000000..10134c0df4 --- /dev/null +++ b/experimental/air/cmd/sweep_test.go @@ -0,0 +1,81 @@ +package aircmd + +import ( + "testing" + + "github.com/databricks/databricks-sdk-go/apierr" + "github.com/databricks/databricks-sdk-go/experimental/mocks" + "github.com/databricks/databricks-sdk-go/service/jobs" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func TestFindForEachTask(t *testing.T) { + // No tasks at all. + assert.Nil(t, findForEachTask(&jobs.Run{})) + + // A task that is not a foreach. + assert.Nil(t, findForEachTask(&jobs.Run{Tasks: []jobs.RunTask{{TaskKey: "a"}}})) + + // The foreach task is found even when it isn't first. + run := &jobs.Run{Tasks: []jobs.RunTask{ + {TaskKey: "a"}, + {TaskKey: "sweep", ForEachTask: &jobs.RunForEachTask{}}, + }} + got := findForEachTask(run) + require.NotNil(t, got) + assert.Equal(t, "sweep", got.TaskKey) +} + +func sweepTaskFixture() *jobs.RunTask { + return &jobs.RunTask{ + RunId: 99, + ForEachTask: &jobs.RunForEachTask{ + Stats: &jobs.ForEachStats{TaskRunStats: &jobs.ForEachTaskTaskRunStats{ + TotalIterations: 4, + SucceededIterations: 1, + FailedIterations: 1, + ActiveIterations: 2, + CompletedIterations: 2, + }}, + }, + } +} + +func TestBuildSweepInfo(t *testing.T) { + ctx := t.Context() + + t.Run("counts and iteration rows", func(t *testing.T) { + m := mocks.NewMockWorkspaceClient(t) + m.GetMockJobsAPI().EXPECT().GetRun(mock.Anything, jobs.GetRunRequest{RunId: 99}).Return( + &jobs.Run{Iterations: []jobs.RunTask{{ + TaskKey: "iter_0", + RunId: 100, + State: &jobs.RunState{ResultState: jobs.RunResultStateSuccess}, + GenAiComputeTask: &jobs.GenAiComputeTask{MlflowExperimentName: "/Users/me@example.com/exp"}, + }}}, nil) + + info := buildSweepInfo(ctx, m.WorkspaceClient, sweepTaskFixture()) + assert.Equal(t, 4, info.Total) + assert.Equal(t, 2, info.Completed) + assert.Equal(t, 1, info.Succeeded) + assert.Equal(t, 1, info.Failed) + assert.Equal(t, 2, info.Active) + require.Len(t, info.Tasks, 1) + assert.Equal(t, "iter_0", info.Tasks[0].TaskKey) + assert.Equal(t, "100", info.Tasks[0].RunID) + assert.Equal(t, "SUCCESS", info.Tasks[0].Status) + assert.Equal(t, "exp", info.Tasks[0].Experiment) + }) + + t.Run("iteration lookup failure still returns counts", func(t *testing.T) { + m := mocks.NewMockWorkspaceClient(t) + m.GetMockJobsAPI().EXPECT().GetRun(mock.Anything, jobs.GetRunRequest{RunId: 99}).Return( + nil, apierr.ErrResourceDoesNotExist) + + info := buildSweepInfo(ctx, m.WorkspaceClient, sweepTaskFixture()) + assert.Equal(t, 4, info.Total) + assert.Empty(t, info.Tasks) + }) +}