diff --git a/internal/sql_workbench/service/sql_workbench_service.go b/internal/sql_workbench/service/sql_workbench_service.go index ef8fb3ac..33af90b4 100644 --- a/internal/sql_workbench/service/sql_workbench_service.go +++ b/internal/sql_workbench/service/sql_workbench_service.go @@ -1441,6 +1441,19 @@ func (sqlWorkbenchService *SqlWorkbenchService) isEnableSQLAudit(dbService *biz. return dbService.SQLEConfig.AuditEnabled && dbService.SQLEConfig.SQLQueryConfig.AuditEnabled } +// normalizeSQLEAuditSchemaName 将 ODC SQL Server 的 database.schema 组合归一化为 SQLE 连接库名。 +// ODC 将 SQL Server 的 database.schema 平铺为 schema 名(如 TestDB.dbo), +// 而 SQLE 直连审核会把 schema_name 当作连接的数据库名。 +func normalizeSQLEAuditSchemaName(dbType, schemaName string) string { + if dbType != string(pkgConst.DBTypeSQLServer) { + return schemaName + } + if idx := strings.Index(schemaName, "."); idx > 0 { + return schemaName[:idx] + } + return schemaName +} + // callSQLEAudit 调用 SQLE 直接审核接口 func (sqlWorkbenchService *SqlWorkbenchService) callSQLEAudit(ctx context.Context, sql string, dbService *biz.DBService, schemaName string) (*cloudbeaver.AuditSQLReply, error) { // 获取 SQLE 服务地址 @@ -1450,6 +1463,7 @@ func (sqlWorkbenchService *SqlWorkbenchService) callSQLEAudit(ctx context.Contex } sqleAddr := fmt.Sprintf("%s/v2/sql_audit", target.URL.String()) + schemaName = normalizeSQLEAuditSchemaName(dbService.DBType, schemaName) auditReq := cloudbeaver.AuditSQLReq{ InstanceType: dbService.DBType, diff --git a/internal/sql_workbench/service/sql_workbench_service_test.go b/internal/sql_workbench/service/sql_workbench_service_test.go index 090ad66c..b73860ba 100644 --- a/internal/sql_workbench/service/sql_workbench_service_test.go +++ b/internal/sql_workbench/service/sql_workbench_service_test.go @@ -367,3 +367,45 @@ func Test_buildOdcCreateAndUpdateRequests_setPasswordSaved(t *testing.T) { t.Fatalf("expected passwordSaved in update JSON: %s", updateJSON) } } + +func Test_normalizeSQLEAuditSchemaName(t *testing.T) { + cases := map[string]struct { + dbType string + schema string + expected string + }{ + "SQL Server database.schema": { + dbType: string(pkgConst.DBTypeSQLServer), + schema: "TestDB.dbo", + expected: "TestDB", + }, + "SQL Server catalog only": { + dbType: string(pkgConst.DBTypeSQLServer), + schema: "TestDB", + expected: "TestDB", + }, + "SQL Server non-dbo schema": { + dbType: string(pkgConst.DBTypeSQLServer), + schema: "TestDB.sales", + expected: "TestDB", + }, + "MySQL schema unchanged": { + dbType: string(pkgConst.DBTypeMySQL), + schema: "app.db", + expected: "app.db", + }, + "Oracle schema unchanged": { + dbType: string(pkgConst.DBTypeOracle), + schema: "HR", + expected: "HR", + }, + } + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + got := normalizeSQLEAuditSchemaName(tc.dbType, tc.schema) + if got != tc.expected { + t.Fatalf("normalizeSQLEAuditSchemaName(%q, %q) = %q, want %q", tc.dbType, tc.schema, got, tc.expected) + } + }) + } +}