Skip to content

Commit 8ae13dc

Browse files
authored
Process: Raise exception on timeout. (#4077)
1 parent 030485c commit 8ae13dc

File tree

2 files changed

+191
-54
lines changed

2 files changed

+191
-54
lines changed

lisa/util/process.py

Lines changed: 80 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import shlex
99
import signal
1010
import subprocess
11+
import threading
1112
import time
1213
from dataclasses import dataclass
1314
from pathlib import Path
@@ -211,6 +212,7 @@ def __init__(
211212
self._result: Optional[ExecutableResult] = None
212213
self._sudo: bool = False
213214
self._nohup: bool = False
215+
self._result_lock = threading.Lock()
214216

215217
# add a string stream handler to the logger
216218
self.log_buffer = io.StringIO()
@@ -362,7 +364,27 @@ def wait_result(
362364
timeout: float = 600,
363365
expected_exit_code: Optional[int] = None,
364366
expected_exit_code_failure_message: str = "",
367+
raise_on_timeout: bool = True,
365368
) -> ExecutableResult:
369+
with self._result_lock:
370+
return self._wait_result(
371+
timeout,
372+
expected_exit_code,
373+
expected_exit_code_failure_message,
374+
raise_on_timeout,
375+
)
376+
377+
def _wait_result(
378+
self,
379+
timeout: float,
380+
expected_exit_code: Optional[int],
381+
expected_exit_code_failure_message: str,
382+
raise_on_timeout: bool,
383+
) -> ExecutableResult:
384+
if self._result is not None:
385+
if self._result.is_timeout and raise_on_timeout:
386+
self._raise_timeout_exception(self._cmd, timeout)
387+
return self._result
366388
timer = create_timer()
367389
is_timeout = False
368390

@@ -375,77 +397,78 @@ def wait_result(
375397
self.kill()
376398
is_timeout = True
377399

378-
if self._result is None:
379-
assert self._process
380-
if is_timeout:
381-
# LogWriter only flushes if "\n" is written, so we need to flush
382-
# manually.
383-
self._stdout_writer.flush()
384-
self._stderr_writer.flush()
385-
process_result = spur.results.result(
386-
return_code=1,
387-
allow_error=True,
388-
output=self.log_buffer.getvalue(),
389-
stderr_output="",
390-
)
391-
else:
392-
process_result = self._process.wait_for_result()
393-
if not self._is_posix and self._shell.is_remote:
394-
# special handle remote windows. There are extra control chars
395-
# and on extra line at the end.
396-
397-
# remove extra controls in remote Windows
398-
process_result.output = filter_ansi_escape(process_result.output)
399-
process_result.stderr_output = filter_ansi_escape(
400-
process_result.stderr_output
401-
)
400+
assert self._process
401+
if is_timeout:
402+
# LogWriter only flushes if "\n" is written, so we need to flush
403+
# manually.
404+
self._stdout_writer.flush()
405+
self._stderr_writer.flush()
406+
process_result = spur.results.result(
407+
return_code=1,
408+
allow_error=True,
409+
output=self.log_buffer.getvalue(),
410+
stderr_output="",
411+
)
412+
else:
413+
process_result = self._process.wait_for_result()
402414

403-
self._stdout_writer.close()
404-
self._stderr_writer.close()
405-
# cache for future queries, in case it's queried twice.
406-
self._result = ExecutableResult(
407-
process_result.output.strip(),
408-
process_result.stderr_output.strip(),
409-
process_result.return_code,
410-
self._cmd,
411-
self._timer.elapsed(),
412-
is_timeout,
415+
if not self._is_posix and self._shell.is_remote:
416+
# special handle remote windows. There are extra control chars
417+
# and on extra line at the end.
418+
419+
# remove extra controls in remote Windows
420+
process_result.output = filter_ansi_escape(process_result.output)
421+
process_result.stderr_output = filter_ansi_escape(
422+
process_result.stderr_output
413423
)
414424

415-
self._recycle_resource()
425+
self._stdout_writer.close()
426+
self._stderr_writer.close()
427+
428+
# cache for future queries, in case it's queried twice.
429+
result = ExecutableResult(
430+
process_result.output.strip(),
431+
process_result.stderr_output.strip(),
432+
process_result.return_code,
433+
self._cmd,
434+
self._timer.elapsed(),
435+
is_timeout,
436+
)
416437

417-
if not self._is_posix:
418-
# convert windows error code to int4, so it's more friendly.
419-
assert self._result.exit_code is not None
420-
exit_code = self._result.exit_code
421-
if exit_code > 2**31:
422-
self._result.exit_code = exit_code - 2**32
438+
self._recycle_resource()
423439

424-
self._log.debug(
425-
f"execution time: {self._timer}, exit code: {self._result.exit_code}"
426-
)
440+
if not self._is_posix:
441+
# convert windows error code to int4, so it's more friendly.
442+
assert result.exit_code is not None
443+
exit_code = result.exit_code
444+
if exit_code > 2**31:
445+
result.exit_code = exit_code - 2**32
446+
447+
self._log.debug(f"execution time: {self._timer}, exit code: {result.exit_code}")
427448

428449
if expected_exit_code is not None:
429-
self._result.assert_exit_code(
450+
result.assert_exit_code(
430451
expected_exit_code=expected_exit_code,
431452
message=expected_exit_code_failure_message,
432453
)
433454

434455
if self._is_posix and self._sudo:
435-
self._result.stdout = self._filter_sudo_result(self._result.stdout)
456+
result.stdout = self._filter_sudo_result(result.stdout)
436457

437-
self._result.stdout = self._filter_profile_error(self._result.stdout)
438-
self._result.stdout = self._filter_bash_prompt(self._result.stdout)
439-
self._check_if_need_input_password(self._result.stdout)
440-
self._result.stdout = self._filter_sudo_required_password_info(
441-
self._result.stdout
442-
)
458+
result.stdout = self._filter_profile_error(result.stdout)
459+
result.stdout = self._filter_bash_prompt(result.stdout)
460+
self._check_if_need_input_password(result.stdout)
461+
result.stdout = self._filter_sudo_required_password_info(result.stdout)
443462

444463
if not self._is_posix:
445464
# fix windows ending with " by some unknown reason.
446-
self._result.stdout = self._remove_ending_quote(self._result.stdout)
447-
self._result.stderr = self._remove_ending_quote(self._result.stderr)
465+
result.stdout = self._remove_ending_quote(result.stdout)
466+
result.stderr = self._remove_ending_quote(result.stderr)
467+
468+
if is_timeout and raise_on_timeout:
469+
self._raise_timeout_exception(self._cmd, timeout)
448470

471+
self._result = result
449472
return self._result
450473

451474
def kill(self) -> None:
@@ -547,6 +570,9 @@ def _recycle_resource(self) -> None:
547570
self._process._stderr.close()
548571
self._process = None
549572

573+
def _raise_timeout_exception(self, cmdlet: List[str], timeout: float) -> None:
574+
raise LisaException(f"command '{cmdlet}' timeout after {timeout} seconds.")
575+
550576
def _filter_sudo_result(self, raw_input: str) -> str:
551577
# this warning message may break commands, so remove it from the first line
552578
# of standard output.

selftests/test_process.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT license.
3+
4+
from unittest import TestCase
5+
from unittest.mock import Mock
6+
7+
from assertpy import assert_that
8+
9+
from lisa.util import LisaException
10+
from lisa.util.process import Process
11+
12+
13+
class ProcessTestCase(TestCase):
14+
def test_wait_result_caching_prevents_reprocessing(self) -> None:
15+
from unittest.mock import MagicMock, patch
16+
17+
shell = Mock()
18+
shell.is_posix = True
19+
shell.is_remote = False
20+
21+
process = Process("test", shell)
22+
process._cmd = ["echo", "test"]
23+
process._timer = Mock()
24+
process._stdout_writer = Mock()
25+
process._stderr_writer = Mock()
26+
process.log_buffer = MagicMock()
27+
28+
with (
29+
patch.object(process, "_process") as mock_process,
30+
patch.object(process, "is_running", side_effect=[True, False]),
31+
patch.object(process, "_recycle_resource"),
32+
):
33+
process._timer.elapsed.return_value = 1.0
34+
mock_process.wait_for_result.return_value = Mock(
35+
output="test output",
36+
stderr_output="",
37+
return_code=0,
38+
)
39+
40+
result1 = process.wait_result(timeout=10)
41+
result2 = process.wait_result(timeout=10)
42+
43+
assert_that(result1).is_same_as(result2)
44+
assert_that(mock_process.wait_for_result.call_count).is_equal_to(1)
45+
46+
def test_wait_result_timeout_with_raise_on_timeout_true(self) -> None:
47+
from unittest.mock import MagicMock, patch
48+
49+
shell = Mock()
50+
shell.is_posix = True
51+
shell.is_remote = False
52+
53+
process = Process("test", shell)
54+
process._cmd = ["sleep", "1"]
55+
process._stdout_writer = Mock()
56+
process._stderr_writer = Mock()
57+
process.log_buffer = MagicMock()
58+
59+
# Mock both the local timer and the instance timer
60+
mock_timer = Mock()
61+
mock_timer.elapsed.side_effect = [5.0, 11.0, 11.0]
62+
process._timer = mock_timer
63+
64+
with (
65+
patch.object(process, "_process"),
66+
patch.object(process, "is_running", side_effect=[True, False]),
67+
patch.object(process, "kill"),
68+
patch.object(process, "_recycle_resource"),
69+
patch("lisa.util.process.time.sleep"), # Patch at import location
70+
patch("lisa.util.process.create_timer", return_value=mock_timer),
71+
):
72+
process.log_buffer.getvalue.return_value = "partial output"
73+
74+
with self.assertRaises(LisaException) as context:
75+
process.wait_result(timeout=10, raise_on_timeout=True)
76+
77+
assert_that(str(context.exception)).contains("timeout after 10 seconds")
78+
79+
def test_wait_result_timeout_with_raise_on_timeout_false(self) -> None:
80+
from unittest.mock import MagicMock, patch
81+
82+
shell = Mock()
83+
shell.is_posix = True
84+
shell.is_remote = False
85+
86+
process = Process("test", shell)
87+
process._cmd = ["sleep", "1"]
88+
process._stdout_writer = Mock()
89+
process._stderr_writer = Mock()
90+
process.log_buffer = MagicMock()
91+
92+
# Mock both the local timer and the instance timer
93+
mock_timer = Mock()
94+
mock_timer.elapsed.side_effect = [5.0, 11.0, 11.0, 11.0]
95+
process._timer = mock_timer
96+
97+
with (
98+
patch.object(process, "_process"),
99+
patch.object(process, "is_running", side_effect=[True, False]),
100+
patch.object(process, "kill"),
101+
patch.object(process, "_recycle_resource"),
102+
patch("lisa.util.process.time.sleep"), # Patch at import location
103+
patch("lisa.util.process.create_timer", return_value=mock_timer),
104+
):
105+
process.log_buffer.getvalue.return_value = "partial output"
106+
107+
result = process.wait_result(timeout=10, raise_on_timeout=False)
108+
109+
assert_that(result.is_timeout).is_true()
110+
assert_that(result.stdout).is_equal_to("partial output")
111+
assert_that(result.exit_code).is_equal_to(1)

0 commit comments

Comments
 (0)