Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
244 changes: 215 additions & 29 deletions rs/state_machine_tests/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ use ic_crypto_test_utils_crypto_returning_ok::CryptoReturningOk;
use ic_crypto_test_utils_ni_dkg::{
SecretKeyBytes, dummy_initial_dkg_transcript_with_master_key, sign_message,
};
use ic_crypto_tree_hash::{Label, Path as LabeledTreePath, sparse_labeled_tree_from_paths};
use ic_crypto_tree_hash::{
Label, LabeledTree, MatchPatternPath, MixedHashTree, Path as LabeledTreePath,
sparse_labeled_tree_from_paths,
};
use ic_crypto_utils_threshold_sig_der::threshold_sig_public_key_to_der;
use ic_cycles_account_manager::{CyclesAccountManager, IngressInductionCost};
use ic_error_types::RejectCode;
Expand All @@ -46,9 +49,13 @@ use ic_interfaces::{
p2p::consensus::MutablePool,
validation::ValidationResult,
};
use ic_interfaces_certified_stream_store::{CertifiedStreamStore, EncodeStreamError};
use ic_interfaces_certified_stream_store::{
CertifiedStreamStore, DecodeStreamError, EncodeStreamError,
};
use ic_interfaces_registry::RegistryClient;
use ic_interfaces_state_manager::{CertificationScope, StateHashError, StateManager, StateReader};
use ic_interfaces_state_manager::{
CertificationScope, CertifiedStateSnapshot, Labeled, StateHashError, StateManager, StateReader,
};
use ic_limits::{MAX_INGRESS_TTL, PERMITTED_DRIFT, SMALL_APP_SUBNET_MAX_SIZE};
use ic_logger::replica_logger::no_op_logger;
use ic_logger::{ReplicaLogger, error};
Expand Down Expand Up @@ -166,8 +173,9 @@ use ic_types::{
},
nominal_cycles::NominalCycles,
signature::ThresholdSignature,
state_manager::StateManagerResult,
time::{GENESIS, Time},
xnet::{CertifiedStreamSlice, StreamIndex},
xnet::{CertifiedStreamSlice, StreamIndex, StreamSlice},
};
use ic_xnet_payload_builder::{
RefillTaskHandle, XNetPayloadBuilderImpl, XNetPayloadBuilderMetrics, XNetSlicePoolImpl,
Expand All @@ -185,6 +193,7 @@ use std::{
fmt,
io::{self, stderr},
net::Ipv6Addr,
ops::Deref,
path::{Path, PathBuf},
str::FromStr,
string::ToString,
Expand Down Expand Up @@ -896,6 +905,183 @@ enum SignatureSecretKey {
VetKD(ic_crypto_test_utils_vetkd::PrivateKey),
}

pub struct StateMachineStateManager {
inner: StateManagerImpl,
// This field must be the last one so that the temporary directory is deleted at the very end.
state_dir: Option<Box<dyn StateMachineStateDir>>,
// DO NOT PUT ANY FIELDS AFTER `state_dir`!!!
}

impl StateMachineStateManager {
fn state_dir_path(&self) -> PathBuf {
self.state_dir
.as_ref()
.expect("StateMachineStateManager uninitialized")
.path()
}

fn into_state_dir(mut self) -> Box<dyn StateMachineStateDir> {
self.state_dir
.take()
.expect("StateMachineStateManager uninitialized")
}
}

impl Drop for StateMachineStateManager {
fn drop(&mut self) {
// Finish any asynchronous state manager operations before dropping the state manager.
self.inner.flush_all();
}
}

impl Deref for StateMachineStateManager {
type Target = StateManagerImpl;

fn deref(&self) -> &Self::Target {
&self.inner
}
}

impl StateManager for StateMachineStateManager {
fn list_state_hashes_to_certify(&self) -> Vec<(Height, CryptoHashOfPartialState)> {
self.deref().list_state_hashes_to_certify()
}

fn deliver_state_certification(&self, certification: Certification) {
self.deref().deliver_state_certification(certification)
}

fn get_state_hash_at(&self, height: Height) -> Result<CryptoHashOfState, StateHashError> {
self.deref().get_state_hash_at(height)
}

fn fetch_state(
&self,
height: Height,
root_hash: CryptoHashOfState,
cup_interval_length: Height,
) {
self.deref()
.fetch_state(height, root_hash, cup_interval_length)
}

fn remove_states_below(&self, height: Height) {
self.deref().remove_states_below(height)
}

fn remove_inmemory_states_below(
&self,
height: Height,
extra_heights_to_keep: &BTreeSet<Height>,
) {
self.deref()
.remove_inmemory_states_below(height, extra_heights_to_keep)
}

fn commit_and_certify(
&self,
state: ReplicatedState,
height: Height,
scope: CertificationScope,
batch_summary: Option<BatchSummary>,
) {
self.deref()
.commit_and_certify(state, height, scope, batch_summary)
}

fn take_tip(&self) -> (Height, ReplicatedState) {
self.deref().take_tip()
}

fn take_tip_at(&self, height: Height) -> StateManagerResult<ReplicatedState> {
self.deref().take_tip_at(height)
}

fn report_diverged_checkpoint(&self, height: Height) {
self.deref().report_diverged_checkpoint(height)
}
}

impl CertifiedStreamStore for StateMachineStateManager {
fn encode_certified_stream_slice(
&self,
remote_subnet: SubnetId,
witness_begin: Option<StreamIndex>,
msg_begin: Option<StreamIndex>,
msg_limit: Option<usize>,
byte_limit: Option<usize>,
) -> Result<CertifiedStreamSlice, EncodeStreamError> {
self.deref().encode_certified_stream_slice(
remote_subnet,
witness_begin,
msg_begin,
msg_limit,
byte_limit,
)
}

fn decode_certified_stream_slice(
&self,
remote_subnet: SubnetId,
registry_version: RegistryVersion,
certified_slice: &CertifiedStreamSlice,
) -> Result<StreamSlice, DecodeStreamError> {
self.deref()
.decode_certified_stream_slice(remote_subnet, registry_version, certified_slice)
}

fn decode_valid_certified_stream_slice(
&self,
certified_slice: &CertifiedStreamSlice,
) -> Result<StreamSlice, DecodeStreamError> {
self.deref()
.decode_valid_certified_stream_slice(certified_slice)
}

fn subnets_with_certified_streams(&self) -> Vec<SubnetId> {
self.deref().subnets_with_certified_streams()
}
}

impl StateReader for StateMachineStateManager {
type State = ReplicatedState;

fn get_state_at(&self, height: Height) -> StateManagerResult<Labeled<Arc<ReplicatedState>>> {
self.deref().get_state_at(height)
}

fn get_latest_state(&self) -> Labeled<Arc<ReplicatedState>> {
self.deref().get_latest_state()
}

fn get_latest_certified_state(&self) -> Option<Labeled<Arc<ReplicatedState>>> {
self.deref().get_latest_certified_state()
}

fn latest_state_height(&self) -> Height {
self.deref().latest_state_height()
}

fn latest_certified_height(&self) -> Height {
self.deref().latest_certified_height()
}

fn read_certified_state_with_exclusion(
&self,
paths: &LabeledTree<()>,
exclusion: Option<&MatchPatternPath>,
) -> Option<(Arc<ReplicatedState>, MixedHashTree, Certification)> {
self.deref()
.read_certified_state_with_exclusion(paths, exclusion)
}

fn get_certified_state_snapshot(
&self,
) -> Option<Box<dyn CertifiedStateSnapshot<State = ReplicatedState> + 'static>> {
self.deref().get_certified_state_snapshot()
}
}

/// Represents a replicated state machine detached from the network layer that
/// can be used to test this part of the stack in isolation.
pub struct StateMachine {
Expand All @@ -909,7 +1095,7 @@ pub struct StateMachine {
is_vetkd_enabled: bool,
registry_data_provider: Arc<ProtoRegistryDataProvider>,
pub registry_client: Arc<FakeRegistryClient>,
pub state_manager: Arc<StateManagerImpl>,
pub state_manager: Arc<StateMachineStateManager>,
consensus_time: Arc<PocketConsensusTime>,
ingress_pool: Arc<RwLock<PocketIngressPool>>,
ingress_manager: Arc<IngressManager>,
Expand Down Expand Up @@ -952,9 +1138,6 @@ pub struct StateMachine {
remove_old_states: bool,
cycles_account_manager: Arc<CyclesAccountManager>,
cost_schedule: CanisterCyclesCostSchedule,
// This field must be the last one so that the temporary directory is deleted at the very end.
state_dir: Box<dyn StateMachineStateDir>,
// DO NOT PUT ANY FIELDS AFTER `state_dir`!!!
}

impl Default for StateMachine {
Expand All @@ -966,7 +1149,7 @@ impl Default for StateMachine {
impl fmt::Debug for StateMachine {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("StateMachine")
.field("state_dir", &self.state_dir.path().display())
.field("state_dir", &self.state_manager.state_dir_path().display())
.field("nonce", &self.nonce.load(Ordering::Relaxed))
.finish()
}
Expand Down Expand Up @@ -1755,7 +1938,7 @@ impl StateMachine {
..Default::default()
};

let state_manager = Arc::new(StateManagerImpl::new(
let state_manager_impl = StateManagerImpl::new(
Arc::new(FakeVerifier),
subnet_id,
subnet_type,
Expand All @@ -1764,7 +1947,11 @@ impl StateMachine {
&sm_config,
None,
malicious_flags.clone(),
));
);
let state_manager = Arc::new(StateMachineStateManager {
inner: state_manager_impl,
state_dir: Some(state_dir),
});

if let Some(create_registry_version) = create_at_registry_version {
add_subnet_local_registry_records(
Expand Down Expand Up @@ -2054,7 +2241,6 @@ impl StateMachine {
_ingress_watcher_drop_guard: ingress_watcher_drop_guard,
certified_height_tx,
runtime,
state_dir,
// Note: state machine tests are commonly used for testing
// canisters, such tests usually don't rely on any persistence.
checkpoint_interval_length: checkpoint_interval_length.into(),
Expand All @@ -2080,9 +2266,8 @@ impl StateMachine {
}
}

fn into_components_inner(self) -> (Box<dyn StateMachineStateDir>, u64, Time, u64) {
fn into_components_inner(self) -> (u64, Time, u64) {
(
self.state_dir,
self.nonce.into_inner(),
Time::from_nanos_since_unix_epoch(self.time.into_inner()),
self.checkpoint_interval_length.load(Ordering::Relaxed),
Expand All @@ -2091,24 +2276,28 @@ impl StateMachine {

fn into_components(self) -> (Box<dyn StateMachineStateDir>, u64, Time, u64) {
// Finish any asynchronous state manager operations first.
self.state_manager.flush_tip_channel();
self.state_manager
.state_layout()
.flush_checkpoint_removal_channel();
self.state_manager.flush_all();

let state_manager = Arc::downgrade(&self.state_manager);
let result = self.into_components_inner();
let mut state_manager = self.state_manager.clone();
let (nonce, time, checkpoint_interval_length) = self.into_components_inner();
// StateManager is owned by an Arc, that is cloned into multiple components and different
// threads. If we return before all the asynchronous components release the Arc, we may
// end up with to StateManagers writing to the same directory, resulting in a crash.
let start = std::time::Instant::now();
while state_manager.upgrade().is_some() {
std::thread::sleep(std::time::Duration::from_millis(50));
let state_dir = loop {
match Arc::try_unwrap(state_manager) {
Ok(sm) => {
break sm.into_state_dir();
}
Err(sm) => {
state_manager = sm;
}
}
if start.elapsed() > std::time::Duration::from_secs(5 * 60) {
panic!("Timed out while dropping StateMachine.");
}
}
result
};
(state_dir, nonce, time, checkpoint_interval_length)
}

/// Safely drops this `StateMachine`. We cannot achieve this functionality by implementing `Drop`
Expand Down Expand Up @@ -2232,10 +2421,7 @@ impl StateMachine {
let checkpoint_interval_length = if enabled { 0 } else { u64::MAX };
self.set_checkpoint_interval_length(checkpoint_interval_length);
// Finish any asynchronous state manager operations.
self.state_manager.flush_tip_channel();
self.state_manager
.state_layout()
.flush_checkpoint_removal_channel();
self.state_manager.flush_all();
}

/// Set current interval length. The typical interval length
Expand Down Expand Up @@ -4710,7 +4896,7 @@ impl StateMachine {

/// Make sure the latest state is certified.
pub fn certify_latest_state_helper(
state_manager: Arc<StateManagerImpl>,
state_manager: Arc<StateMachineStateManager>,
secret_key: &SecretKeyBytes,
subnet_id: SubnetId,
) {
Expand Down
8 changes: 7 additions & 1 deletion rs/state_manager/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -885,7 +885,7 @@ impl Drop for StateManagerImpl {
// Make sure the tip thread didn't panic. Otherwise we may be blind to it in tests.
// If the tip thread panics after the latest communication with tip_channel the test returns
// success.
self.flush_tip_channel();
self.flush_all();
}
}

Expand Down Expand Up @@ -1224,6 +1224,12 @@ impl StateManagerImpl {
flush_tip_channel(&self.tip_channel)
}

/// Finish all asynchronous operations.
pub fn flush_all(&self) {
self.flush_tip_channel();
self.state_layout().flush_checkpoint_removal_channel();
}

/// Height for the initial default state.
const INITIAL_STATE_HEIGHT: Height = Height::new(0);

Expand Down