Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 28 additions & 8 deletions src/Shared/ProviderConfiguration.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ internal ProviderConfiguration()
internal static ProviderConfiguration ProviderConfigurationForSessionState(NameValueCollection config)
{
ProviderConfiguration configuration = new ProviderConfiguration(config);

configuration.ThrowOnError = GetBoolSettings(config, "throwOnError", true);
int retryTimeoutInMilliSec = GetIntSettings(config, "retryTimeoutInMilliseconds", 5000);
configuration.RetryTimeout = new TimeSpan(0, 0, 0, 0, retryTimeoutInMilliSec);

// Get request timeout from config
HttpRuntimeSection httpRuntimeSection = ConfigurationManager.GetSection("system.web/httpRuntime") as HttpRuntimeSection;
configuration.RequestTimeout = httpRuntimeSection.ExecutionTimeout;
Expand All @@ -57,10 +57,10 @@ internal static ProviderConfiguration ProviderConfigurationForSessionState(NameV
internal static ProviderConfiguration ProviderConfigurationForOutputCache(NameValueCollection config)
{
ProviderConfiguration configuration = new ProviderConfiguration(config);

// No retry login for output cache provider
configuration.RetryTimeout = TimeSpan.Zero;

// Session state specific attribute which are not applicable to output cache
configuration.ThrowOnError = true;
configuration.RequestTimeout = TimeSpan.Zero;
Expand All @@ -81,7 +81,7 @@ private ProviderConfiguration(NameValueCollection config)
Port = GetIntSettings(config, "port", 0);
AccessKey = GetStringSettings(config, "accessKey", null);
UseSsl = GetBoolSettings(config, "ssl", true);

// All below parameters are only fetched from web.config
DatabaseId = GetIntSettings(config, "databaseId", 0);
ApplicationName = GetStringSettings(config, "applicationName", null);
Expand Down Expand Up @@ -129,6 +129,12 @@ private static string GetStringSettings(NameValueCollection config, string attrN
return defaultVal;
}

string connectionStringValue = GetFromConnectionString(literalValue);
if (!string.IsNullOrEmpty(connectionStringValue))
{
return connectionStringValue;
}

string appSettingsValue = GetFromAppSetting(literalValue);
if (!string.IsNullOrEmpty(appSettingsValue))
{
Expand Down Expand Up @@ -205,6 +211,20 @@ private static string GetFromAppSetting(string attrName)
return null;
}

private static string GetFromConnectionString(string connectionStringName)
{
if (!string.IsNullOrEmpty(connectionStringName))
{
var connectionString = ConfigurationManager.ConnectionStrings[connectionStringName];

if (connectionString != null)
{
return connectionString.ConnectionString;
}
}
return null;
}

// Reads string value from web.config session state section
private static string GetFromConfig(NameValueCollection config, string attrName)
{
Expand All @@ -220,12 +240,12 @@ internal static void EnableLoggingIfParametersAvailable(NameValueCollection conf
{
string LoggingClassName = GetStringSettings(config, "loggingClassName", null);
string LoggingMethodName = GetStringSettings(config, "loggingMethodName", null);

if( !string.IsNullOrEmpty(LoggingClassName) && !string.IsNullOrEmpty(LoggingMethodName) )
{
// Find 'Type' that is same as fully qualified class name if not found than also don't throw error and ignore case while searching
Type LoggingClass = Type.GetType(LoggingClassName, throwOnError: false, ignoreCase: true);

if (LoggingClass == null)
{
// If class name is not assembly qualified name than look for class in all assemblies one by one
Expand Down Expand Up @@ -264,7 +284,7 @@ internal static Type GetLoggingClass(string LoggingClassName)
if (LoggingClass == null)
{
// If class name is not assembly qualified name and it also doesn't contain namespace (it is just class name) than
// try to use assembly name as namespace and try to load class from all assemblies one by one
// try to use assembly name as namespace and try to load class from all assemblies one by one
LoggingClass = a.GetType(a.GetName().Name + "." + LoggingClassName, throwOnError: false, ignoreCase: true);
}
if (LoggingClass != null)
Expand Down
91 changes: 59 additions & 32 deletions src/Shared/StackExchangeClientConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

using System;
using System.Diagnostics;
using System.Net;
using System.Web.SessionState;
using StackExchange.Redis;

Expand All @@ -13,19 +14,24 @@ namespace Microsoft.Web.Redis
internal class StackExchangeClientConnection : IRedisClientConnection
{

ConnectionMultiplexer redisMultiplexer;
IDatabase connection;
ProviderConfiguration configuration;
ConnectionMultiplexer _redisMultiplexer;
IDatabase _connection;
ProviderConfiguration _configuration;

public StackExchangeClientConnection(ProviderConfiguration configuration)
{
this.configuration = configuration;
_configuration = configuration;
ConfigurationOptions configOption;

// If connection string is given then use it otherwise use individual options
if (!string.IsNullOrEmpty(configuration.ConnectionString))
{
configOption = ConfigurationOptions.Parse(configuration.ConnectionString);

if (!string.IsNullOrEmpty(configOption.ServiceName))
{
ModifyEndpointsForSentinelConfiguration(configOption);
}
}
else
{
Expand All @@ -52,42 +58,65 @@ public StackExchangeClientConnection(ProviderConfiguration configuration)
configOption.SyncTimeout = configuration.OperationTimeoutInMilliSec;
}
}
if (LogUtility.logger == null)

_redisMultiplexer = LogUtility.logger == null ? ConnectionMultiplexer.Connect(configOption) : ConnectionMultiplexer.Connect(configOption, LogUtility.logger);

_connection = _redisMultiplexer.GetDatabase(configuration.DatabaseId);
}

private static void ModifyEndpointsForSentinelConfiguration(ConfigurationOptions configOption)
{
var sentinelConfiguration = new ConfigurationOptions
{
redisMultiplexer = ConnectionMultiplexer.Connect(configOption);
}
else
CommandMap = CommandMap.Sentinel,
TieBreaker = "",
ServiceName = configOption.ServiceName,
SyncTimeout = configOption.SyncTimeout
};

EndPoint masterEndPoint = null;

foreach (var endpoint in configOption.EndPoints)
{
redisMultiplexer = ConnectionMultiplexer.Connect(configOption, LogUtility.logger);
sentinelConfiguration.EndPoints.Add(endpoint);
var sentinelConnection = ConnectionMultiplexer.Connect(sentinelConfiguration);
masterEndPoint = sentinelConnection.GetServer(endpoint).SentinelGetMasterAddressByName(sentinelConfiguration.ServiceName);

if (masterEndPoint != null)
{
break;
}
}
this.connection = redisMultiplexer.GetDatabase(configuration.DatabaseId);

configOption.EndPoints.Clear();
configOption.EndPoints.Add(masterEndPoint);
}

public IDatabase RealConnection
{
get { return connection; }
get { return _connection; }
}

public void Open()
{ }

public void Close()
{
redisMultiplexer.Close();
_redisMultiplexer.Close();
}

public bool Expiry(string key, int timeInSeconds)
{
TimeSpan timeSpan = new TimeSpan(0, 0, timeInSeconds);
RedisKey redisKey = key;
return (bool)RetryLogic(() => connection.KeyExpire(redisKey,timeSpan));
return (bool)RetryLogic(() => _connection.KeyExpire(redisKey,timeSpan));
}

public object Eval(string script, string[] keyArgs, object[] valueArgs)
{
RedisKey[] redisKeyArgs = new RedisKey[keyArgs.Length];
RedisValue[] redisValueArgs = new RedisValue[valueArgs.Length];

int i = 0;
foreach (string key in keyArgs)
{
Expand All @@ -110,7 +139,7 @@ public object Eval(string script, string[] keyArgs, object[] valueArgs)
}
i++;
}
return RetryLogic(() => connection.ScriptEvaluate(script, redisKeyArgs, redisValueArgs));
return RetryLogic(() => _connection.ScriptEvaluate(script, redisKeyArgs, redisValueArgs));
}

private object RetryForScriptNotFound(Func<object> redisOperation)
Expand Down Expand Up @@ -146,18 +175,16 @@ private object RetryLogic(Func<object> redisOperation)
catch (Exception)
{
TimeSpan passedTime = DateTime.Now - startTime;
if (configuration.RetryTimeout < passedTime)
if (_configuration.RetryTimeout < passedTime)
{
throw;
}
else

var remainingTimeout = (int)(_configuration.RetryTimeout.TotalMilliseconds - passedTime.TotalMilliseconds);
// if remaining time is less than 1 sec than wait only for that much time and than give a last try
if (remainingTimeout < timeToSleepBeforeRetryInMiliseconds)
{
int remainingTimeout = (int)(configuration.RetryTimeout.TotalMilliseconds - passedTime.TotalMilliseconds);
// if remaining time is less than 1 sec than wait only for that much time and than give a last try
if (remainingTimeout < timeToSleepBeforeRetryInMiliseconds)
{
timeToSleepBeforeRetryInMiliseconds = remainingTimeout;
}
timeToSleepBeforeRetryInMiliseconds = remainingTimeout;
}

// First time try after 20 msec after that try after 1 second
Expand All @@ -176,7 +203,7 @@ public int GetSessionTimeout(object rowDataFromRedis)
int sessionTimeout = (int)lockScriptReturnValueArray[2];
if (sessionTimeout == -1)
{
sessionTimeout = (int) configuration.SessionTimeout.TotalSeconds;
sessionTimeout = (int) _configuration.SessionTimeout.TotalSeconds;
}
// converting seconds to minutes
sessionTimeout = sessionTimeout / 60;
Expand All @@ -194,7 +221,7 @@ public bool IsLocked(object rowDataFromRedis)

public string GetLockId(object rowDataFromRedis)
{
return StackExchangeClientConnection.GetLockIdStatic(rowDataFromRedis);
return GetLockIdStatic(rowDataFromRedis);
}

internal static string GetLockIdStatic(object rowDataFromRedis)
Expand All @@ -207,7 +234,7 @@ internal static string GetLockIdStatic(object rowDataFromRedis)

public ISessionStateItemCollection GetSessionData(object rowDataFromRedis)
{
return StackExchangeClientConnection.GetSessionDataStatic(rowDataFromRedis);
return GetSessionDataStatic(rowDataFromRedis);
}

internal static ISessionStateItemCollection GetSessionDataStatic(object rowDataFromRedis)
Expand All @@ -220,7 +247,7 @@ internal static ISessionStateItemCollection GetSessionDataStatic(object rowDataF
if (lockScriptReturnValueArray.Length > 1 && lockScriptReturnValueArray[1] != null)
{
RedisResult[] data = (RedisResult[])lockScriptReturnValueArray[1];

// LUA script returns data as object array so keys and values are store one after another
// This list has to be even because it contains pair of <key, value> as {key, value, key, value}
if (data != null && data.Length != 0 && data.Length % 2 == 0)
Expand All @@ -247,23 +274,23 @@ public void Set(string key, byte[] data, DateTime utcExpiry)
RedisKey redisKey = key;
RedisValue redisValue = data;
TimeSpan timeSpanForExpiry = utcExpiry - DateTime.UtcNow;
connection.StringSet(redisKey, redisValue, timeSpanForExpiry);
_connection.StringSet(redisKey, redisValue, timeSpanForExpiry);
}

public byte[] Get(string key)
{
RedisKey redisKey = key;
RedisValue redisValue = connection.StringGet(redisKey);
return (byte[]) redisValue;
RedisValue redisValue = _connection.StringGet(redisKey);
return redisValue;
}

public void Remove(string key)
{
RedisKey redisKey = key;
connection.KeyDelete(redisKey);
_connection.KeyDelete(redisKey);
}

public byte[] GetOutputCacheDataFromResult(object rowDataFromRedis)
public byte[] GetOutputCacheDataFromResult(object rowDataFromRedis)
{
RedisResult rowDataAsRedisResult = (RedisResult)rowDataFromRedis;
return (byte[]) rowDataAsRedisResult;
Expand Down