andreas.schaertl

Java Streams for Python Programmers

tech

With Java 8, streams were added to the language. This post describes how typical Python list comprehensions can be implemented in Java using streams.

Creating and Working With Streams

Let’s say in Python we have a list l.

>>> l = [1, 5, 1992]

If we wanted to create a list that contains all the squares of the values in l, we would write a list comprehension.

>>> [x**2 for x in l]
[1, 25, 3968064]

We can implement this almost as concise in Java using streams, which live in the java.util.stream.* import. First we need to convert the data at hand into a Stream.

List<Integer> l = Arrays.asList(1, 5, 1992);
Stream<Integer> stream = list.stream();

Now that we have our stream, we want to square up each number. This is a map operation.

Stream<Integer> squares = stream.map(x -> x * x);

The map method actually returns a new stream: A stream that takes the individual values x from the original source and then applies the function x * x to it. But the stream just describes this computation, to actually run it, the stream needs to be consumed. This is similar to how generators work in Python.

We can consume the generated stream piece by piece using the forEach method.

squares.forEach(System.out::println);

In this case, System.out.println is called on every element of the stream. The output is as expected.

1
25
3968064

Instead of using forEach, you can also use iterator method to get, in this case, an Iterator<Integer>.

But we wanted something comparable to list comprehensions. So let’s convert the stream to a list.

List<Integer> res = squares.collect(Collectors.toList());

In real code we wouldn’t create all these variables of course. A practical version would look something along those lines:

List<Integer> l = Arrays.asList(1, 5, 1992);
List<Integer> res = l.stream()
	.map(x -> x * x)
	.collect(Collectors.toList());

Printing res, we get what we originally asked for.

[1, 25, 3968064]

Filter

Now that we got to know streams, it’s time to pick up the pace. Given a list l in Python, if we want only those values that are, say, greater than 100, we can write the following.

>>> l = [1, 5, 15, 515, 15515]
>>> [x for x in l if x > 100]
[515, 15515]

This is called a filter operation because we filter out all those values which do not fit our requirement. It’s easy with Java streams.

List<Integer> l = Arrays.asList(1, 5, 15, 515, 15515);
List<Integer> results = l.stream()
	.filter(x -> x > 100)
	.collect(Collectors.toList());

Printing results gives us the expected output.

[515, 15515]

Now what if we want to square only the numbers which are greater than 100? Both in Python and Java this is now easy. In Python you would write a list comprehension.

[x**2 for x in l if x > 100]

In Java we would first run filter and then map on the original stream.

l.stream()
	.filter(x -> x > 100)
	.map(x -> x * x)
	.collect(Collectors.toList());

One, Two, Three

There is one gotcha about streams related to the difference between primitive data types and their boxed counterparts. Let’s explore it by example.

In Python it is easy to create a list of integers in a range.

>>> [x for x in range(0, 10)]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

To do the same in Java, we would write something like this.

List<Integer> ls = IntStream.range(0, 10)
	.boxed()
	.collect(Collectors.toList());

Printing ls we get the expected.

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

Note the call to the boxed method. It converts the IntStream, a stream of int values to Stream<Integer>, a stream of Integer values. This is required because there can be no List<int> in Java, only a List<Integer>. Similar to this, there is DoubleStream which serves the same purpose for floating point numbers.

We can also move the other way, that is from Stream<Integer> to IntStream, indeed from Stream<MyCoolType> to IntStream using the mapToInt method of Stream. We might want to do this because IntStream has cool methods like sum which Stream lacks. Consider the following code in which we sum up the price of multiple products in a shopping cart. In Python we would write something like this.

Product = namedtuple('Product', ('name', 'price'))

def totalPrice(cart: List[Product]):
	return sum(p.price for p in cart)

We can write similar code in Java now.

class Product {
	String name;
	int price;
}

int totalPrice(List<Product> cart) {
	return cart.stream()
		.mapToInt(p -> p.price)
		.sum();
}

Resume

We can now translate Python list comprehensions to Java using streams. I didn’t show it here, but this leads to shorter and cleaner code compared to using traditional loops.

I do have to mention that streams are way more powerful than what was shown here. Interesting features include advanced reduction (think sum on steroids) and parallel execution. If you are interested, consider reading this Java 8 Stream Tutorial or if you are feeling courageous start at the official documentation.