diff --git a/drivers/accel/amdxdna/amdxdna_ctx.c b/drivers/accel/amdxdna/amdxdna_ctx.c index ec8953f28..094da3bac 100644 --- a/drivers/accel/amdxdna/amdxdna_ctx.c +++ b/drivers/accel/amdxdna/amdxdna_ctx.c @@ -142,6 +142,9 @@ void *amdxdna_cmd_get_payload(struct amdxdna_gem_obj *abo, u32 *size) else num_masks = 1 + FIELD_GET(AMDXDNA_CMD_EXTRA_CU_MASK, cmd->header); + if (abo->mem.size < offsetof(struct amdxdna_cmd, data[num_masks])) + return NULL; + if (size) { count = FIELD_GET(AMDXDNA_CMD_COUNT, cmd->header); if (unlikely(count <= num_masks || @@ -195,6 +198,8 @@ int amdxdna_cmd_set_error(struct amdxdna_gem_obj *abo, if (amdxdna_cmd_get_op(abo) == ERT_CMD_CHAIN) { cc = amdxdna_cmd_get_payload(abo, NULL); + if (!cc) + return -ENOMEM; cc->error_index = (cmd_idx < cc->command_count) ? cmd_idx : 0; abo = amdxdna_gem_get_obj(client, cc->data[0], AMDXDNA_BO_SHARE); if (!abo) diff --git a/src/driver/amdxdna/aie4_hwctx.c b/src/driver/amdxdna/aie4_hwctx.c index f750cc7c4..7180bb6ad 100644 --- a/src/driver/amdxdna/aie4_hwctx.c +++ b/src/driver/amdxdna/aie4_hwctx.c @@ -715,6 +715,8 @@ static int submit_one_cmd(struct amdxdna_ctx *ctx, } dpu = amdxdna_cmd_get_payload(cmd_abo, NULL); + if (!dpu) + return -ENOMEM; chained = dpu->chained; if (chained >= HSA_MAX_LEVEL1_INDIRECT_ENTRIES) { XDNA_ERR(xdna, "Invalid DPU data"); diff --git a/src/driver/amdxdna/amdxdna_ctx.h b/src/driver/amdxdna/amdxdna_ctx.h index 727d25897..fc0c8d913 100644 --- a/src/driver/amdxdna/amdxdna_ctx.h +++ b/src/driver/amdxdna/amdxdna_ctx.h @@ -347,6 +347,9 @@ amdxdna_cmd_get_payload(struct amdxdna_gem_obj *abo, u32 *size) else num_masks = 1 + FIELD_GET(AMDXDNA_CMD_EXTRA_CU_MASK, cmd->header); + if (abo->mem.size < offsetof(struct amdxdna_cmd, data[num_masks])) + return NULL; + if (size) { count = FIELD_GET(AMDXDNA_CMD_COUNT, cmd->header); if (unlikely(count <= num_masks || diff --git a/src/driver/amdxdna/ve2_hwctx.c b/src/driver/amdxdna/ve2_hwctx.c index 9cdbcabc0..10d4da90c 100644 --- a/src/driver/amdxdna/ve2_hwctx.c +++ b/src/driver/amdxdna/ve2_hwctx.c @@ -336,7 +336,8 @@ static inline void ve2_hwctx_job_release_locked(struct amdxdna_ctx *hwctx, op = amdxdna_cmd_get_op(cmd_bo); if (op == ERT_CMD_CHAIN) { cmd_chain = amdxdna_cmd_get_payload(cmd_bo, NULL); - cmd_cnt = cmd_chain->command_count; + if (cmd_chain) + cmd_cnt = cmd_chain->command_count; } hwctx->completed += cmd_cnt;