Suppose we have some code using Entity Framework that we want to unit test:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
public class CustomerController : Controller
{
    private readonly ICustomerContext _context;

    public CustomerController(ICustomerContext context)
    {
        _context = context;
    }
    
    public ViewResult Index(int id)
    {
        var customer = _context.Customers.Find(id);
        return View(customer);
    }
}

public class CustomerContext : DbContext, ICustomerContext
{
    public DbSet<Customer> Customers { get; set; }
}

public class Customer
{
    public int Id { get; set; }
    public string Name { get; set; }
}

public interface ICustomerContext
{
    DbSet<Customer> Customers { get; }
}

To unit test this we create a fake ICustomerContext that uses an in-memory list. First a fake DbSet implementation is needed, which turns out to be fairly straightforward.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
public class FakeDbSet<T> : DbSet<T>, IQueryable
    where T : class
{
    private readonly Func<T, object\[\], bool> _find;
    readonly ObservableCollection<T> _items;
    private IQueryable<T> _query;

    public FakeDbSet(Func<T, object[], bool> find = null)
    {
        _find = find;
        _items = new ObservableCollection<T>();
        _query = _items.AsQueryable();
    }
    
    public override ObservableCollection<T> Local { get { return _items; } }
    IQueryProvider IQueryable.Provider { get { return _query.Provider; } }
    Expression IQueryable.Expression { get { return _query.xpression; } }
    
    public override IEnumerable<T> AddRange(IEnumerable<T> entities)
    {
        foreach (var entity in entities)
        {
            _items.Add(entity);
        }
        return entities;
    }
    
    public override T Add(T item)
    {
        _items.Add(item);
        return item;
    }
    
    public override IEnumerable<T> RemoveRange(IEnumerable<T> entities)
    {
        foreach (var entity in entities)
        {
            _items.Remove(entity);
        }
        return entities;
    }
    
    public override T Remove(T item)
    {
        _items.Remove(item);
        return item;
    }
    
    public override T Attach(T item)
    {
        _items.Add(item);
        return item;
    }
    
    public override T Create()
    {
        return Activator.CreateInstance<T>();
    }
    
    public override TDerivedEntity Create<TDerivedEntity>()
    {
        return Activator.CreateInstance<TDerivedEntity>();
    }
    
    public override T Find(params object[] keyValues)
    {
        if (_find == null)
        {
            throw new NotSupportedException();
        }
        return _items.SingleOrDefault(e => _find(e, keyValues));
    }
}

Find is slightly awkward to implement - you can use reflection, or you can just provide a lambda like I have.

The fake customer context is now trivial to create

1
2
3
4
5
6
7
8
public class FakeCustomerContext : ICustomerContext
{
    private readonly DbSet<Customer> _customers =
        new FakeDbSet<Customer>((customer, objects) => customer.Id == (int) objects\[0\]);

    public DbSet<Customer> Customers { get { return _customers; } }

}

Now to write the tests themselves:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
[TestMethod]
public void ViewReturnsCorrectCustomer()
{
    var context = new FakeCustomerContext();
    context.Customers.Add(new Customer {Id = 42, Name = "Test Customer"});

    var controller = new CustomerController(context);
    var result = controller.Index(42);
    Assert.AreEqual("Test Customer", ((Customer)result.Model).Name);
}

This approach is somewhat limited, but works reasonably well.