Fold in Java

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67

interface List<T> {
  <R> R accept(ListVisitor<T,R> v);
}

interface ListVisitor<T,R> {
  R visitNil();
  R visitCons(T head, List<T> tail);
}

class Nil<T> implements List<T> {
  @Override
  public <R> R accept(ListVisitor<T,R> v) {
    return v.visitNil();
  }
}

class Cons<T> implements List<T> {
  public final T head;
  public final List<T> tail;
  public Cons(T head, List<T> tail) {
    this.head = head;
    this.tail = tail;
  }
  @Override
  public <R> R accept(ListVisitor<T,R> v) {
    return v.visitCons(head, tail);
  }
}

interface Function<A,B> {
  B apply(A a);
}

class Fold {
  public static <A,B> B foldr(final Function<A,Function<B,B>> f, final B init, List<A> list) {
    return list.accept(new ListVisitor<A, B>() {

		@Override
		public B visitNil() {
			return init;
		}

		@Override
		public B visitCons(A head, List<A> tail) {
			return f.apply(head).apply(foldr(f,init,tail));
		}
	});
  }
  static Function<Integer,Function<Integer,Integer>> plus = new Function<Integer, Function<Integer,Integer>>() {
	@Override
	public Function<Integer, Integer> apply(final Integer a) {
		return new Function<Integer, Integer>() {
			@Override
			public Integer apply(Integer b) {
				return a + b;
			}
		};
	}
  };
  
  static List<Integer> list0 = new Cons<Integer>(8, new Cons<Integer>(3, new Cons<Integer>(5, new Nil<Integer>())));
  
  public static void main(String[] args) {
	  System.out.print(foldr(plus, 0, list0));
  }
}