Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 80 additions & 54 deletions lisa/util/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import shlex
import signal
import subprocess
import threading
import time
from dataclasses import dataclass
from pathlib import Path
Expand Down Expand Up @@ -211,6 +212,7 @@ def __init__(
self._result: Optional[ExecutableResult] = None
self._sudo: bool = False
self._nohup: bool = False
self._result_lock = threading.Lock()

# add a string stream handler to the logger
self.log_buffer = io.StringIO()
Expand Down Expand Up @@ -362,7 +364,27 @@ def wait_result(
timeout: float = 600,
expected_exit_code: Optional[int] = None,
expected_exit_code_failure_message: str = "",
raise_on_timeout: bool = True,
) -> ExecutableResult:
with self._result_lock:
return self._wait_result(
timeout,
expected_exit_code,
expected_exit_code_failure_message,
raise_on_timeout,
)

def _wait_result(
self,
timeout: float,
expected_exit_code: Optional[int],
expected_exit_code_failure_message: str,
raise_on_timeout: bool,
) -> ExecutableResult:
if self._result is not None:
if self._result.is_timeout and raise_on_timeout:
self._raise_timeout_exception(self._cmd, timeout)
return self._result
timer = create_timer()
is_timeout = False

Expand All @@ -375,77 +397,78 @@ def wait_result(
self.kill()
is_timeout = True

if self._result is None:
assert self._process
if is_timeout:
# LogWriter only flushes if "\n" is written, so we need to flush
# manually.
self._stdout_writer.flush()
self._stderr_writer.flush()
process_result = spur.results.result(
return_code=1,
allow_error=True,
output=self.log_buffer.getvalue(),
stderr_output="",
)
else:
process_result = self._process.wait_for_result()
if not self._is_posix and self._shell.is_remote:
# special handle remote windows. There are extra control chars
# and on extra line at the end.

# remove extra controls in remote Windows
process_result.output = filter_ansi_escape(process_result.output)
process_result.stderr_output = filter_ansi_escape(
process_result.stderr_output
)
assert self._process
if is_timeout:
# LogWriter only flushes if "\n" is written, so we need to flush
# manually.
self._stdout_writer.flush()
self._stderr_writer.flush()
process_result = spur.results.result(
return_code=1,
allow_error=True,
output=self.log_buffer.getvalue(),
stderr_output="",
)
else:
process_result = self._process.wait_for_result()

self._stdout_writer.close()
self._stderr_writer.close()
# cache for future queries, in case it's queried twice.
self._result = ExecutableResult(
process_result.output.strip(),
process_result.stderr_output.strip(),
process_result.return_code,
self._cmd,
self._timer.elapsed(),
is_timeout,
if not self._is_posix and self._shell.is_remote:
# special handle remote windows. There are extra control chars
# and on extra line at the end.

# remove extra controls in remote Windows
process_result.output = filter_ansi_escape(process_result.output)
process_result.stderr_output = filter_ansi_escape(
process_result.stderr_output
)

self._recycle_resource()
self._stdout_writer.close()
self._stderr_writer.close()

# cache for future queries, in case it's queried twice.
result = ExecutableResult(
process_result.output.strip(),
process_result.stderr_output.strip(),
process_result.return_code,
self._cmd,
self._timer.elapsed(),
is_timeout,
)

if not self._is_posix:
# convert windows error code to int4, so it's more friendly.
assert self._result.exit_code is not None
exit_code = self._result.exit_code
if exit_code > 2**31:
self._result.exit_code = exit_code - 2**32
self._recycle_resource()

self._log.debug(
f"execution time: {self._timer}, exit code: {self._result.exit_code}"
)
if not self._is_posix:
# convert windows error code to int4, so it's more friendly.
assert result.exit_code is not None
exit_code = result.exit_code
if exit_code > 2**31:
result.exit_code = exit_code - 2**32

self._log.debug(f"execution time: {self._timer}, exit code: {result.exit_code}")

if expected_exit_code is not None:
self._result.assert_exit_code(
result.assert_exit_code(
expected_exit_code=expected_exit_code,
message=expected_exit_code_failure_message,
)

if self._is_posix and self._sudo:
self._result.stdout = self._filter_sudo_result(self._result.stdout)
result.stdout = self._filter_sudo_result(result.stdout)

self._result.stdout = self._filter_profile_error(self._result.stdout)
self._result.stdout = self._filter_bash_prompt(self._result.stdout)
self._check_if_need_input_password(self._result.stdout)
self._result.stdout = self._filter_sudo_required_password_info(
self._result.stdout
)
result.stdout = self._filter_profile_error(result.stdout)
result.stdout = self._filter_bash_prompt(result.stdout)
self._check_if_need_input_password(result.stdout)
result.stdout = self._filter_sudo_required_password_info(result.stdout)

if not self._is_posix:
# fix windows ending with " by some unknown reason.
self._result.stdout = self._remove_ending_quote(self._result.stdout)
self._result.stderr = self._remove_ending_quote(self._result.stderr)
result.stdout = self._remove_ending_quote(result.stdout)
result.stderr = self._remove_ending_quote(result.stderr)

if is_timeout and raise_on_timeout:
self._raise_timeout_exception(self._cmd, timeout)

self._result = result
return self._result

def kill(self) -> None:
Expand Down Expand Up @@ -547,6 +570,9 @@ def _recycle_resource(self) -> None:
self._process._stderr.close()
self._process = None

def _raise_timeout_exception(self, cmdlet: List[str], timeout: float) -> None:
raise LisaException(f"command '{cmdlet}' timeout after {timeout} seconds.")

def _filter_sudo_result(self, raw_input: str) -> str:
# this warning message may break commands, so remove it from the first line
# of standard output.
Expand Down
111 changes: 111 additions & 0 deletions selftests/test_process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from unittest import TestCase
from unittest.mock import Mock

from assertpy import assert_that

from lisa.util import LisaException
from lisa.util.process import Process


class ProcessTestCase(TestCase):
def test_wait_result_caching_prevents_reprocessing(self) -> None:
from unittest.mock import MagicMock, patch

shell = Mock()
shell.is_posix = True
shell.is_remote = False

process = Process("test", shell)
process._cmd = ["echo", "test"]
process._timer = Mock()
process._stdout_writer = Mock()
process._stderr_writer = Mock()
process.log_buffer = MagicMock()

with (
patch.object(process, "_process") as mock_process,
patch.object(process, "is_running", side_effect=[True, False]),
patch.object(process, "_recycle_resource"),
):
process._timer.elapsed.return_value = 1.0
mock_process.wait_for_result.return_value = Mock(
output="test output",
stderr_output="",
return_code=0,
)

result1 = process.wait_result(timeout=10)
result2 = process.wait_result(timeout=10)

assert_that(result1).is_same_as(result2)
assert_that(mock_process.wait_for_result.call_count).is_equal_to(1)

def test_wait_result_timeout_with_raise_on_timeout_true(self) -> None:
from unittest.mock import MagicMock, patch

shell = Mock()
shell.is_posix = True
shell.is_remote = False

process = Process("test", shell)
process._cmd = ["sleep", "1"]
process._stdout_writer = Mock()
process._stderr_writer = Mock()
process.log_buffer = MagicMock()

# Mock both the local timer and the instance timer
mock_timer = Mock()
mock_timer.elapsed.side_effect = [5.0, 11.0, 11.0]
process._timer = mock_timer

with (
patch.object(process, "_process"),
patch.object(process, "is_running", side_effect=[True, False]),
patch.object(process, "kill"),
patch.object(process, "_recycle_resource"),
patch("lisa.util.process.time.sleep"), # Patch at import location
patch("lisa.util.process.create_timer", return_value=mock_timer),
):
process.log_buffer.getvalue.return_value = "partial output"

with self.assertRaises(LisaException) as context:
process.wait_result(timeout=10, raise_on_timeout=True)

assert_that(str(context.exception)).contains("timeout after 10 seconds")

def test_wait_result_timeout_with_raise_on_timeout_false(self) -> None:
from unittest.mock import MagicMock, patch

shell = Mock()
shell.is_posix = True
shell.is_remote = False

process = Process("test", shell)
process._cmd = ["sleep", "1"]
process._stdout_writer = Mock()
process._stderr_writer = Mock()
process.log_buffer = MagicMock()

# Mock both the local timer and the instance timer
mock_timer = Mock()
mock_timer.elapsed.side_effect = [5.0, 11.0, 11.0, 11.0]
process._timer = mock_timer

with (
patch.object(process, "_process"),
patch.object(process, "is_running", side_effect=[True, False]),
patch.object(process, "kill"),
patch.object(process, "_recycle_resource"),
patch("lisa.util.process.time.sleep"), # Patch at import location
patch("lisa.util.process.create_timer", return_value=mock_timer),
):
process.log_buffer.getvalue.return_value = "partial output"

result = process.wait_result(timeout=10, raise_on_timeout=False)

assert_that(result.is_timeout).is_true()
assert_that(result.stdout).is_equal_to("partial output")
assert_that(result.exit_code).is_equal_to(1)
Loading