diff --git a/snuba/web/rpc/common/common.py b/snuba/web/rpc/common/common.py index 37583c051c..0e4c2a14b2 100644 --- a/snuba/web/rpc/common/common.py +++ b/snuba/web/rpc/common/common.py @@ -15,6 +15,7 @@ from snuba.protos.common import ( MalformedAttributeException, type_array_to_membership_array_expression, + type_array_to_stored_array_json_path, ) from snuba.protos.common import ( attribute_key_to_expression as _attribute_key_to_expression, @@ -35,6 +36,7 @@ Argument, Expression, FunctionCall, + JsonPath, Lambda, SubscriptableReference, ) @@ -348,6 +350,15 @@ def _attribute_value_to_expression(v: AttributeValue) -> Expression: "val_double_array", } +_NUMBER_VALUE_TYPES = {"val_int", "val_float", "val_double"} + +_NUMERIC_COMPARISON_OP_TO_FN: dict[int, Callable[..., FunctionCall]] = { + ComparisonFilter.OP_LESS_THAN: f.less, + ComparisonFilter.OP_LESS_THAN_OR_EQUALS: f.lessOrEquals, + ComparisonFilter.OP_GREATER_THAN: f.greater, + ComparisonFilter.OP_GREATER_THAN_OR_EQUALS: f.greaterOrEquals, +} + def _validate_comparison_filter_type_array( op: ComparisonFilter.Op.ValueType, v: AttributeValue @@ -358,12 +369,17 @@ def _validate_comparison_filter_type_array( "LIKE/NOT_LIKE on array keys requires a string pattern" ) return + if op in _NUMERIC_COMPARISON_OP_TO_FN: + if v.WhichOneof("value") not in _NUMBER_VALUE_TYPES: + raise BadSnubaRPCRequestException( + f"{ComparisonFilter.Op.Name(op)} on array keys requires a numeric value" + "(val_int, val_float, val_double)" + ) + return if op in (ComparisonFilter.OP_EQUALS, ComparisonFilter.OP_NOT_EQUALS): - # Array can be empty or non-empty. It can never be null, or can never have null elements. vt = v.WhichOneof("value") + # Nested Arrays are not Supported, Array value in SET not implemented yet. if vt in ( - None, - "val_null", "val_array", "val_str_array", "val_int_array", @@ -377,7 +393,8 @@ def _validate_comparison_filter_type_array( return raise BadSnubaRPCRequestException( f"{ComparisonFilter.Op.Name(op)} is not supported on array keys " - "(supported: LIKE, NOT_LIKE, OP_EQUALS, OP_NOT_EQUALS)" + "(supported: OP_EQUALS, OP_NOT_EQUALS, OP_LIKE, OP_NOT_LIKE, " + "OP_LESS_THAN, OP_LESS_THAN_OR_EQUALS, OP_GREATER_THAN, OP_GREATER_THAN_OR_EQUALS)" ) @@ -419,6 +436,27 @@ def _type_array_includes_scalar_expression( return f.arrayExists(Lambda(None, ("x",), f.equals(x, rhs)), array_expr) +def _type_array_numeric_comparison_expression( + attr_key: AttributeKey, + op: ComparisonFilter.Op.ValueType, + v: AttributeValue, +) -> Expression: + """Per-element numeric comparison (>, <, >=, <=) for TYPE_ARRAY keys. + Non-numeric comparison yields NULL through coalesce and silently fail through predicate + """ + array_expr = type_array_to_stored_array_json_path(attr_key) + x = Argument(None, "x") + if v.WhichOneof("value") == "val_int": + element = JsonPath(None, x, "Int", "Nullable(Int64)") + else: + element = JsonPath(None, x, "Double", "Nullable(Float64)") + rhs = _attribute_value_to_expression(v) + return f.arrayExists( + Lambda(None, ("x",), _NUMERIC_COMPARISON_OP_TO_FN[op](element, rhs)), + array_expr, + ) + + def _any_attribute_filter_to_expression( filt: AnyAttributeFilter, ) -> Expression: @@ -677,12 +715,20 @@ def trace_item_filters_to_expression( expr_with_null = or_cond(expr, f.isNull(k_expression)) return expr_with_null if op == ComparisonFilter.OP_LESS_THAN: + if k.type == AttributeKey.Type.TYPE_ARRAY: + return _type_array_numeric_comparison_expression(k, op, v) return f.less(k_expression, v_expression) if op == ComparisonFilter.OP_LESS_THAN_OR_EQUALS: + if k.type == AttributeKey.Type.TYPE_ARRAY: + return _type_array_numeric_comparison_expression(k, op, v) return f.lessOrEquals(k_expression, v_expression) if op == ComparisonFilter.OP_GREATER_THAN: + if k.type == AttributeKey.Type.TYPE_ARRAY: + return _type_array_numeric_comparison_expression(k, op, v) return f.greater(k_expression, v_expression) if op == ComparisonFilter.OP_GREATER_THAN_OR_EQUALS: + if k.type == AttributeKey.Type.TYPE_ARRAY: + return _type_array_numeric_comparison_expression(k, op, v) return f.greaterOrEquals(k_expression, v_expression) if op == ComparisonFilter.OP_IN: _check_non_string_values_cannot_ignore_case(item_filter.comparison_filter) diff --git a/tests/web/rpc/test_common.py b/tests/web/rpc/test_common.py index ae18b6a5be..49af370af8 100644 --- a/tests/web/rpc/test_common.py +++ b/tests/web/rpc/test_common.py @@ -31,7 +31,7 @@ from snuba.datasets.storages.factory import get_storage from snuba.datasets.storages.storage_key import StorageKey from snuba.protos.common import ATTRIBUTES_TO_COALESCE -from snuba.query.expressions import FunctionCall, Lambda, Literal +from snuba.query.expressions import Argument, Expression, FunctionCall, JsonPath, Lambda, Literal from snuba.web.rpc.common.common import ( _any_attribute_filter_to_expression, attribute_key_to_expression, @@ -246,6 +246,152 @@ def test_not_like_on_int_key_raises(self) -> None: trace_item_filters_to_expression(item_filter, attribute_key_to_expression) +class TestTraceItemFiltersArrayNumericComparison: + """Per-element numeric comparison (>, <, >=, <=) for TYPE_ARRAY keys""" + + @staticmethod + def _make_comparison_filter( + attr_name: str, + attr_type: AttributeKey.Type.ValueType, + op: ComparisonFilter.Op.ValueType, + value: AttributeValue, + ) -> TraceItemFilter: + return TraceItemFilter( + comparison_filter=ComparisonFilter( + key=AttributeKey(type=attr_type, name=attr_name), + op=op, + value=value, + ) + ) + + @staticmethod + def _assert_array_numeric_comparison_shape( + expr: Expression, + expected_ch_fn: str, + expected_literal_value: float | int, + expected_attr_name: str, + ) -> None: + """arrayExists(x -> (x., ), attributes_array[name]) + + Element tag is derived from the literal's Python type: + int (not bool) -> x.Int (Nullable(Int64)); float -> x.Double (Nullable(Float64)). + """ + if isinstance(expected_literal_value, bool) or not isinstance(expected_literal_value, int): + expected_tag, expected_return_type = "Double", "Nullable(Float64)" + else: + expected_tag, expected_return_type = "Int", "Nullable(Int64)" + + assert isinstance(expr, FunctionCall) and expr.function_name == "arrayExists" + lam, array_expr = expr.parameters + assert isinstance(lam, Lambda) and lam.parameters == ("x",) + assert isinstance(array_expr, JsonPath) + assert (array_expr.path, array_expr.return_type) == (expected_attr_name, "Array(JSON)") + + comparison = lam.transformation + assert isinstance(comparison, FunctionCall) + assert comparison.function_name == expected_ch_fn + element, rhs = comparison.parameters + assert isinstance(element, JsonPath) + assert isinstance(element.base, Argument) and element.base.name == "x" + assert (element.path, element.return_type) == (expected_tag, expected_return_type) + assert isinstance(rhs, Literal) and rhs.value == expected_literal_value + + def test_greater_than_on_array_key_with_val_int(self) -> None: + item_filter = self._make_comparison_filter( + "my_nums", + AttributeKey.Type.TYPE_ARRAY, + ComparisonFilter.OP_GREATER_THAN, + AttributeValue(val_int=5), + ) + result = trace_item_filters_to_expression(item_filter, attribute_key_to_expression) + self._assert_array_numeric_comparison_shape(result, "greater", 5, "my_nums") + + def test_less_than_on_array_key_with_val_double(self) -> None: + item_filter = self._make_comparison_filter( + "my_nums", + AttributeKey.Type.TYPE_ARRAY, + ComparisonFilter.OP_LESS_THAN, + AttributeValue(val_double=1.5), + ) + result = trace_item_filters_to_expression(item_filter, attribute_key_to_expression) + self._assert_array_numeric_comparison_shape(result, "less", 1.5, "my_nums") + + def test_greater_or_equals_on_array_key_with_val_float(self) -> None: + item_filter = self._make_comparison_filter( + "my_nums", + AttributeKey.Type.TYPE_ARRAY, + ComparisonFilter.OP_GREATER_THAN_OR_EQUALS, + AttributeValue(val_float=2.0), + ) + result = trace_item_filters_to_expression(item_filter, attribute_key_to_expression) + self._assert_array_numeric_comparison_shape(result, "greaterOrEquals", 2.0, "my_nums") + + def test_less_or_equals_on_array_key_with_val_int(self) -> None: + item_filter = self._make_comparison_filter( + "my_nums", + AttributeKey.Type.TYPE_ARRAY, + ComparisonFilter.OP_LESS_THAN_OR_EQUALS, + AttributeValue(val_int=10), + ) + result = trace_item_filters_to_expression(item_filter, attribute_key_to_expression) + self._assert_array_numeric_comparison_shape(result, "lessOrEquals", 10, "my_nums") + + @pytest.mark.parametrize( + "op", + [ + ComparisonFilter.OP_LESS_THAN, + ComparisonFilter.OP_LESS_THAN_OR_EQUALS, + ComparisonFilter.OP_GREATER_THAN, + ComparisonFilter.OP_GREATER_THAN_OR_EQUALS, + ], + ) + def test_numeric_comparison_on_array_key_with_val_str_raises( + self, op: ComparisonFilter.Op.ValueType + ) -> None: + item_filter = self._make_comparison_filter( + "my_nums", + AttributeKey.Type.TYPE_ARRAY, + op, + AttributeValue(val_str="5"), + ) + with pytest.raises( + BadSnubaRPCRequestException, + match="on array keys requires a numeric value", + ): + trace_item_filters_to_expression(item_filter, attribute_key_to_expression) + + def test_numeric_comparison_on_array_key_with_val_bool_raises(self) -> None: + item_filter = self._make_comparison_filter( + "my_nums", + AttributeKey.Type.TYPE_ARRAY, + ComparisonFilter.OP_GREATER_THAN, + AttributeValue(val_bool=True), + ) + with pytest.raises( + BadSnubaRPCRequestException, + match="on array keys requires a numeric value", + ): + trace_item_filters_to_expression(item_filter, attribute_key_to_expression) + + def test_negated_greater_than_on_array_key_wraps_in_not(self) -> None: + """!array[*]:>5 composes as not_filter wrapping the comparison filter.""" + inner = self._make_comparison_filter( + "my_nums", + AttributeKey.Type.TYPE_ARRAY, + ComparisonFilter.OP_GREATER_THAN, + AttributeValue(val_int=5), + ) + from sentry_protos.snuba.v1.trace_item_filter_pb2 import NotFilter + + item_filter = TraceItemFilter(not_filter=NotFilter(filters=[inner])) + result = trace_item_filters_to_expression(item_filter, attribute_key_to_expression) + assert isinstance(result, FunctionCall) + assert result.function_name == "not" + wrapped = result.parameters[0] + assert isinstance(wrapped, FunctionCall) + self._assert_array_numeric_comparison_shape(wrapped, "greater", 5, "my_nums") + + class TestExistsFilterCoalesced: """exists_filter on coalesced attributes must check all deprecated keys.""" diff --git a/tests/web/rpc/v1/test_endpoint_trace_item_table/test_endpoint_trace_item_table.py b/tests/web/rpc/v1/test_endpoint_trace_item_table/test_endpoint_trace_item_table.py index db37dda848..082a631da4 100644 --- a/tests/web/rpc/v1/test_endpoint_trace_item_table/test_endpoint_trace_item_table.py +++ b/tests/web/rpc/v1/test_endpoint_trace_item_table/test_endpoint_trace_item_table.py @@ -3931,6 +3931,102 @@ def test_trace_item_table_array_op_equals_all_scalar_rhs_types( assert len(by_name[attr_name].results) == 1 assert check_row(by_name[attr_name].results[0]) + @pytest.mark.clickhouse_db + @pytest.mark.redis_db + def test_trace_item_table_array_op_greater_than_int(self) -> None: + """OP_GREATER_THAN on TYPE_ARRAY returns rows where some int element > 50.""" + span_ts = BASE_TIME - timedelta(minutes=1) + items_storage = get_storage(StorageKey("eap_items")) + write_raw_unprocessed_events( + items_storage, # type: ignore + [ + # matches: contains 200 (> 50) + gen_item_message(span_ts, attributes={"frame_linenos": _int_array(1, 9, 200)}), + # no match: all elements <= 50 (the lex-compare gotcha: "9" > "50" is true) + gen_item_message(span_ts, attributes={"frame_linenos": _int_array(9, 20)}), + # matches: 99 > 50 + gen_item_message(span_ts, attributes={"frame_linenos": _int_array(99)}), + ], + ) + message = TraceItemTableRequest( + meta=RequestMeta( + project_ids=[1, 2, 3], + organization_id=1, + cogs_category="something", + referrer="something", + start_timestamp=START_TIMESTAMP, + end_timestamp=END_TIMESTAMP, + trace_item_type=TraceItemType.TRACE_ITEM_TYPE_SPAN, + ), + filter=TraceItemFilter( + comparison_filter=ComparisonFilter( + key=AttributeKey(type=AttributeKey.TYPE_ARRAY, name="frame_linenos"), + op=ComparisonFilter.OP_GREATER_THAN, + value=AttributeValue(val_int=50), + ) + ), + columns=[ + Column(key=AttributeKey(type=AttributeKey.TYPE_STRING, name="sentry.item_id")), + Column(key=AttributeKey(type=AttributeKey.TYPE_ARRAY, name="frame_linenos")), + ], + ) + response = EndpointTraceItemTable().execute(message) + by_name = {cv.attribute_name: cv for cv in response.column_values} + assert len(by_name["frame_linenos"].results) == 2 + for row in by_name["frame_linenos"].results: + int_vals = [ + e.val_int for e in row.val_array.values if e.WhichOneof("value") == "val_int" + ] + assert any(v > 50 for v in int_vals) + + @pytest.mark.clickhouse_db + @pytest.mark.redis_db + def test_trace_item_table_array_op_less_than_or_equals_double(self) -> None: + """OP_LESS_THAN_OR_EQUALS on TYPE_ARRAY with a double RHS.""" + span_ts = BASE_TIME - timedelta(minutes=1) + items_storage = get_storage(StorageKey("eap_items")) + write_raw_unprocessed_events( + items_storage, # type: ignore + [ + # matches: 1.0 <= 1.5 + gen_item_message(span_ts, attributes={"measurements": _double_array(1.0, 9.9)}), + # no match: smallest element 2.0 > 1.5 + gen_item_message(span_ts, attributes={"measurements": _double_array(2.0, 3.0)}), + # matches: 0.5 <= 1.5 + gen_item_message(span_ts, attributes={"measurements": _double_array(0.5)}), + ], + ) + message = TraceItemTableRequest( + meta=RequestMeta( + project_ids=[1, 2, 3], + organization_id=1, + cogs_category="something", + referrer="something", + start_timestamp=START_TIMESTAMP, + end_timestamp=END_TIMESTAMP, + trace_item_type=TraceItemType.TRACE_ITEM_TYPE_SPAN, + ), + filter=TraceItemFilter( + comparison_filter=ComparisonFilter( + key=AttributeKey(type=AttributeKey.TYPE_ARRAY, name="measurements"), + op=ComparisonFilter.OP_LESS_THAN_OR_EQUALS, + value=AttributeValue(val_double=1.5), + ) + ), + columns=[ + Column(key=AttributeKey(type=AttributeKey.TYPE_STRING, name="sentry.item_id")), + Column(key=AttributeKey(type=AttributeKey.TYPE_ARRAY, name="measurements")), + ], + ) + response = EndpointTraceItemTable().execute(message) + by_name = {cv.attribute_name: cv for cv in response.column_values} + assert len(by_name["measurements"].results) == 2 + for row in by_name["measurements"].results: + double_vals = [ + e.val_double for e in row.val_array.values if e.WhichOneof("value") == "val_double" + ] + assert any(v <= 1.5 for v in double_vals) + class TestTraceItemTableArrayColumn(BaseApiTest): @pytest.mark.clickhouse_db