Memoization
I've spent a couple of hours to write a couple of helper classes to help with the memoization of function calls. Since I haven't found anything as elegant on Google, I thought I'd put them here in case anyone else can use them.
Caveat: the code is not thread-safe. I didn't have a use case for it and thread safety is tricky so I didn't want to spend the time. Do not use in multi-threaded code!
First, the Memoizer class:
public class Memoizer
{
public static Func<T> Memoize<T>(Func<T> f)
{
// do NOT inline - it will create a new one on each subsequent call and thus not cache anything
var map = new KeyMap<T>();
return () => map.GetValue(new KeyStruct(), f);
}
public static Func<T1, T> Memoize<T1, T>(Func<T1, T> f)
{
// do NOT inline - it will create a new one on each subsequent call and thus not cache anything
var map = new KeyMap<T>();
return arg => map.GetValue(new KeyStruct(arg), () => f(arg));
}
public static Func<T1, T2, T> Memoize<T1, T2, T>(Func<T1, T2, T> f)
{
// do NOT inline - it will create a new one on each subsequent call and thus not cache anything
var map = new KeyMap<T>();
return (a1, a2) => map.GetValue(new KeyStruct(a1, a2), () => f(a1, a2));
}
}
As the comment points out, though it is tempting to make the methods one line, the map creation must not be inlined in the lambda because a new one will be created each time, thus disabling any caching. It also must not be created once, in a constructor, because you do not want to use the same cache for two distinct functions that share the same signature - caching Math.Sin and Math.Cos would be bad, even though they both have the same Func<double, double> signature.
The above class uses two others:
public class KeyMap<T>
{
public KeyMap()
{
map = new Dictionary<KeyStruct, T>();
}
public T GetValue(KeyStruct key, Func<T> f)
{
T value;
if (!map.TryGetValue(key, out value))
{
value = f();
map.Add(key, value);
}
return value;
}
//
private readonly Dictionary<KeyStruct, T> map;
}
and
public class KeyStruct
{
public KeyStruct(params object[] args)
{
if (args == null)
throw new ArgumentNullException("args");
this.args = args.Select(arg => arg ?? 0).ToArray();
}
public override bool Equals(object obj)
{
if (!(obj is KeyStruct))
return false;
var compareTo = (KeyStruct) obj;
if (args.Length != compareTo.args.Length)
return false;
return args
.Zip(compareTo.args, (o1, o2) => new { o1, o2 })
.All(pair => pair.o1.Equals(pair.o2));
}
public override int GetHashCode()
{
unchecked
{
return args.Aggregate(17, (acc, arg) => acc * 23 + arg.GetHashCode());
}
}
public static bool operator ==(KeyStruct obj1, KeyStruct obj2)
{
return obj1.Equals(obj2);
}
public static bool operator !=(KeyStruct obj1, KeyStruct obj2)
{
return !(obj1 == obj2);
}
//
private readonly object[] args;
}
The KeyStruct class is simply there to allow me to use a dictionary even though the key can only be a single value. For the hash calculation I have adapted Jon Skeet's code here.
Since I have developed these classes using TDD, here are the final tests:
[TestClass]
public class KeyStructTests
{
[TestMethod]
public void EmptyListOfArgsReturnsHash17()
{
var sut = new KeyStruct(new object[0]);
var result = sut.GetHashCode();
Assert.AreEqual(17, result);
}
[TestMethod]
public void TwoEmptyKeyStructsAreEqual()
{
var k1 = new KeyStruct();
var k2 = new KeyStruct();
var result = k1 == k2;
Assert.IsTrue(result);
}
[TestMethod]
public void OneArgEqualTo0ReturnsHash391()
{
var sut = new KeyStruct(new object[] { 0 });
var result = sut.GetHashCode();
Assert.AreEqual(391, result);
}
[TestMethod]
public void OneArgEqualTo5ReturnsHash396()
{
var sut = new KeyStruct(new object[] { 5 });
var result = sut.GetHashCode();
Assert.AreEqual(396, result);
}
[TestMethod]
public void EqualityReturnsTrueForEqualValues()
{
var s1 = new KeyStruct(new object[] { 1, "a", 3.5m });
var s2 = new KeyStruct(new object[] { 1, "a", 3.5m });
var result = s1 == s2;
Assert.IsTrue(result);
}
[TestMethod]
public void EqualityReturnsFalseForDifferentValues_1()
{
var s1 = new KeyStruct(new object[] { 1, "a", 3.5m });
var s2 = new KeyStruct(new object[] { 2, "b", 4.5m });
var result = s1 == s2;
Assert.IsFalse(result);
}
[TestMethod]
public void EqualityReturnsFalseForDifferentValues_2()
{
var s1 = new KeyStruct(new object[] { 1, "a", 3.5m });
var s2 = new KeyStruct(new object[] { 1, "a" });
var result = s1 == s2;
Assert.IsFalse(result);
}
}
and
public class MemoizerTests
{
[TestClass]
public class ZeroArgs
{
[TestMethod]
public void CallsTheUnderlyingFuncOnFirstCall()
{
var called = false;
Func<int> original = () =>
{
called = true;
return 5;
};
var memoized = Memoizer.Memoize(original);
var result = memoized.Invoke();
Assert.AreEqual(5, result);
Assert.IsTrue(called);
}
[TestMethod]
public void ReturnsCachedValueForSubsequentCalls()
{
var called = 0;
Func<int> original = () =>
{
called++;
return 5;
};
var memoized = Memoizer.Memoize(original);
memoized.Invoke();
memoized.Invoke();
var result = memoized.Invoke();
Assert.AreEqual(5, result);
Assert.AreEqual(1, called);
}
}
[TestClass]
public class OneArg
{
[TestMethod]
public void CallsTheUnderlyingFuncOnFirstCall()
{
var called = false;
Func<int, int> original = arg =>
{
called = true;
return arg * 5;
};
var memoized = Memoizer.Memoize(original);
var result = memoized.Invoke(7);
Assert.AreEqual(35, result);
Assert.IsTrue(called);
}
[TestMethod]
public void ReturnsCachedValueForSubsequentCalls()
{
var called = 0;
Func<int, int> original = arg =>
{
called++;
return arg * 5;
};
var memoized = Memoizer.Memoize(original);
memoized.Invoke(7);
memoized.Invoke(7);
var result = memoized.Invoke(7);
Assert.AreEqual(35, result);
Assert.AreEqual(1, called);
}
[TestMethod]
public void CallsTheUnderlyingFuncWhenCalledWithDifferentArguments()
{
var called = new Dictionary<int, bool>();
Func<int, int> original = arg =>
{
called[arg] = true;
return arg * 5;
};
var memoized = Memoizer.Memoize(original);
var result1 = memoized.Invoke(7);
var result2 = memoized.Invoke(8);
Assert.AreEqual(35, result1);
Assert.AreEqual(40, result2);
Assert.IsTrue(called[7]);
Assert.IsTrue(called[8]);
}
}
[TestClass]
public class TwoArgs
{
[TestMethod]
public void CallsTheUnderlyingFuncOnFirstCall()
{
var called = false;
Func<int, int, int> original = (a1, a2) =>
{
called = true;
return a1 * a2 * 5;
};
var memoized = Memoizer.Memoize(original);
var result = memoized.Invoke(7, 8);
Assert.AreEqual(280, result);
Assert.IsTrue(called);
}
[TestMethod]
public void ReturnsCachedValueForSubsequentCalls()
{
var called = 0;
Func<int, int, int> original = (a1, a2) =>
{
called++;
return a1 * a2 * 5;
};
var memoized = Memoizer.Memoize(original);
memoized.Invoke(7, 8);
memoized.Invoke(7, 8);
var result = memoized.Invoke(7, 8);
Assert.AreEqual(280, result);
Assert.AreEqual(1, called);
}
[TestMethod]
public void CallsTheUnderlyingFuncWhenCalledWithDifferentArguments()
{
var called = new Dictionary<int, bool>();
Func<int, int, int> original = (a1, a2) =>
{
called[a1 * 10 + a2] = true;
return a1 * a2 * 5;
};
var memoized = Memoizer.Memoize(original);
var result1 = memoized.Invoke(7, 8);
var result2 = memoized.Invoke(8, 7);
Assert.AreEqual(280, result1);
Assert.AreEqual(280, result2);
Assert.IsTrue(called[78]);
Assert.IsTrue(called[87]);
}
}
}
As I said, I hope this helps someone. Please let me know if it does - or if you spot any problems.
Comments