Monday, March 18, 2013

Memoization

I've spent a couple of hours to write a couple of helper classes to help with the memoization of function calls. Since I haven't found anything as elegant on Google, I thought I'd put them here in case anyone else can use them.

Caveat: the code is not thread-safe. I didn't have a use case for it and thread safety is tricky so I didn't want to spend the time. Do not use in multi-threaded code!

First, the Memoizer class:

  public class Memoizer
  {
    public static Func<T> Memoize<T>(Func<T> f)
    {
      // do NOT inline - it will create a new one on each subsequent call and thus not cache anything
      var map = new KeyMap<T>();

      return () => map.GetValue(new KeyStruct(), f);
    }

    public static Func<T1, T> Memoize<T1, T>(Func<T1, T> f)
    {
      // do NOT inline - it will create a new one on each subsequent call and thus not cache anything
      var map = new KeyMap<T>();

      return arg => map.GetValue(new KeyStruct(arg), () => f(arg));
    }

    public static Func<T1, T2, T> Memoize<T1, T2, T>(Func<T1, T2, T> f)
    {
      // do NOT inline - it will create a new one on each subsequent call and thus not cache anything
      var map = new KeyMap<T>();

      return (a1, a2) => map.GetValue(new KeyStruct(a1, a2), () => f(a1, a2));
    }
  }

As the comment points out, though it is tempting to make the methods one line, the map creation must not be inlined in the lambda because a new one will be created each time, thus disabling any caching. It also must not be created once, in a constructor, because you do not want to use the same cache for two distinct functions that share the same signature - caching Math.Sin and Math.Cos would be bad, even though they both have the same Func<double, double> signature.

The above class uses two others:

  public class KeyMap<T>
  {
    public KeyMap()
    {
      map = new Dictionary<KeyStruct, T>();
    }

    public T GetValue(KeyStruct key, Func<T> f)
    {
      T value;
      if (!map.TryGetValue(key, out value))
      {
        value = f();
        map.Add(key, value);
      }

      return value;
    }

    //

    private readonly Dictionary<KeyStruct, T> map;
  }

and

  public class KeyStruct
  {
    public KeyStruct(params object[] args)
    {
      if (args == null)
        throw new ArgumentNullException("args");

      this.args = args.Select(arg => arg ?? 0).ToArray();
    }

    public override bool Equals(object obj)
    {
      if (!(obj is KeyStruct))
        return false;

      var compareTo = (KeyStruct) obj;
      if (args.Length != compareTo.args.Length)
        return false;

      return args
        .Zip(compareTo.args, (o1, o2) => new { o1, o2 })
        .All(pair => pair.o1.Equals(pair.o2));
    }

    public override int GetHashCode()
    {
      unchecked
      {
        return args.Aggregate(17, (acc, arg) => acc * 23 + arg.GetHashCode());
      }
    }

    public static bool operator ==(KeyStruct obj1, KeyStruct obj2)
    {
      return obj1.Equals(obj2);
    }

    public static bool operator !=(KeyStruct obj1, KeyStruct obj2)
    {
      return !(obj1 == obj2);
    }

    //

    private readonly object[] args;
  }

The KeyStruct class is simply there to allow me to use a dictionary even though the key can only be a single value. For the hash calculation I have adapted Jon Skeet's code here.

Since I have developed these classes using TDD, here are the final tests:

  [TestClass]
  public class KeyStructTests
  {
    [TestMethod]
    public void EmptyListOfArgsReturnsHash17()
    {
      var sut = new KeyStruct(new object[0]);

      var result = sut.GetHashCode();

      Assert.AreEqual(17, result);
    }

    [TestMethod]
    public void TwoEmptyKeyStructsAreEqual()
    {
      var k1 = new KeyStruct();
      var k2 = new KeyStruct();

      var result = k1 == k2;

      Assert.IsTrue(result);
    }

    [TestMethod]
    public void OneArgEqualTo0ReturnsHash391()
    {
      var sut = new KeyStruct(new object[] { 0 });

      var result = sut.GetHashCode();

      Assert.AreEqual(391, result);
    }

    [TestMethod]
    public void OneArgEqualTo5ReturnsHash396()
    {
      var sut = new KeyStruct(new object[] { 5 });

      var result = sut.GetHashCode();

      Assert.AreEqual(396, result);
    }

    [TestMethod]
    public void EqualityReturnsTrueForEqualValues()
    {
      var s1 = new KeyStruct(new object[] { 1, "a", 3.5m });
      var s2 = new KeyStruct(new object[] { 1, "a", 3.5m });

      var result = s1 == s2;

      Assert.IsTrue(result);
    }

    [TestMethod]
    public void EqualityReturnsFalseForDifferentValues_1()
    {
      var s1 = new KeyStruct(new object[] { 1, "a", 3.5m });
      var s2 = new KeyStruct(new object[] { 2, "b", 4.5m });

      var result = s1 == s2;

      Assert.IsFalse(result);
    }

    [TestMethod]
    public void EqualityReturnsFalseForDifferentValues_2()
    {
      var s1 = new KeyStruct(new object[] { 1, "a", 3.5m });
      var s2 = new KeyStruct(new object[] { 1, "a" });

      var result = s1 == s2;

      Assert.IsFalse(result);
    }
  }

