Java Streams for Python Programmers
techWith Java 8, streams were added to the language. This post describes how typical Python list comprehensions can be implemented in Java using streams.
Creating and Working With Streams
Let’s say in Python we have a list l
.
>>> l = [1, 5, 1992]
If we wanted to create a list that contains all the squares of the
values in l
, we would write a list comprehension.
>>> [x**2 for x in l]
[1, 25, 3968064]
We can implement this almost as concise in Java using streams, which
live in the java.util.stream.*
import. First we need to convert the
data at hand into a
Stream
.
List<Integer> l = Arrays.asList(1, 5, 1992);
Stream<Integer> stream = list.stream();
Now that we have our stream
, we want to square up each number. This
is a
map
operation.
Stream<Integer> squares = stream.map(x -> x * x);
The map
method actually returns a new stream: A stream that takes
the individual values x
from the original source and then applies
the function x * x
to it. But the stream just describes this
computation, to actually run it, the stream needs to be consumed. This
is similar to how generators work in Python.
We can consume the generated stream piece by piece using the
forEach
method.
squares.forEach(System.out::println);
In this case, System.out.println
is called on every element of the
stream. The output is as expected.
1
25
3968064
Instead of using forEach
, you can also use
iterator
method to get, in this case, an Iterator<Integer>
.
But we wanted something comparable to list comprehensions. So let’s convert the stream to a list.
List<Integer> res = squares.collect(Collectors.toList());
In real code we wouldn’t create all these variables of course. A practical version would look something along those lines:
List<Integer> l = Arrays.asList(1, 5, 1992);
List<Integer> res = l.stream()
.map(x -> x * x)
.collect(Collectors.toList());
Printing res
, we get what we originally asked for.
[1, 25, 3968064]
Filter
Now that we got to know streams, it’s time to pick up the pace. Given
a list l
in Python, if we want only those values that are, say,
greater than 100, we can write the following.
>>> l = [1, 5, 15, 515, 15515]
>>> [x for x in l if x > 100]
[515, 15515]
This is called a
filter
operation because we filter out all those values which do not fit our
requirement. It’s easy with Java streams.
List<Integer> l = Arrays.asList(1, 5, 15, 515, 15515);
List<Integer> results = l.stream()
.filter(x -> x > 100)
.collect(Collectors.toList());
Printing results
gives us the expected output.
[515, 15515]
Now what if we want to square only the numbers which are greater than 100? Both in Python and Java this is now easy. In Python you would write a list comprehension.
[x**2 for x in l if x > 100]
In Java we would first run filter
and then map
on the original
stream.
l.stream()
.filter(x -> x > 100)
.map(x -> x * x)
.collect(Collectors.toList());
One, Two, Three
There is one gotcha about streams related to the difference between primitive data types and their boxed counterparts. Let’s explore it by example.
In Python it is easy to create a list of integers in a range.
>>> [x for x in range(0, 10)]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
To do the same in Java, we would write something like this.
List<Integer> ls = IntStream.range(0, 10)
.boxed()
.collect(Collectors.toList());
Printing ls
we get the expected.
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Note the call to the
boxed
method. It converts the
IntStream
,
a stream of int
values to Stream<Integer>
, a stream of Integer
values. This is required because there can be no List<int>
in Java,
only a List<Integer>
. Similar to this, there is
DoubleStream
which serves the same purpose for floating point numbers.
We can also move the other way, that is from Stream<Integer>
to
IntStream
, indeed from Stream<MyCoolType>
to IntStream
using the
mapToInt
method of Stream
. We might want to do this because IntStream
has
cool methods like sum
which Stream
lacks. Consider the following
code in which we sum up the price of multiple products in a shopping
cart. In Python we would write something like this.
Product = namedtuple('Product', ('name', 'price'))
def totalPrice(cart: List[Product]):
return sum(p.price for p in cart)
We can write similar code in Java now.
class Product {
String name;
int price;
}
int totalPrice(List<Product> cart) {
return cart.stream()
.mapToInt(p -> p.price)
.sum();
}
Resume
We can now translate Python list comprehensions to Java using streams. I didn’t show it here, but this leads to shorter and cleaner code compared to using traditional loops.
I do have to mention that streams are way more powerful than what
was shown here. Interesting features include advanced reduction (think
sum
on steroids) and parallel execution. If you are interested,
consider reading this
Java 8 Stream Tutorial
or if you are feeling courageous start at the
official documentation.