-
Notifications
You must be signed in to change notification settings - Fork 2.9k
Description
OpenVINO Version
2025.3
Operating System
Ubuntu 20.04 (LTS)
Device used for inference
CPU
Framework
ONNX
Issue description
When executing via onnxruntime-openvino on a CPU device, OpenVINO's output for the ReduceLogSumExp op differs from the onnxruntime CPU reference: OpenVINO returns Inf for input values >= 88.72284 (== log(MAX_FLOAT32)), whereas the onnxruntime CPU execution provider returns finite results for inputs up to MAX_FLOAT [1].
The behavior of the onnxruntime CPU provider matches that of a numerically stable LogSumExp implementation k = max(x1, ..., xn) ; lse = k + log(exp(x1-k) + ... + exp(xn-k)) [1]
The OpenVINO behavior matches a 'naive' LogSumExp implementation (lse = log(exp(x1) + ... + exp(xn)): Since exp(88.7) approaches MAX_FLOAT32, larger values lead to the Inf output.
The behavior is the same for float16 and float32 models, probably that's because as far as we understand the openVINO CPU device is always running float32 anyway.
It seems that this has already been adressed for the pytorch frontend by #28887:
| // for numerical stability and avoiding exponent explosion, |
For the ONNX frontend, to the best of our abilities this seems to be the region where a similar fix would be needed (we haven't attempted this, though).
Opset v1
| ov::OutputVector reduce_log_sum_exp(const ov::frontend::onnx::Node& node) { |
Opset v13
| ov::OutputVector reduce_log_sum_exp(const ov::frontend::onnx::Node& node) { |
Opset v18
| ov::OutputVector reduce_log_sum_exp(const ov::frontend::onnx::Node& node) { |
[1] actually it looks like the onnxruntime CPU provider only works correctly for an onnx opset 12 model and does something weird and different in opset 18, returning Inf for inputs > ~ 700. But that's not an openVINO issue of course. See attached script.
Step-by-step reproduction
Below is a self-contained script that runs the ReduceLogSumExp op on a range of inputs, using the openVINO ep, the CPU ep, and some numpy implementations for comparison.
"""
Tested with the following PyPI packages:
onnx==1.19.1
onnxruntime-openvino==1.23.0 # == OpenVINO 2025.3
"""
import tempfile
import onnx
import onnxruntime as ort
import numpy as np
import warnings
warnings.filterwarnings('ignore')
np.set_printoptions(suppress=True, linewidth=9000, threshold=9000)
INPUT_DIM = 2
DTYPE = 'float16'
COMPARE_WITH_NUMPY = True
def build_model_opset12():
""" Tiny opset 12 ONNX model that just applies ReduceLogSumExp"""
if DTYPE == 'float16':
tp_dtype = onnx.TensorProto.FLOAT16
elif DTYPE == 'float32':
tp_dtype = onnx.TensorProto.FLOAT
inp = onnx.helper.make_tensor_value_info("input", tp_dtype, [INPUT_DIM])
out = onnx.helper.make_tensor_value_info("rlse_out", tp_dtype, [1])
rls_node = onnx.helper.make_node(
"ReduceLogSumExp",
inputs=["input",],
outputs=["rlse_out"],
name="rlse",
keepdims=1,
axes=[0]
)
graph = onnx.helper.make_graph(
nodes=[rls_node],
name="rlse_graph",
inputs=[inp],
outputs=[out],
)
opset = onnx.helper.make_operatorsetid("", 12)
model = onnx.helper.make_model(graph, opset_imports=[opset], ir_version=10)
onnx.checker.check_model(model)
output_names = ["rlse_out"]
return model, output_names
def build_model_opset18():
""" Tiny opset 18 ONNX model that just applies ReduceLogSumExp"""
if DTYPE == 'float16':
tp_dtype = onnx.TensorProto.FLOAT16
elif DTYPE == 'float32':
tp_dtype = onnx.TensorProto.FLOAT
inp = onnx.helper.make_tensor_value_info("input", tp_dtype, [INPUT_DIM])
out = onnx.helper.make_tensor_value_info("rlse_out", tp_dtype, [1])
axes_init = onnx.helper.make_tensor(
name="axes",
data_type=onnx.TensorProto.INT64,
dims=[1],
vals=[0],
)
rls_node = onnx.helper.make_node(
"ReduceLogSumExp",
inputs=["input", "axes"], # opset 18: axes provided as an input instead of kwarg
outputs=["rlse_out"],
name="rlse",
keepdims=1,
)
graph = onnx.helper.make_graph(
nodes=[rls_node],
name="rlse_graph",
inputs=[inp],
outputs=[out],
initializer=[axes_init],
)
opset = onnx.helper.make_operatorsetid("", 18)
model = onnx.helper.make_model(graph, opset_imports=[opset], ir_version=10)
onnx.checker.check_model(model)
output_names = ["rlse_out"]
return model, output_names
for build_model in [
build_model_opset18,
build_model_opset12
]:
with tempfile.TemporaryDirectory() as tempdir:
model, output_names = build_model()
print("model opset", model.opset_import[0].version)
print("=" * 50)
# CPU ep session
opts = ort.SessionOptions()
ort_session_cpu = ort.InferenceSession(model.SerializeToString(),
sess_options=opts,
providers=["CPUExecutionProvider"])
# OpenVINO ep session
opts_ov = ort.SessionOptions()
opts_ov.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
openvino_config = {
"device_type": "CPU",
"cache_dir" : tempdir
}
ort_session_openvino = ort.InferenceSession(model.SerializeToString(),
sess_options=opts_ov,
providers=["OpenVINOExecutionProvider"],
provider_options=[openvino_config]
)
# Try increasingly large inputs
for case in [10,11,12, "(...)",
87,88,89,90, "(...)",
708.74999, 708.75, "(...)",
9999, "(...)",
np.finfo('float16').max,
# np.finfo('float32').max
]:
if isinstance(case, str):
print(case, "\n\n")
continue
in_array = (np.ones((INPUT_DIM)) * case).astype(DTYPE)
# Run it
outputs_cpu = ort_session_cpu.run(output_names, {"input":in_array})[0]
outputs_ov = ort_session_openvino.run(output_names, {"input":in_array})[0]
print(f"ReduceLogSumExp({in_array}) = ")
if DTYPE == 'float16' and COMPARE_WITH_NUMPY:
# some comparisions with numpy:
# - the naive (unsafe) numpy version goes Inf already at ~11, which
# makes sense considering 11.09 == np.log(np.finfo('float16').max),
# i.e. the elementwise exp() simply overflows.
# - the behavior of onnxruntime CPU matches the safe float16 numpy version
# - the behavior of onnxruntime OpenVINO matches the naive/unsafe float32
# followed by a cast back to float16. It goes to inf for inputs around
# 89, consider 88.72284 == np.log(np.finfo('float32').max, so again
# it looks like the exp() overflows.
maximum = np.max(in_array)
rlse_safe_numpy = maximum + np.log(np.sum(np.exp(in_array - maximum)))
assert rlse_safe_numpy.dtype == np.dtypes.Float16DType.type
rlse_unsafe_numpy = np.log(np.sum(np.exp(in_array)))
assert rlse_unsafe_numpy.dtype == np.dtypes.Float16DType.type
rlse_unsafe_numpy_32 = np.log(np.sum(np.exp(in_array.astype('float32'))))
assert rlse_unsafe_numpy_32.dtype == np.dtypes.Float32DType.type
rlse_unsafe_numpy_32_cast_to_f16 = rlse_unsafe_numpy_32.astype('float16')
print(f" safe numpy float16: {rlse_safe_numpy}")
print(f" unsafe numpy float16: {rlse_unsafe_numpy}")
print(f" unsafe numpy float32: {rlse_unsafe_numpy_32_cast_to_f16}")
print(f" ORT CPU: {outputs_cpu[0]} {'!' if not np.isfinite(outputs_cpu[0]) else ''}")
print(f" ORT OpenVINO: {outputs_ov[0]} {'!' if not np.isfinite(outputs_ov[0]) else ''}")
print('\n\n')Relevant log output
# Output of the attached script
model opset 12
==================================================
ReduceLogSumExp([10. 10.]) =
safe numpy float16: 10.6953125
unsafe numpy float16: 10.6953125
unsafe numpy float32: 10.6953125
ORT CPU: 10.6953125
ORT OpenVINO: 10.6953125
ReduceLogSumExp([11. 11.]) =
safe numpy float16: 11.6953125
unsafe numpy float16: inf
unsafe numpy float32: 11.6953125
ORT CPU: 11.6953125
ORT OpenVINO: 11.6953125
ReduceLogSumExp([12. 12.]) =
safe numpy float16: 12.6953125
unsafe numpy float16: inf
unsafe numpy float32: 12.6953125
ORT CPU: 12.6953125
ORT OpenVINO: 12.6953125
(...)
ReduceLogSumExp([87. 87.]) =
safe numpy float16: 87.6875
unsafe numpy float16: inf
unsafe numpy float32: 87.6875
ORT CPU: 87.6875
ORT OpenVINO: 87.6875
ReduceLogSumExp([88. 88.]) =
safe numpy float16: 88.6875
unsafe numpy float16: inf
unsafe numpy float32: 88.6875
ORT CPU: 88.6875
ORT OpenVINO: 88.6875
ReduceLogSumExp([89. 89.]) =
safe numpy float16: 89.6875
unsafe numpy float16: inf
unsafe numpy float32: inf
ORT CPU: 89.6875
ORT OpenVINO: inf !
ReduceLogSumExp([90. 90.]) =
safe numpy float16: 90.6875
unsafe numpy float16: inf
unsafe numpy float32: inf
ORT CPU: 90.6875
ORT OpenVINO: inf !
(...)
ReduceLogSumExp([708.5 708.5]) =
safe numpy float16: 709.0
unsafe numpy float16: inf
unsafe numpy float32: inf
ORT CPU: 709.0
ORT OpenVINO: inf !
ReduceLogSumExp([709. 709.]) =
safe numpy float16: 709.5
unsafe numpy float16: inf
unsafe numpy float32: inf
ORT CPU: 709.5
ORT OpenVINO: inf !
(...)
ReduceLogSumExp([10000. 10000.]) =
safe numpy float16: 10000.0
unsafe numpy float16: inf
unsafe numpy float32: inf
ORT CPU: 10000.0
ORT OpenVINO: inf !
(...)
ReduceLogSumExp([65504. 65504.]) =
safe numpy float16: 65504.0
unsafe numpy float16: inf
unsafe numpy float32: inf
ORT CPU: 65504.0
ORT OpenVINO: inf !Issue submission checklist
- I'm reporting an issue. It's not a question.
- I checked the problem with the documentation, FAQ, open issues, Stack Overflow, etc., and have not found a solution.
- There is reproducer code and related data files such as images, videos, models, etc.