Suppose we have some code using Entity Framework that we want to unit test:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
|
public class CustomerController : Controller
{
private readonly ICustomerContext _context;
public CustomerController(ICustomerContext context)
{
_context = context;
}
public ViewResult Index(int id)
{
var customer = _context.Customers.Find(id);
return View(customer);
}
}
public class CustomerContext : DbContext, ICustomerContext
{
public DbSet<Customer> Customers { get; set; }
}
public class Customer
{
public int Id { get; set; }
public string Name { get; set; }
}
public interface ICustomerContext
{
DbSet<Customer> Customers { get; }
}
|
To unit test this we create a fake ICustomerContext that uses an in-memory list. First a fake DbSet implementation is needed, which turns out to be fairly straightforward.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
|
public class FakeDbSet<T> : DbSet<T>, IQueryable
where T : class
{
private readonly Func<T, object\[\], bool> _find;
readonly ObservableCollection<T> _items;
private IQueryable<T> _query;
public FakeDbSet(Func<T, object[], bool> find = null)
{
_find = find;
_items = new ObservableCollection<T>();
_query = _items.AsQueryable();
}
public override ObservableCollection<T> Local { get { return _items; } }
IQueryProvider IQueryable.Provider { get { return _query.Provider; } }
Expression IQueryable.Expression { get { return _query.xpression; } }
public override IEnumerable<T> AddRange(IEnumerable<T> entities)
{
foreach (var entity in entities)
{
_items.Add(entity);
}
return entities;
}
public override T Add(T item)
{
_items.Add(item);
return item;
}
public override IEnumerable<T> RemoveRange(IEnumerable<T> entities)
{
foreach (var entity in entities)
{
_items.Remove(entity);
}
return entities;
}
public override T Remove(T item)
{
_items.Remove(item);
return item;
}
public override T Attach(T item)
{
_items.Add(item);
return item;
}
public override T Create()
{
return Activator.CreateInstance<T>();
}
public override TDerivedEntity Create<TDerivedEntity>()
{
return Activator.CreateInstance<TDerivedEntity>();
}
public override T Find(params object[] keyValues)
{
if (_find == null)
{
throw new NotSupportedException();
}
return _items.SingleOrDefault(e => _find(e, keyValues));
}
}
|
Find is slightly awkward to implement - you can use reflection, or you can just provide a lambda like I have.
The fake customer context is now trivial to create
1
2
3
4
5
6
7
8
|
public class FakeCustomerContext : ICustomerContext
{
private readonly DbSet<Customer> _customers =
new FakeDbSet<Customer>((customer, objects) => customer.Id == (int) objects\[0\]);
public DbSet<Customer> Customers { get { return _customers; } }
}
|
Now to write the tests themselves:
1
2
3
4
5
6
7
8
9
10
|
[TestMethod]
public void ViewReturnsCorrectCustomer()
{
var context = new FakeCustomerContext();
context.Customers.Add(new Customer {Id = 42, Name = "Test Customer"});
var controller = new CustomerController(context);
var result = controller.Index(42);
Assert.AreEqual("Test Customer", ((Customer)result.Model).Name);
}
|
This approach is somewhat limited, but works reasonably well.