From 07e7c382a570a9e80abb83cfa9899e4d27cdd138 Mon Sep 17 00:00:00 2001 From: Jens Geyer Date: Thu, 28 May 2026 01:10:31 +0200 Subject: [PATCH] THRIFT-6049: Limit struct read/write recursion depth in Lua library Client: lua The Lua generator wraps each generated read()/write() body with incrementRecursionDepth()/pcall/decrementRecursionDepth() on TProtocolBase (limit DEFAULT_RECURSION_DEPTH = 64, TProtocolException.DEPTH_LIMIT on excess). Previously only skip() was bounded, so the generated read/write path recursed without limit. Unions and exceptions are generated through the same path, so they are bounded too. Replace the isolated counter test with a generated-code round-trip (lib/lua/test/test_recursion_depth.lua) over a recursive struct (RecTree), union (RecUnion) and exception (RecError) from a new RecursionDepth.thrift, driven through the generated read()/write() over a real TBinaryProtocol: a chain at the limit round-trips (also proving the guard does not double-count), a chain past it is rejected on write, and a hand-serialized over-limit payload is rejected on read (crafted with the real recursive field, id 1 = list, so the reader recurses through the guarded path rather than the separate, unbounded skip()). Co-Authored-By: Claude Sonnet 4.6 Co-Authored-By: Claude Opus 4.8 --- .../src/thrift/generate/t_lua_generator.cc | 14 ++ lib/lua/TProtocol.lua | 18 ++ lib/lua/test/RecursionDepth.thrift | 37 ++++ lib/lua/test/test_recursion_depth.lua | 174 ++++++++++++++++++ 4 files changed, 243 insertions(+) create mode 100644 lib/lua/test/RecursionDepth.thrift create mode 100644 lib/lua/test/test_recursion_depth.lua diff --git a/compiler/cpp/src/thrift/generate/t_lua_generator.cc b/compiler/cpp/src/thrift/generate/t_lua_generator.cc index f7f8f054e08..54c1fd7097f 100644 --- a/compiler/cpp/src/thrift/generate/t_lua_generator.cc +++ b/compiler/cpp/src/thrift/generate/t_lua_generator.cc @@ -421,6 +421,9 @@ void t_lua_generator::generate_lua_struct_reader(ostream& out, t_struct* tstruct indent(out) << '\n' << '\n' << "function " << tstruct->get_name() << ":read(iprot)" << '\n'; indent_up(); + indent(out) << "iprot:incrementRecursionDepth()" << '\n'; + indent(out) << "local ok, err = pcall(function()" << '\n'; + indent_up(); indent(out) << "iprot:readStructBegin()" << '\n'; // while: Read in fields @@ -460,6 +463,10 @@ void t_lua_generator::generate_lua_struct_reader(ostream& out, t_struct* tstruct indent_down(); indent(out) << "end" << '\n'; indent(out) << "iprot:readStructEnd()" << '\n'; + indent_down(); + indent(out) << "end)" << '\n'; + indent(out) << "iprot:decrementRecursionDepth()" << '\n'; + indent(out) << "if not ok then error(err, 0) end" << '\n'; // end function indent_down(); @@ -478,6 +485,9 @@ void t_lua_generator::generate_lua_struct_writer(ostream& out, t_struct* tstruct indent(out) << '\n' << '\n' << "function " << tstruct->get_name() << ":write(oprot)" << '\n'; indent_up(); + indent(out) << "oprot:incrementRecursionDepth()" << '\n'; + indent(out) << "local ok, err = pcall(function()" << '\n'; + indent_up(); indent(out) << "oprot:writeStructBegin('" << tstruct->get_name() << "')" << '\n'; for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { // To check element of self whether nil or not. @@ -497,6 +507,10 @@ void t_lua_generator::generate_lua_struct_writer(ostream& out, t_struct* tstruct } indent(out) << "oprot:writeFieldStop()" << '\n'; indent(out) << "oprot:writeStructEnd()" << '\n'; + indent_down(); + indent(out) << "end)" << '\n'; + indent(out) << "oprot:decrementRecursionDepth()" << '\n'; + indent(out) << "if not ok then error(err, 0) end" << '\n'; // end function indent_down(); diff --git a/lib/lua/TProtocol.lua b/lib/lua/TProtocol.lua index f7a993f0b50..8a86e756970 100644 --- a/lib/lua/TProtocol.lua +++ b/lib/lua/TProtocol.lua @@ -48,6 +48,8 @@ function TProtocolException:__errorCodeToString() end end +DEFAULT_RECURSION_DEPTH = 64 + TProtocolBase = __TObject:new{ __type = 'TProtocolBase', trans @@ -63,9 +65,25 @@ function TProtocolBase:new(obj) error('You must provide ' .. ttype(self) .. ' with a trans') end + obj.recursionDepth = 0 return __TObject.new(self, obj) end +function TProtocolBase:incrementRecursionDepth() + self.recursionDepth = self.recursionDepth + 1 + if self.recursionDepth > DEFAULT_RECURSION_DEPTH then + self.recursionDepth = self.recursionDepth - 1 + terror(TProtocolException:new{ + message = 'Maximum recursion depth exceeded', + errorCode = TProtocolException.DEPTH_LIMIT + }) + end +end + +function TProtocolBase:decrementRecursionDepth() + self.recursionDepth = self.recursionDepth - 1 +end + function TProtocolBase:writeMessageBegin(name, ttype, seqid) end function TProtocolBase:writeMessageEnd() end function TProtocolBase:writeStructBegin(name) end diff --git a/lib/lua/test/RecursionDepth.thrift b/lib/lua/test/RecursionDepth.thrift new file mode 100644 index 00000000000..bbeec3940bf --- /dev/null +++ b/lib/lua/test/RecursionDepth.thrift @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +// A self-referential struct, union and exception used to exercise the +// recursion-depth guard in the generated read()/write() code. Recursion runs +// through a list so the types are expressible (no by-value cycle). + +struct RecTree { + 1: list children + 2: i16 item +} + +union RecUnion { + 1: list children + 2: i32 leaf +} + +exception RecError { + 1: list children + 2: i32 leaf +} diff --git a/lib/lua/test/test_recursion_depth.lua b/lib/lua/test/test_recursion_depth.lua new file mode 100644 index 00000000000..8e8c0712557 --- /dev/null +++ b/lib/lua/test/test_recursion_depth.lua @@ -0,0 +1,174 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. +-- + +-- Drives the recursion-depth guard through the *generated* read()/write() +-- code over a recursive struct (RecTree), union (RecUnion) and exception +-- (RecError), not the protocol counter in isolation. +-- +-- Run after generating the types next to this file: +-- thrift -o lib/lua/test --gen lua lib/lua/test/RecursionDepth.thrift +-- lua lib/lua/test/test_recursion_depth.lua + +-- Locate the library (../) and the generated code (gen-lua/) relative to this +-- script so it can be run from anywhere. +local script_dir = arg[0]:match('(.*[/\\])') or './' +package.path = script_dir .. '../?.lua;' .. + script_dir .. 'gen-lua/?.lua;' .. package.path + +-- The recursion guard under test is pure Lua (the generated read/write and the +-- TProtocolBase counter). The C extensions only do byte packing, which Lua 5.3+ +-- provides natively, so functional pure-Lua stand-ins let the test run without +-- building them while still exercising the real generated serialization path. +package.preload['libluabitwise'] = function() + return { + bor = function(a, b) return a | b end, + band = function(a, b) return a & b end, + bxor = function(a, b) return a ~ b end, + shiftl = function(a, n) return (a << n) & 0xFFFFFFFF end, + shiftr = function(a, n) return a >> n end, + } +end +package.preload['libluabpack'] = function() + local fmt = {c = '>i1', C = '>I1', s = '>i2', S = '>I2', + i = '>i4', I = '>I4', l = '>i8', d = '>d'} + return { + bpack = function(code, val) return string.pack(fmt[code], val) end, + bunpack = function(code, data) return (string.unpack(fmt[code], data)) end, + } +end +package.preload['liblualongnumber'] = function() + -- i64 is never exercised here (the test types use i16/i32 only). + return {new = function(_, v) return v or 0 end, tonumber = function(v) return v end} +end + +require('Thrift') +require('TTransport') +require('TMemoryBuffer') +require('TProtocol') +require('TBinaryProtocol') +require('RecursionDepth_ttypes') + +local LIMIT = DEFAULT_RECURSION_DEPTH + +local passed, failed = 0, 0 +local function ok(cond, name, detail) + if cond then + print('ok - ' .. name) + passed = passed + 1 + else + print('not ok - ' .. name .. (detail and (' # ' .. tostring(detail)) or '')) + failed = failed + 1 + end +end + +local function new_proto() + return TBinaryProtocol:new{trans = TMemoryBuffer:new{}} +end + +-- Build a chain of `depth` nested nodes; each inner node carries a single +-- child, the deepest carries the scalar leaf -- a valid shape for both the +-- struct and the union. +local function make_chain(class, leaf_field, depth) + local node = class:new{} + if depth > 1 then + node.children = {make_chain(class, leaf_field, depth - 1)} + else + node[leaf_field] = 1 + end + return node +end + +local function chain_depth(node) + local depth = 1 + if node.children and #node.children > 0 then + depth = depth + chain_depth(node.children[1]) + end + return depth +end + +-- Serialize an over-limit payload with raw protocol primitives so the reader +-- recurses through the guarded struct path (field id 1 = list), not the +-- separate (unbounded) skip() path. +local function write_deep(oprot, depth) + oprot:writeStructBegin('Rec') + if depth > 1 then + oprot:writeFieldBegin('children', TType.LIST, 1) + oprot:writeListBegin(TType.STRUCT, 1) + write_deep(oprot, depth - 1) + oprot:writeListEnd() + oprot:writeFieldEnd() + end + oprot:writeFieldStop() + oprot:writeStructEnd() +end + +-- True only for the depth-limit rejection -- an EOF/structural error would +-- carry a different type/message. +local function is_depth_error(err) + return type(err) == 'string' + and err:find('TProtocolException') ~= nil + and err:find('recursion') ~= nil +end + +local cases = { + {kind = 'struct', class = RecTree, leaf = 'item'}, + {kind = 'union', class = RecUnion, leaf = 'leaf'}, + {kind = 'exception', class = RecError, leaf = 'leaf'}, +} + +for _, case in ipairs(cases) do + local kind, class, leaf = case.kind, case.class, case.leaf + + -- 1. A chain exactly at the limit round-trips. Also proves the guard does + -- not double-count (a chain of 64 would be rejected at 32 otherwise). + do + local proto = new_proto() + local chain = make_chain(class, leaf, LIMIT) + local wok, werr = pcall(function() chain:write(proto) end) + ok(wok, kind .. ': writing a chain at the depth limit succeeds', werr) + + local decoded = class:new{} + local rok, rerr = pcall(function() decoded:read(proto) end) + ok(rok, kind .. ': reading a chain at the depth limit succeeds', rerr) + ok(chain_depth(decoded) == LIMIT, + kind .. ': round-trips to the original depth (' .. LIMIT .. ')') + end + + -- 2. Writing past the limit is rejected with a depth-limit error. + do + local proto = new_proto() + local chain = make_chain(class, leaf, LIMIT + 5) + local wok, werr = pcall(function() chain:write(proto) end) + ok(not wok, kind .. ': writing past the limit throws') + ok(is_depth_error(werr), kind .. ': ... with a recursion-depth error', werr) + end + + -- 3. Reading an over-limit payload is rejected with a depth-limit error. + do + local proto = new_proto() + write_deep(proto, LIMIT + 5) + local decoded = class:new{} + local rok, rerr = pcall(function() decoded:read(proto) end) + ok(not rok, kind .. ': reading past the limit throws') + ok(is_depth_error(rerr), kind .. ': ... with a recursion-depth error', rerr) + end +end + +print(string.format('\n%d passed, %d failed', passed, failed)) +if failed > 0 then os.exit(1) end