Skip to content

[Bug]: ONNX frontend uses a numerically instable ReduceLogSumExp implementation #32839

@cknd

Description

@cknd

OpenVINO Version

2025.3

Operating System

Ubuntu 20.04 (LTS)

Device used for inference

CPU

Framework

ONNX

Issue description

When executing via onnxruntime-openvino on a CPU device, OpenVINO's output for the ReduceLogSumExp op differs from the onnxruntime CPU reference: OpenVINO returns Inf for input values >= 88.72284 (== log(MAX_FLOAT32)), whereas the onnxruntime CPU execution provider returns finite results for inputs up to MAX_FLOAT [1].

The behavior of the onnxruntime CPU provider matches that of a numerically stable LogSumExp implementation k = max(x1, ..., xn) ; lse = k + log(exp(x1-k) + ... + exp(xn-k)) [1]

The OpenVINO behavior matches a 'naive' LogSumExp implementation (lse = log(exp(x1) + ... + exp(xn)): Since exp(88.7) approaches MAX_FLOAT32, larger values lead to the Inf output.

The behavior is the same for float16 and float32 models, probably that's because as far as we understand the openVINO CPU device is always running float32 anyway.

It seems that this has already been adressed for the pytorch frontend by #28887:

// for numerical stability and avoiding exponent explosion,

For the ONNX frontend, to the best of our abilities this seems to be the region where a similar fix would be needed (we haven't attempted this, though).

Opset v1

ov::OutputVector reduce_log_sum_exp(const ov::frontend::onnx::Node& node) {

Opset v13
ov::OutputVector reduce_log_sum_exp(const ov::frontend::onnx::Node& node) {

Opset v18
ov::OutputVector reduce_log_sum_exp(const ov::frontend::onnx::Node& node) {


[1] actually it looks like the onnxruntime CPU provider only works correctly for an onnx opset 12 model and does something weird and different in opset 18, returning Inf for inputs > ~ 700. But that's not an openVINO issue of course. See attached script.

Step-by-step reproduction

Below is a self-contained script that runs the ReduceLogSumExp op on a range of inputs, using the openVINO ep, the CPU ep, and some numpy implementations for comparison.

"""
Tested with the following PyPI packages:
onnx==1.19.1
onnxruntime-openvino==1.23.0   # == OpenVINO 2025.3
"""

import tempfile
import onnx
import onnxruntime as ort
import numpy as np
import warnings
warnings.filterwarnings('ignore')

np.set_printoptions(suppress=True, linewidth=9000, threshold=9000)

INPUT_DIM = 2
DTYPE = 'float16'
COMPARE_WITH_NUMPY = True

def build_model_opset12():
    """ Tiny opset 12 ONNX model that just applies ReduceLogSumExp"""
    if DTYPE == 'float16':
        tp_dtype = onnx.TensorProto.FLOAT16
    elif DTYPE == 'float32':
        tp_dtype = onnx.TensorProto.FLOAT
    inp = onnx.helper.make_tensor_value_info("input", tp_dtype, [INPUT_DIM])
    out = onnx.helper.make_tensor_value_info("rlse_out", tp_dtype, [1])

    rls_node = onnx.helper.make_node(
        "ReduceLogSumExp",
        inputs=["input",],
        outputs=["rlse_out"],
        name="rlse",
        keepdims=1,
        axes=[0]
    )

    graph = onnx.helper.make_graph(
        nodes=[rls_node],
        name="rlse_graph",
        inputs=[inp],
        outputs=[out],
    )

    opset = onnx.helper.make_operatorsetid("", 12)

    model = onnx.helper.make_model(graph, opset_imports=[opset], ir_version=10)
    onnx.checker.check_model(model)

    output_names = ["rlse_out"]
    return model, output_names


def build_model_opset18():
    """ Tiny opset 18 ONNX model that just applies ReduceLogSumExp"""
    if DTYPE == 'float16':
        tp_dtype = onnx.TensorProto.FLOAT16
    elif DTYPE == 'float32':
        tp_dtype = onnx.TensorProto.FLOAT
    inp = onnx.helper.make_tensor_value_info("input", tp_dtype, [INPUT_DIM])
    out = onnx.helper.make_tensor_value_info("rlse_out", tp_dtype, [1])

    axes_init = onnx.helper.make_tensor(
        name="axes",
        data_type=onnx.TensorProto.INT64,
        dims=[1],
        vals=[0],
    )

    rls_node = onnx.helper.make_node(
        "ReduceLogSumExp",
        inputs=["input", "axes"],  # opset 18: axes provided as an input instead of kwarg
        outputs=["rlse_out"],
        name="rlse",
        keepdims=1,
    )

    graph = onnx.helper.make_graph(
        nodes=[rls_node],
        name="rlse_graph",
        inputs=[inp],
        outputs=[out],
        initializer=[axes_init],
    )

    opset = onnx.helper.make_operatorsetid("", 18)

    model = onnx.helper.make_model(graph, opset_imports=[opset], ir_version=10)
    onnx.checker.check_model(model)

    output_names = ["rlse_out"]
    return model, output_names



for build_model in [
        build_model_opset18,
        build_model_opset12
    ]:
    with tempfile.TemporaryDirectory() as tempdir:
        model, output_names = build_model()
        print("model opset", model.opset_import[0].version)
        print("=" * 50)

        # CPU ep session
        opts = ort.SessionOptions()
        ort_session_cpu = ort.InferenceSession(model.SerializeToString(),
                                    sess_options=opts,
                                    providers=["CPUExecutionProvider"])

        # OpenVINO ep session
        opts_ov = ort.SessionOptions()
        opts_ov.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
        openvino_config = {
                "device_type": "CPU",
                "cache_dir" : tempdir
            }
        ort_session_openvino = ort.InferenceSession(model.SerializeToString(),
                                        sess_options=opts_ov,
                                        providers=["OpenVINOExecutionProvider"],
                                        provider_options=[openvino_config]
                                        )

        # Try increasingly large inputs
        for case in [10,11,12, "(...)",
                     87,88,89,90, "(...)",
                     708.74999, 708.75, "(...)",
                     9999, "(...)",
                     np.finfo('float16').max,
                    #  np.finfo('float32').max
                     ]:


            if isinstance(case, str):
                print(case, "\n\n")
                continue

            in_array = (np.ones((INPUT_DIM)) * case).astype(DTYPE)

            # Run it
            outputs_cpu = ort_session_cpu.run(output_names, {"input":in_array})[0]
            outputs_ov = ort_session_openvino.run(output_names, {"input":in_array})[0]
            print(f"ReduceLogSumExp({in_array}) = ")

            if DTYPE == 'float16' and COMPARE_WITH_NUMPY:
                # some comparisions with numpy:
                # - the naive (unsafe) numpy version goes Inf already at ~11, which
                #   makes sense considering 11.09 == np.log(np.finfo('float16').max),
                #   i.e. the elementwise exp() simply overflows.
                # - the behavior of onnxruntime CPU matches the safe float16 numpy version
                # - the behavior of onnxruntime OpenVINO matches the naive/unsafe float32
                #   followed by a cast back to float16. It goes to inf for inputs around
                #   89, consider 88.72284 == np.log(np.finfo('float32').max, so again
                #   it looks like the exp() overflows.
                maximum = np.max(in_array)
                rlse_safe_numpy = maximum + np.log(np.sum(np.exp(in_array - maximum)))
                assert rlse_safe_numpy.dtype == np.dtypes.Float16DType.type

                rlse_unsafe_numpy = np.log(np.sum(np.exp(in_array)))
                assert rlse_unsafe_numpy.dtype == np.dtypes.Float16DType.type

                rlse_unsafe_numpy_32 = np.log(np.sum(np.exp(in_array.astype('float32'))))
                assert rlse_unsafe_numpy_32.dtype == np.dtypes.Float32DType.type
                rlse_unsafe_numpy_32_cast_to_f16 = rlse_unsafe_numpy_32.astype('float16')
                print(f"    safe numpy float16:    {rlse_safe_numpy}")
                print(f"    unsafe numpy float16:  {rlse_unsafe_numpy}")
                print(f"    unsafe numpy float32:  {rlse_unsafe_numpy_32_cast_to_f16}")
            print(f"    ORT CPU:               {outputs_cpu[0]}     {'!' if not np.isfinite(outputs_cpu[0]) else ''}")
            print(f"    ORT OpenVINO:          {outputs_ov[0]}     {'!' if not np.isfinite(outputs_ov[0]) else ''}")
            print('\n\n')

Relevant log output

# Output of the attached script


model opset 12
==================================================
ReduceLogSumExp([10. 10.]) = 
    safe numpy float16:    10.6953125
    unsafe numpy float16:  10.6953125
    unsafe numpy float32:  10.6953125
    ORT CPU:               10.6953125     
    ORT OpenVINO:          10.6953125     



ReduceLogSumExp([11. 11.]) = 
    safe numpy float16:    11.6953125
    unsafe numpy float16:  inf
    unsafe numpy float32:  11.6953125
    ORT CPU:               11.6953125     
    ORT OpenVINO:          11.6953125     



ReduceLogSumExp([12. 12.]) = 
    safe numpy float16:    12.6953125
    unsafe numpy float16:  inf
    unsafe numpy float32:  12.6953125
    ORT CPU:               12.6953125     
    ORT OpenVINO:          12.6953125     



(...) 


ReduceLogSumExp([87. 87.]) = 
    safe numpy float16:    87.6875
    unsafe numpy float16:  inf
    unsafe numpy float32:  87.6875
    ORT CPU:               87.6875     
    ORT OpenVINO:          87.6875     



ReduceLogSumExp([88. 88.]) = 
    safe numpy float16:    88.6875
    unsafe numpy float16:  inf
    unsafe numpy float32:  88.6875
    ORT CPU:               88.6875     
    ORT OpenVINO:          88.6875     



ReduceLogSumExp([89. 89.]) = 
    safe numpy float16:    89.6875
    unsafe numpy float16:  inf
    unsafe numpy float32:  inf
    ORT CPU:               89.6875     
    ORT OpenVINO:          inf     !



ReduceLogSumExp([90. 90.]) = 
    safe numpy float16:    90.6875
    unsafe numpy float16:  inf
    unsafe numpy float32:  inf
    ORT CPU:               90.6875     
    ORT OpenVINO:          inf     !



(...) 


ReduceLogSumExp([708.5 708.5]) = 
    safe numpy float16:    709.0
    unsafe numpy float16:  inf
    unsafe numpy float32:  inf
    ORT CPU:               709.0     
    ORT OpenVINO:          inf     !



ReduceLogSumExp([709. 709.]) = 
    safe numpy float16:    709.5
    unsafe numpy float16:  inf
    unsafe numpy float32:  inf
    ORT CPU:               709.5     
    ORT OpenVINO:          inf     !



(...) 


ReduceLogSumExp([10000. 10000.]) = 
    safe numpy float16:    10000.0
    unsafe numpy float16:  inf
    unsafe numpy float32:  inf
    ORT CPU:               10000.0     
    ORT OpenVINO:          inf     !



(...) 


ReduceLogSumExp([65504. 65504.]) = 
    safe numpy float16:    65504.0
    unsafe numpy float16:  inf
    unsafe numpy float32:  inf
    ORT CPU:               65504.0     
    ORT OpenVINO:          inf     !

Issue submission checklist

  • I'm reporting an issue. It's not a question.
  • I checked the problem with the documentation, FAQ, open issues, Stack Overflow, etc., and have not found a solution.
  • There is reproducer code and related data files such as images, videos, models, etc.

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions