
I've spent a couple of hours to write a couple of helper classes to help with the memoization of function calls. Since I haven't found anything as elegant on Google, I thought I'd put them here in case anyone else can use them.

Caveat: the code is not thread-safe. I didn't have a use case for it and thread safety is tricky so I didn't want to spend the time. Do not use in multi-threaded code!

First, the Memoizer class:

  public class Memoizer
    public static Func<T> Memoize<T>(Func<T> f)
      // do NOT inline - it will create a new one on each subsequent call and thus not cache anything
      var map = new KeyMap<T>();

      return () => map.GetValue(new KeyStruct(), f);

    public static Func<T1, T> Memoize<T1, T>(Func<T1, T> f)
      // do NOT inline - it will create a new one on each subsequent call and thus not cache anything
      var map = new KeyMap<T>();

      return arg => map.GetValue(new KeyStruct(arg), () => f(arg));

    public static Func<T1, T2, T> Memoize<T1, T2, T>(Func<T1, T2, T> f)
      // do NOT inline - it will create a new one on each subsequent call and thus not cache anything
      var map = new KeyMap<T>();

      return (a1, a2) => map.GetValue(new KeyStruct(a1, a2), () => f(a1, a2));

As the comment points out, though it is tempting to make the methods one line, the map creation must not be inlined in the lambda because a new one will be created each time, thus disabling any caching. It also must not be created once, in a constructor, because you do not want to use the same cache for two distinct functions that share the same signature - caching Math.Sin and Math.Cos would be bad, even though they both have the same Func<double, double> signature.

The above class uses two others:

  public class KeyMap<T>
    public KeyMap()
      map = new Dictionary<KeyStruct, T>();

    public T GetValue(KeyStruct key, Func<T> f)
      T value;
      if (!map.TryGetValue(key, out value))
        value = f();
        map.Add(key, value);

      return value;


    private readonly Dictionary<KeyStruct, T> map;


  public class KeyStruct
    public KeyStruct(params object[] args)
      if (args == null)
        throw new ArgumentNullException("args");

      this.args = args.Select(arg => arg ?? 0).ToArray();

    public override bool Equals(object obj)
      if (!(obj is KeyStruct))
        return false;

      var compareTo = (KeyStruct) obj;
      if (args.Length != compareTo.args.Length)
        return false;

      return args
        .Zip(compareTo.args, (o1, o2) => new { o1, o2 })
        .All(pair => pair.o1.Equals(pair.o2));

    public override int GetHashCode()
        return args.Aggregate(17, (acc, arg) => acc * 23 + arg.GetHashCode());

    public static bool operator ==(KeyStruct obj1, KeyStruct obj2)
      return obj1.Equals(obj2);

    public static bool operator !=(KeyStruct obj1, KeyStruct obj2)
      return !(obj1 == obj2);


    private readonly object[] args;

The KeyStruct class is simply there to allow me to use a dictionary even though the key can only be a single value. For the hash calculation I have adapted Jon Skeet's code here.

