diff --git a/src/Asn1Decode.sol b/src/Asn1Decode.sol index 97cd4e4..a6780a0 100644 --- a/src/Asn1Decode.sol +++ b/src/Asn1Decode.sol @@ -165,10 +165,8 @@ library Asn1Decode { * @return Uint value of node */ function uintAt(bytes memory der, Asn1Ptr ptr) internal pure returns (uint256) { - require(der[ptr.header()] == 0x02, "Not type INTEGER"); - require(der[ptr.content()] & 0x80 == 0, "Not positive"); - uint256 len = ptr.length(); - return uint256(readBytesN(der, ptr.content(), len) >> (32 - len) * 8); + (uint256 start, uint256 len) = positiveIntegerContent(der, ptr, 32); + return uint256(readBytesN(der, start, len) >> (32 - len) * 8); } /* @@ -178,19 +176,40 @@ library Asn1Decode { * @return 384-bit uint encoded in uint128 and uint256 */ function uint384At(bytes memory der, Asn1Ptr ptr) internal pure returns (uint128, uint256) { + (uint256 start, uint256 valueLength) = positiveIntegerContent(der, ptr, 48); + + if (valueLength > 32) { + uint256 hiLen = valueLength - 32; + return ( + uint128(uint256(readBytesN(der, start, hiLen) >> (32 - hiLen) * 8)), + uint256(readBytesN(der, start + hiLen, 32)) + ); + } + + return (0, uint256(readBytesN(der, start, valueLength) >> (32 - valueLength) * 8)); + } + + function positiveIntegerContent(bytes memory der, Asn1Ptr ptr, uint256 maxValueLength) + private + pure + returns (uint256 start, uint256 valueLength) + { require(der[ptr.header()] == 0x02, "Not type INTEGER"); - require(der[ptr.content()] & 0x80 == 0, "Not positive"); - uint256 valueLength = ptr.length(); - uint256 start = ptr.content(); + valueLength = ptr.length(); + require(valueLength > 0, "invalid INTEGER length"); + start = ptr.content(); + if (der[start] == 0) { - start++; - valueLength--; + if (valueLength > 1) { + require(der[start + 1] & 0x80 == 0x80, "non-canonical INTEGER"); + start++; + valueLength--; + } + } else { + require(der[start] & 0x80 == 0, "Not positive"); } - uint256 shift = 48 - valueLength; - return ( - uint128(uint256(readBytesN(der, start, 16 - shift) >> (128 + shift * 8))), - uint256(readBytesN(der, start + 16 - shift, 32)) - ); + + require(valueLength <= maxValueLength, "invalid INTEGER length"); } /* diff --git a/test/Asn1Decode.t.sol b/test/Asn1Decode.t.sol index ae44910..ccae135 100644 --- a/test/Asn1Decode.t.sol +++ b/test/Asn1Decode.t.sol @@ -20,6 +20,10 @@ contract Asn1Harness { return der.uintAt(der.root()); } + function uint384AtRoot(bytes memory der) external pure returns (uint128 hi, uint256 lo) { + return der.uint384At(der.root()); + } + function timestampAtRoot(bytes memory der) external pure returns (uint256) { return der.timestampAt(der.root()); } @@ -68,6 +72,20 @@ contract Asn1DecodeTest is Test { assertEq(h.uintAtRoot(hex"0203012345"), 0x012345); // INTEGER 0x012345 } + function test_uintAt_requiredLeadingZero() public view { + assertEq(h.uintAtRoot(hex"02020080"), 0x80); + } + + function test_uintAt_unnecessaryLeadingZero_reverts() public { + vm.expectRevert("non-canonical INTEGER"); + h.uintAtRoot(hex"0202007f"); + } + + function test_uintAt_empty_reverts() public { + vm.expectRevert("invalid INTEGER length"); + h.uintAtRoot(hex"0200"); + } + function test_uintAt_notInteger_reverts() public { vm.expectRevert("Not type INTEGER"); h.uintAtRoot(hex"0401ff"); // OCTET STRING, not INTEGER @@ -84,6 +102,22 @@ contract Asn1DecodeTest is Test { h.uintAtRoot(hex"02050000"); // claims 5 content bytes, only 2 present } + function test_uint384At_requiredLeadingZero() public view { + (uint128 hi, uint256 lo) = h.uint384AtRoot(hex"02020080"); + assertEq(hi, 0); + assertEq(lo, 0x80); + } + + function test_uint384At_unnecessaryLeadingZero_reverts() public { + vm.expectRevert("non-canonical INTEGER"); + h.uint384AtRoot(hex"0202007f"); + } + + function test_uint384At_empty_reverts() public { + vm.expectRevert("invalid INTEGER length"); + h.uint384AtRoot(hex"0200"); + } + // --- timestampAt --- function test_timestamp_utcEpoch() public view { @@ -187,11 +221,36 @@ contract Asn1DecodeTest is Test { } function testFuzz_uintAt_positive(uint64 v) public view { - // INTEGER with an explicit 0x00 sign byte so the value is always positive - bytes memory der = abi.encodePacked(bytes1(0x02), bytes1(0x09), bytes1(0x00), bytes8(v)); + bytes memory der = _derEncodeUint64(v); assertEq(h.uintAtRoot(der), v); } + function _derEncodeUint64(uint64 v) internal pure returns (bytes memory) { + if (v == 0) { + return hex"020100"; + } + + bytes8 raw = bytes8(v); + uint256 offset; + while (offset < 8 && raw[offset] == 0) { + offset++; + } + + uint256 len = 8 - offset; + bool needsPad = uint8(raw[offset]) >= 0x80; + bytes memory der = new bytes(2 + len + (needsPad ? 1 : 0)); + der[0] = 0x02; + der[1] = bytes1(uint8(der.length - 2)); + uint256 dst = 2; + if (needsPad) { + der[dst++] = 0x00; + } + for (uint256 i = offset; i < 8; ++i) { + der[dst++] = raw[i]; + } + return der; + } + function _utcTime(string memory s) internal pure returns (bytes memory) { bytes memory b = bytes(s); require(b.length == 13, "test: UTCTime must be 13 chars");