Memoization
I've spent a couple of hours to write a couple of helper classes to help with the memoization of function calls. Since I haven't found anything as elegant on Google, I thought I'd put them here in case anyone else can use them.
Caveat: the code is not thread-safe. I didn't have a use case for it and thread safety is tricky so I didn't want to spend the time. Do not use in multi-threaded code!
First, the Memoizer
class:
public class Memoizer { public static Func<T> Memoize<T>(Func<T> f) { // do NOT inline - it will create a new one on each subsequent call and thus not cache anything var map = new KeyMap<T>(); return () => map.GetValue(new KeyStruct(), f); } public static Func<T1, T> Memoize<T1, T>(Func<T1, T> f) { // do NOT inline - it will create a new one on each subsequent call and thus not cache anything var map = new KeyMap<T>(); return arg => map.GetValue(new KeyStruct(arg), () => f(arg)); } public static Func<T1, T2, T> Memoize<T1, T2, T>(Func<T1, T2, T> f) { // do NOT inline - it will create a new one on each subsequent call and thus not cache anything var map = new KeyMap<T>(); return (a1, a2) => map.GetValue(new KeyStruct(a1, a2), () => f(a1, a2)); } }
As the comment points out, though it is tempting to make the methods one line, the map creation must not be inlined in the lambda because a new one will be created each time, thus disabling any caching. It also must not be created once, in a constructor, because you do not want to use the same cache for two distinct functions that share the same signature - caching Math.Sin
and Math.Cos
would be bad, even though they both have the same Func<double, double>
signature.
The above class uses two others:
public class KeyMap<T> { public KeyMap() { map = new Dictionary<KeyStruct, T>(); } public T GetValue(KeyStruct key, Func<T> f) { T value; if (!map.TryGetValue(key, out value)) { value = f(); map.Add(key, value); } return value; } // private readonly Dictionary<KeyStruct, T> map; }
and
public class KeyStruct { public KeyStruct(params object[] args) { if (args == null) throw new ArgumentNullException("args"); this.args = args.Select(arg => arg ?? 0).ToArray(); } public override bool Equals(object obj) { if (!(obj is KeyStruct)) return false; var compareTo = (KeyStruct) obj; if (args.Length != compareTo.args.Length) return false; return args .Zip(compareTo.args, (o1, o2) => new { o1, o2 }) .All(pair => pair.o1.Equals(pair.o2)); } public override int GetHashCode() { unchecked { return args.Aggregate(17, (acc, arg) => acc * 23 + arg.GetHashCode()); } } public static bool operator ==(KeyStruct obj1, KeyStruct obj2) { return obj1.Equals(obj2); } public static bool operator !=(KeyStruct obj1, KeyStruct obj2) { return !(obj1 == obj2); } // private readonly object[] args; }
The KeyStruct
class is simply there to allow me to use a dictionary even though the key can only be a single value. For the hash calculation I have adapted Jon Skeet's code here.
Since I have developed these classes using TDD, here are the final tests:
[TestClass] public class KeyStructTests { [TestMethod] public void EmptyListOfArgsReturnsHash17() { var sut = new KeyStruct(new object[0]); var result = sut.GetHashCode(); Assert.AreEqual(17, result); } [TestMethod] public void TwoEmptyKeyStructsAreEqual() { var k1 = new KeyStruct(); var k2 = new KeyStruct(); var result = k1 == k2; Assert.IsTrue(result); } [TestMethod] public void OneArgEqualTo0ReturnsHash391() { var sut = new KeyStruct(new object[] { 0 }); var result = sut.GetHashCode(); Assert.AreEqual(391, result); } [TestMethod] public void OneArgEqualTo5ReturnsHash396() { var sut = new KeyStruct(new object[] { 5 }); var result = sut.GetHashCode(); Assert.AreEqual(396, result); } [TestMethod] public void EqualityReturnsTrueForEqualValues() { var s1 = new KeyStruct(new object[] { 1, "a", 3.5m }); var s2 = new KeyStruct(new object[] { 1, "a", 3.5m }); var result = s1 == s2; Assert.IsTrue(result); } [TestMethod] public void EqualityReturnsFalseForDifferentValues_1() { var s1 = new KeyStruct(new object[] { 1, "a", 3.5m }); var s2 = new KeyStruct(new object[] { 2, "b", 4.5m }); var result = s1 == s2; Assert.IsFalse(result); } [TestMethod] public void EqualityReturnsFalseForDifferentValues_2() { var s1 = new KeyStruct(new object[] { 1, "a", 3.5m }); var s2 = new KeyStruct(new object[] { 1, "a" }); var result = s1 == s2; Assert.IsFalse(result); } }
and
public class MemoizerTests { [TestClass] public class ZeroArgs { [TestMethod] public void CallsTheUnderlyingFuncOnFirstCall() { var called = false; Func<int> original = () => { called = true; return 5; }; var memoized = Memoizer.Memoize(original); var result = memoized.Invoke(); Assert.AreEqual(5, result); Assert.IsTrue(called); } [TestMethod] public void ReturnsCachedValueForSubsequentCalls() { var called = 0; Func<int> original = () => { called++; return 5; }; var memoized = Memoizer.Memoize(original); memoized.Invoke(); memoized.Invoke(); var result = memoized.Invoke(); Assert.AreEqual(5, result); Assert.AreEqual(1, called); } } [TestClass] public class OneArg { [TestMethod] public void CallsTheUnderlyingFuncOnFirstCall() { var called = false; Func<int, int> original = arg => { called = true; return arg * 5; }; var memoized = Memoizer.Memoize(original); var result = memoized.Invoke(7); Assert.AreEqual(35, result); Assert.IsTrue(called); } [TestMethod] public void ReturnsCachedValueForSubsequentCalls() { var called = 0; Func<int, int> original = arg => { called++; return arg * 5; }; var memoized = Memoizer.Memoize(original); memoized.Invoke(7); memoized.Invoke(7); var result = memoized.Invoke(7); Assert.AreEqual(35, result); Assert.AreEqual(1, called); } [TestMethod] public void CallsTheUnderlyingFuncWhenCalledWithDifferentArguments() { var called = new Dictionary<int, bool>(); Func<int, int> original = arg => { called[arg] = true; return arg * 5; }; var memoized = Memoizer.Memoize(original); var result1 = memoized.Invoke(7); var result2 = memoized.Invoke(8); Assert.AreEqual(35, result1); Assert.AreEqual(40, result2); Assert.IsTrue(called[7]); Assert.IsTrue(called[8]); } } [TestClass] public class TwoArgs { [TestMethod] public void CallsTheUnderlyingFuncOnFirstCall() { var called = false; Func<int, int, int> original = (a1, a2) => { called = true; return a1 * a2 * 5; }; var memoized = Memoizer.Memoize(original); var result = memoized.Invoke(7, 8); Assert.AreEqual(280, result); Assert.IsTrue(called); } [TestMethod] public void ReturnsCachedValueForSubsequentCalls() { var called = 0; Func<int, int, int> original = (a1, a2) => { called++; return a1 * a2 * 5; }; var memoized = Memoizer.Memoize(original); memoized.Invoke(7, 8); memoized.Invoke(7, 8); var result = memoized.Invoke(7, 8); Assert.AreEqual(280, result); Assert.AreEqual(1, called); } [TestMethod] public void CallsTheUnderlyingFuncWhenCalledWithDifferentArguments() { var called = new Dictionary<int, bool>(); Func<int, int, int> original = (a1, a2) => { called[a1 * 10 + a2] = true; return a1 * a2 * 5; }; var memoized = Memoizer.Memoize(original); var result1 = memoized.Invoke(7, 8); var result2 = memoized.Invoke(8, 7); Assert.AreEqual(280, result1); Assert.AreEqual(280, result2); Assert.IsTrue(called[78]); Assert.IsTrue(called[87]); } } }
As I said, I hope this helps someone. Please let me know if it does - or if you spot any problems.
Comments