Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion compiler/cpp/src/thrift/generate/t_ocaml_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,9 @@ void t_ocaml_generator::generate_ocaml_struct_reader(ostream& out, t_struct* tst
indent_up();
indent(out) << "let " << str << " = new " << sname << " in" << '\n';
indent_up();
indent(out) << "iprot#increment_recursion_depth;" << '\n';
indent(out) << "(Fun.protect ~finally:(fun () -> iprot#decrement_recursion_depth) (fun () ->" << '\n';
indent_up();
indent(out) << "ignore(iprot#readStructBegin);" << '\n';

// Loop over reading in fields
Expand Down Expand Up @@ -803,7 +806,9 @@ void t_ocaml_generator::generate_ocaml_struct_reader(ostream& out, t_struct* tst
indent_down();
indent(out) << "with Break -> ());" << '\n';

indent(out) << "iprot#readStructEnd;" << '\n';
indent(out) << "iprot#readStructEnd" << '\n';
indent_down();
indent(out) << "));" << '\n';

indent(out) << str << '\n' << '\n';
indent_down();
Expand All @@ -819,6 +824,9 @@ void t_ocaml_generator::generate_ocaml_struct_writer(ostream& out, t_struct* tst

indent(out) << "method write (oprot : Protocol.t) =" << '\n';
indent_up();
indent(out) << "oprot#increment_recursion_depth;" << '\n';
indent(out) << "Fun.protect ~finally:(fun () -> oprot#decrement_recursion_depth) (fun () ->" << '\n';
indent_up();
indent(out) << "oprot#writeStructBegin \"" << name << "\";" << '\n';

for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) {
Expand Down Expand Up @@ -878,6 +886,8 @@ void t_ocaml_generator::generate_ocaml_struct_writer(ostream& out, t_struct* tst

// Write the struct map
out << indent() << "oprot#writeFieldStop;" << '\n' << indent() << "oprot#writeStructEnd" << '\n';
indent_down();
indent(out) << ")" << '\n';

indent_down();
}
Expand Down
90 changes: 51 additions & 39 deletions compiler/cpp/tests/ocaml/snapshot_service_handle_ex.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,24 +27,30 @@ object (self)
let _new = Oo.copy self in
_new
method write (oprot : Protocol.t) =
oprot#writeStructBegin "ping_args";
oprot#writeFieldStop;
oprot#writeStructEnd
oprot#increment_recursion_depth;
Fun.protect ~finally:(fun () -> oprot#decrement_recursion_depth) (fun () ->
oprot#writeStructBegin "ping_args";
oprot#writeFieldStop;
oprot#writeStructEnd
)
end
let rec read_ping_args (iprot : Protocol.t) =
let _str2 = new ping_args in
ignore(iprot#readStructBegin);
(try while true do
let (_,_t3,_id4) = iprot#readFieldBegin in
if _t3 = Protocol.T_STOP then
raise Break
else ();
(match _id4 with
| _ -> iprot#skip _t3);
iprot#readFieldEnd;
done; ()
with Break -> ());
iprot#readStructEnd;
iprot#increment_recursion_depth;
(Fun.protect ~finally:(fun () -> iprot#decrement_recursion_depth) (fun () ->
ignore(iprot#readStructBegin);
(try while true do
let (_,_t3,_id4) = iprot#readFieldBegin in
if _t3 = Protocol.T_STOP then
raise Break
else ();
(match _id4 with
| _ -> iprot#skip _t3);
iprot#readFieldEnd;
done; ()
with Break -> ());
iprot#readStructEnd
));
_str2

class ping_result =
Expand All @@ -62,33 +68,39 @@ object (self)
_new#set_serverError self#grab_serverError#copy;
_new
method write (oprot : Protocol.t) =
oprot#writeStructBegin "ping_result";
(match _serverError with None -> () | Some _v ->
oprot#writeFieldBegin("serverError",Protocol.T_STRUCT,1);
_v#write(oprot);
oprot#writeFieldEnd
);
oprot#writeFieldStop;
oprot#writeStructEnd
oprot#increment_recursion_depth;
Fun.protect ~finally:(fun () -> oprot#decrement_recursion_depth) (fun () ->
oprot#writeStructBegin "ping_result";
(match _serverError with None -> () | Some _v ->
oprot#writeFieldBegin("serverError",Protocol.T_STRUCT,1);
_v#write(oprot);
oprot#writeFieldEnd
);
oprot#writeFieldStop;
oprot#writeStructEnd
)
end
let rec read_ping_result (iprot : Protocol.t) =
let _str8 = new ping_result in
ignore(iprot#readStructBegin);
(try while true do
let (_,_t9,_id10) = iprot#readFieldBegin in
if _t9 = Protocol.T_STOP then
raise Break
else ();
(match _id10 with
| 1 -> (if _t9 = Protocol.T_STRUCT then
_str8#set_serverError (Errors_types.read_serverError iprot)
else
iprot#skip _t9)
| _ -> iprot#skip _t9);
iprot#readFieldEnd;
done; ()
with Break -> ());
iprot#readStructEnd;
iprot#increment_recursion_depth;
(Fun.protect ~finally:(fun () -> iprot#decrement_recursion_depth) (fun () ->
ignore(iprot#readStructBegin);
(try while true do
let (_,_t9,_id10) = iprot#readFieldBegin in
if _t9 = Protocol.T_STOP then
raise Break
else ();
(match _id10 with
| 1 -> (if _t9 = Protocol.T_STRUCT then
_str8#set_serverError (Errors_types.read_serverError iprot)
else
iprot#skip _t9)
| _ -> iprot#skip _t9);
iprot#readFieldEnd;
done; ()
with Break -> ());
iprot#readStructEnd
));
_str8

class virtual iface =
Expand Down
33 changes: 21 additions & 12 deletions lib/ocaml/src/Thrift.ml
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,30 @@ struct
| 4 -> ONEWAY
| _ -> raise Thrift_error

type exn_type =
| UNKNOWN
| INVALID_DATA
| NEGATIVE_SIZE
| SIZE_LIMIT
| BAD_VERSION
| NOT_IMPLEMENTED
| DEPTH_LIMIT

exception E of exn_type * string;;

class virtual t (trans: Transport.t) =
object (self)
val mutable trans_ = trans
val mutable recursion_depth_ = 0
method getTransport = trans_
method increment_recursion_depth =
recursion_depth_ <- recursion_depth_ + 1;
if recursion_depth_ > 64 then begin
recursion_depth_ <- recursion_depth_ - 1;
raise (E (DEPTH_LIMIT, "Maximum recursion depth exceeded"))
end
method decrement_recursion_depth =
recursion_depth_ <- recursion_depth_ - 1
(* writing methods *)
method virtual writeMessageBegin : string * message_type * int -> unit
method virtual writeMessageEnd : unit
Expand Down Expand Up @@ -246,25 +266,14 @@ struct
self#readListEnd)
| T_UTF8 -> ()
| T_UTF16 -> ()
| _ -> raise (Protocol.E (Protocol.INVALID_DATA, "Invalid data"))
| _ -> raise (E (INVALID_DATA, "Invalid data"))
end

class virtual factory =
object
method virtual getProtocol : Transport.t -> t
end

type exn_type =
| UNKNOWN
| INVALID_DATA
| NEGATIVE_SIZE
| SIZE_LIMIT
| BAD_VERSION
| NOT_IMPLEMENTED
| DEPTH_LIMIT

exception E of exn_type * string;;

end;;


Expand Down
38 changes: 38 additions & 0 deletions lib/ocaml/test/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#

# Self-contained recursion-depth test (THRIFT-6051). It depends only on
# lib/ocaml/src/Thrift.ml (which holds the depth guard); it deliberately does
# not use TBinaryProtocol or generated code, neither of which compiles on
# modern OCaml. Build and run with: make
SRC = ../src
OCAMLC = ocamlc -thread unix.cma threads.cma -I $(SRC)

all: test

test: test_recursion_depth
./test_recursion_depth

test_recursion_depth: $(SRC)/Thrift.ml test_recursion_depth.ml
$(OCAMLC) $(SRC)/Thrift.ml test_recursion_depth.ml -o test_recursion_depth

clean:
rm -f *.cm* *.o test_recursion_depth $(SRC)/Thrift.cm* $(SRC)/Thrift.o

.PHONY: all test clean
Loading
Loading