and

  public class MemoizerTests
  {
    [TestClass]
    public class ZeroArgs
    {
      [TestMethod]
      public void CallsTheUnderlyingFuncOnFirstCall()
      {
        var called = false;
        Func<int> original = () =>
        {
          called = true;
          return 5;
        };
        var memoized = Memoizer.Memoize(original);

        var result = memoized.Invoke();

        Assert.AreEqual(5, result);
        Assert.IsTrue(called);
      }

      [TestMethod]
      public void ReturnsCachedValueForSubsequentCalls()
      {
        var called = 0;
        Func<int> original = () =>
        {
          called++;
          return 5;
        };
        var memoized = Memoizer.Memoize(original);

        memoized.Invoke();
        memoized.Invoke();
        var result = memoized.Invoke();

        Assert.AreEqual(5, result);
        Assert.AreEqual(1, called);
      }
    }

    [TestClass]
    public class OneArg
    {
      [TestMethod]
      public void CallsTheUnderlyingFuncOnFirstCall()
      {
        var called = false;
        Func<int, int> original = arg =>
        {
          called = true;
          return arg * 5;
        };
        var memoized = Memoizer.Memoize(original);

        var result = memoized.Invoke(7);

        Assert.AreEqual(35, result);
        Assert.IsTrue(called);
      }

      [TestMethod]
      public void ReturnsCachedValueForSubsequentCalls()
      {
        var called = 0;
        Func<int, int> original = arg =>
        {
          called++;
          return arg * 5;
        };
        var memoized = Memoizer.Memoize(original);

        memoized.Invoke(7);
        memoized.Invoke(7);
        var result = memoized.Invoke(7);

        Assert.AreEqual(35, result);
        Assert.AreEqual(1, called);
      }

      [TestMethod]
      public void CallsTheUnderlyingFuncWhenCalledWithDifferentArguments()
      {
        var called = new Dictionary<int, bool>();
        Func<int, int> original = arg =>
        {
          called[arg] = true;

          return arg * 5;
        };
        var memoized = Memoizer.Memoize(original);

        var result1 = memoized.Invoke(7);
        var result2 = memoized.Invoke(8);

        Assert.AreEqual(35, result1);
        Assert.AreEqual(40, result2);
        Assert.IsTrue(called[7]);
        Assert.IsTrue(called[8]);
      }
    }

    [TestClass]
    public class TwoArgs
    {
      [TestMethod]
      public void CallsTheUnderlyingFuncOnFirstCall()
      {
        var called = false;
        Func<int, int, int> original = (a1, a2) =>
        {
          called = true;
          return a1 * a2 * 5;
        };
        var memoized = Memoizer.Memoize(original);

        var result = memoized.Invoke(7, 8);

        Assert.AreEqual(280, result);
        Assert.IsTrue(called);
      }

      [TestMethod]
      public void ReturnsCachedValueForSubsequentCalls()
      {
        var called = 0;
        Func<int, int, int> original = (a1, a2) =>
        {
          called++;
          return a1 * a2 * 5;
        };
        var memoized = Memoizer.Memoize(original);

        memoized.Invoke(7, 8);
        memoized.Invoke(7, 8);
        var result = memoized.Invoke(7, 8);

        Assert.AreEqual(280, result);
        Assert.AreEqual(1, called);
      }

      [TestMethod]
      public void CallsTheUnderlyingFuncWhenCalledWithDifferentArguments()
      {
        var called = new Dictionary<int, bool>();
        Func<int, int, int> original = (a1, a2) =>
        {
          called[a1 * 10 + a2] = true;

          return a1 * a2 * 5;
        };
        var memoized = Memoizer.Memoize(original);

        var result1 = memoized.Invoke(7, 8);
        var result2 = memoized.Invoke(8, 7);

        Assert.AreEqual(280, result1);
        Assert.AreEqual(280, result2);
        Assert.IsTrue(called[78]);
        Assert.IsTrue(called[87]);
      }
    }
  }

As I said, I hope this helps someone. Please let me know if it does - or if you spot any problems.