Since I have developed these classes using TDD, here are the final tests:

  public class KeyStructTests
    public void EmptyListOfArgsReturnsHash17()
      var sut = new KeyStruct(new object[0]);

      var result = sut.GetHashCode();

      Assert.AreEqual(17, result);

    public void TwoEmptyKeyStructsAreEqual()
      var k1 = new KeyStruct();
      var k2 = new KeyStruct();

      var result = k1 == k2;


    public void OneArgEqualTo0ReturnsHash391()
      var sut = new KeyStruct(new object[] { 0 });

      var result = sut.GetHashCode();

      Assert.AreEqual(391, result);

    public void OneArgEqualTo5ReturnsHash396()
      var sut = new KeyStruct(new object[] { 5 });

      var result = sut.GetHashCode();

      Assert.AreEqual(396, result);

    public void EqualityReturnsTrueForEqualValues()
      var s1 = new KeyStruct(new object[] { 1, "a", 3.5m });
      var s2 = new KeyStruct(new object[] { 1, "a", 3.5m });

      var result = s1 == s2;


    public void EqualityReturnsFalseForDifferentValues_1()
      var s1 = new KeyStruct(new object[] { 1, "a", 3.5m });
      var s2 = new KeyStruct(new object[] { 2, "b", 4.5m });

      var result = s1 == s2;


    public void EqualityReturnsFalseForDifferentValues_2()
      var s1 = new KeyStruct(new object[] { 1, "a", 3.5m });
      var s2 = new KeyStruct(new object[] { 1, "a" });

      var result = s1 == s2;



  public class MemoizerTests
    public class ZeroArgs
      public void CallsTheUnderlyingFuncOnFirstCall()
        var called = false;
        Func<int> original = () =>
          called = true;
          return 5;
        var memoized = Memoizer.Memoize(original);

        var result = memoized.Invoke();

        Assert.AreEqual(5, result);

      public void ReturnsCachedValueForSubsequentCalls()
        var called = 0;
        Func<int> original = () =>
          return 5;
        var memoized = Memoizer.Memoize(original);

        var result = memoized.Invoke();

        Assert.AreEqual(5, result);
        Assert.AreEqual(1, called);

    public class OneArg
      public void CallsTheUnderlyingFuncOnFirstCall()
        var called = false;
        Func<int, int> original = arg =>
          called = true;
          return arg * 5;
        var memoized = Memoizer.Memoize(original);

        var result = memoized.Invoke(7);

        Assert.AreEqual(35, result);

      public void ReturnsCachedValueForSubsequentCalls()
        var called = 0;
        Func<int, int> original = arg =>
          return arg * 5;
        var memoized = Memoizer.Memoize(original);

        var result = memoized.Invoke(7);

        Assert.AreEqual(35, result);
        Assert.AreEqual(1, called);

      public void CallsTheUnderlyingFuncWhenCalledWithDifferentArguments()
        var called = new Dictionary<int, bool>();
        Func<int, int> original = arg =>
          called[arg] = true;

          return arg * 5;
        var memoized = Memoizer.Memoize(original);

        var result1 = memoized.Invoke(7);
        var result2 = memoized.Invoke(8);

        Assert.AreEqual(35, result1);
        Assert.AreEqual(40, result2);

    public class TwoArgs
      public void CallsTheUnderlyingFuncOnFirstCall()
        var called = false;
        Func<int, int, int> original = (a1, a2) =>
          called = true;
          return a1 * a2 * 5;
        var memoized = Memoizer.Memoize(original);

        var result = memoized.Invoke(7, 8);

        Assert.AreEqual(280, result);

      public void ReturnsCachedValueForSubsequentCalls()
        var called = 0;
        Func<int, int, int> original = (a1, a2) =>
          return a1 * a2 * 5;
        var memoized = Memoizer.Memoize(original);

        memoized.Invoke(7, 8);
        memoized.Invoke(7, 8);
        var result = memoized.Invoke(7, 8);

        Assert.AreEqual(280, result);
        Assert.AreEqual(1, called);

      public void CallsTheUnderlyingFuncWhenCalledWithDifferentArguments()
        var called = new Dictionary<int, bool>();
        Func<int, int, int> original = (a1, a2) =>
          called[a1 * 10 + a2] = true;

          return a1 * a2 * 5;
        var memoized = Memoizer.Memoize(original);

        var result1 = memoized.Invoke(7, 8);
        var result2 = memoized.Invoke(8, 7);

        Assert.AreEqual(280, result1);
        Assert.AreEqual(280, result2);

As I said, I hope this helps someone. Please let me know if it does - or if you spot any problems.


Ivan Stepaniuk said…
Kudos for including the tests in your posts!
Marcel said…
Thank you. Yes, I prefer to write everything test-first, it helps in a lot of ways.

