Submission #1133238


Source Code Expand

import java.io.OutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.Iterator;
import java.io.BufferedWriter;
import java.util.InputMismatchException;
import java.io.IOException;
import java.io.Writer;
import java.io.OutputStreamWriter;
import java.util.NoSuchElementException;
import java.io.InputStream;

/**
 * Built using CHelper plug-in
 * Actual solution is at the top
 *
 * @author Egor Kulikov (egor@egork.net)
 */
public class Main {
    public static void main(String[] args) {
        InputStream inputStream = System.in;
        OutputStream outputStream = System.out;
        InputReader in = new InputReader(inputStream);
        OutputWriter out = new OutputWriter(outputStream);
        TaskD solver = new TaskD();
        solver.solve(1, in, out);
        out.close();
    }

    static class TaskD {
        Graph graph;
        int limit;
        long[][] callRes;
        int callSize;
        long[][] rowCallRes;
        long[][] columnCallRes;
        long[][] sumCallRes;
        long[][] res;
        long[][] rowRes;
        long[][] columnRes;
        long[][] sumRes;

        public void solve(int testNumber, InputReader in, OutputWriter out) {
            int n = in.readInt();
            int[] a = new int[n - 1];
            int[] b = new int[n - 1];
            IOUtils.readIntArrays(in, a, b);
            MiscUtils.decreaseByOne(a, b);
            graph = BidirectionalGraph.createGraph(n, a, b);
            limit = (getDiameter() + 1) / 2;
            callRes = new long[limit + 1][limit + 1];
            rowCallRes = new long[limit + 1][limit + 1];
            columnCallRes = new long[limit + 1][limit + 1];
            sumCallRes = new long[limit + 1][limit + 1];
            res = new long[limit + 1][limit + 1];
            columnRes = new long[limit + 1][limit + 1];
            rowRes = new long[limit + 1][limit + 1];
            sumRes = new long[limit + 1][limit + 1];
            int[][] result = solve(0, -1);
            long answer = 0;
            for (int i = 0; i < result.length; i++) {
                answer += ArrayUtils.sumArray(result[i]);
            }
            out.printLine(answer % MiscUtils.MOD7);
        }

        private int[][] solve(int vertex, int last) {
            int[][] result = null;
            for (int i = graph.firstOutbound(vertex); i != -1; i = graph.nextOutbound(i)) {
                int next = graph.destination(i);
                if (next == last) {
                    continue;
                }
                addEdge(solve(next, vertex));
                if (result == null) {
                    result = trim();
                } else {
                    result = join(result);
                }
            }
            if (result == null) {
                return new int[][]{{1}};
            }
            return result;
        }

        private int[][] join(int[][] result) {
            for (int i = 0; i < result.length; i++) {
                for (int j = 0; j < result.length; j++) {
                    res[i][j] = result[i][j];
                }
            }
            if (result.length > callSize) {
                for (int i = 0; i <= callSize; i++) {
                    Arrays.fill(callRes[i], callSize + 1, result.length, 0);
                }
                for (int i = callSize + 1; i < result.length; i++) {
                    Arrays.fill(callRes[i], 0, result.length, 0);
                }
                callSize = result.length - 1;
            } else {
                for (int i = 0; i < result.length; i++) {
                    Arrays.fill(res[i], result.length, callSize + 1, 0);
                }
                for (int i = result.length; i <= callSize; i++) {
                    Arrays.fill(res[i], 0, callSize + 1, 0);
                }
            }
            fill(res, rowRes, columnRes, sumRes, callSize);
            fill(callRes, rowCallRes, columnCallRes, sumCallRes, callSize);
            result = new int[callSize + 1][callSize + 1];
            for (int i = 0; i < result.length; i++) {
                for (int j = 0; j < result.length; j++) {
                    int ai = Math.min(limit - j + 1, i);
                    int aj = Math.min(limit - i + 1, j);
                    long current = res[i][j] * sumCallRes[ai][aj] +
                            callRes[i][j] * sumRes[ai][aj];
                    current %= MiscUtils.MOD7;
                    if (i + j <= limit) {
                        current += (rowRes[i][aj] * columnCallRes[ai][j] +
                                rowCallRes[i][aj] * columnRes[ai][j] +
                                res[i][j] * columnCallRes[ai][j] +
                                res[i][j] * rowCallRes[i][aj]) % MiscUtils.MOD7 +
                                callRes[i][j] * columnRes[ai][j] +
                                callRes[i][j] * rowRes[i][aj] +
                                res[i][j] * callRes[i][j];
                    }
                    current %= MiscUtils.MOD7;
                    result[i][j] = (int) current;
                }
            }
            return result;
        }

        private void fill(long[][] res, long[][] rowRes, long[][] callRes, long[][] sumRes, int size) {
            for (int i = 0; i <= size; i++) {
                for (int j = 0; j <= size; j++) {
                    if (i > 0) {
                        callRes[i][j] = (callRes[i - 1][j] + res[i - 1][j]) % MiscUtils.MOD7;
                    }
                    if (j > 0) {
                        rowRes[i][j] = (rowRes[i][j - 1] + res[i][j - 1]) % MiscUtils.MOD7;
                    }
                    if (i > 0 && j > 0) {
                        sumRes[i][j] =
                                (sumRes[i - 1][j] + sumRes[i][j - 1] - sumRes[i - 1][j - 1] + res[i - 1][j - 1]) %
                                        MiscUtils.MOD7;
                        if (sumRes[i][j] < 0) {
                            sumRes[i][j] += MiscUtils.MOD7;
                        }
                    }
                }
            }
        }

        private int[][] trim() {
            int[][] result = new int[callSize + 1][callSize + 1];
            for (int i = 0; i <= callSize; i++) {
                for (int j = 0; j <= callSize; j++) {
                    result[i][j] = (int) (callRes[i][j] % MiscUtils.MOD7);
                }
            }
            return result;
        }

        private void addEdge(int[][] call) {
            callSize = Math.min(limit, call.length);
            for (int i = 0; i <= limit && i <= call.length; i++) {
                for (int j = 0; j <= limit && j <= call.length; j++) {
                    callRes[i][j] = 0;
                    if (i > 0 && j < call.length) {
                        callRes[i][j] += call[i - 1][j];
                    }
                    if (j > 0 && i < call.length) {
                        callRes[i][j] += call[i][j - 1];
                    }
                }
            }
        }

        private int getDiameter() {
            int farthest = get(0, -1).second;
            return get(farthest, -1).first;
        }

        private IntIntPair get(int vertex, int last) {
            int dist = 0;
            int farthest = vertex;
            for (int i = graph.firstOutbound(vertex); i != -1; i = graph.nextOutbound(i)) {
                int next = graph.destination(i);
                if (next == last) {
                    continue;
                }
                IntIntPair call = get(next, vertex);
                if (call.first >= dist) {
                    dist = call.first + 1;
                    farthest = call.second;
                }
            }
            return new IntIntPair(dist, farthest);
        }

    }

    static class IntIntPair implements Comparable<IntIntPair> {
        public final int first;
        public final int second;

        public IntIntPair(int first, int second) {
            this.first = first;
            this.second = second;
        }


        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || getClass() != o.getClass()) {
                return false;
            }

            IntIntPair pair = (IntIntPair) o;

            return first == pair.first && second == pair.second;
        }


        public int hashCode() {
            int result = first;
            result = 31 * result + second;
            return result;
        }


        public String toString() {
            return "(" + first + "," + second + ")";
        }

        @SuppressWarnings({"unchecked"})
        public int compareTo(IntIntPair o) {
            int value = Integer.compare(first, o.first);
            if (value != 0) {
                return value;
            }
            return Integer.compare(second, o.second);
        }

    }

    static interface IntCollection extends IntStream {
        public int size();

    }

    static interface IntIterator {
        public int value() throws NoSuchElementException;

        public boolean advance();

        public boolean isValid();

    }

    static class IntArray extends IntAbstractStream implements IntList {
        private int[] data;

        public IntArray(int[] arr) {
            data = arr;
        }

        public int size() {
            return data.length;
        }

        public int get(int at) {
            return data[at];
        }

        public void removeAt(int index) {
            throw new UnsupportedOperationException();
        }

    }

    static abstract class IntAbstractStream implements IntStream {

        public String toString() {
            StringBuilder builder = new StringBuilder();
            boolean first = true;
            for (IntIterator it = intIterator(); it.isValid(); it.advance()) {
                if (first) {
                    first = false;
                } else {
                    builder.append(' ');
                }
                builder.append(it.value());
            }
            return builder.toString();
        }


        public boolean equals(Object o) {
            if (!(o instanceof IntStream)) {
                return false;
            }
            IntStream c = (IntStream) o;
            IntIterator it = intIterator();
            IntIterator jt = c.intIterator();
            while (it.isValid() && jt.isValid()) {
                if (it.value() != jt.value()) {
                    return false;
                }
                it.advance();
                jt.advance();
            }
            return !it.isValid() && !jt.isValid();
        }


        public int hashCode() {
            int result = 0;
            for (IntIterator it = intIterator(); it.isValid(); it.advance()) {
                result *= 31;
                result += it.value();
            }
            return result;
        }

    }

    static interface IntStream extends Iterable<Integer>, Comparable<IntStream> {
        public IntIterator intIterator();

        default public Iterator<Integer> iterator() {
            return new Iterator<Integer>() {
                private IntIterator it = intIterator();

                public boolean hasNext() {
                    return it.isValid();
                }

                public Integer next() {
                    int result = it.value();
                    it.advance();
                    return result;
                }
            };
        }

        default public int compareTo(IntStream c) {
            IntIterator it = intIterator();
            IntIterator jt = c.intIterator();
            while (it.isValid() && jt.isValid()) {
                int i = it.value();
                int j = jt.value();
                if (i < j) {
                    return -1;
                } else if (i > j) {
                    return 1;
                }
                it.advance();
                jt.advance();
            }
            if (it.isValid()) {
                return 1;
            }
            if (jt.isValid()) {
                return -1;
            }
            return 0;
        }

        default public long sum() {
            long result = 0;
            for (IntIterator it = intIterator(); it.isValid(); it.advance()) {
                result += it.value();
            }
            return result;
        }

    }

    static interface IntReversableCollection extends IntCollection {
    }

    static interface Edge {
    }

    static class MiscUtils {
        public static final int MOD7 = (int) (1e9 + 7);

        public static void decreaseByOne(int[]... arrays) {
            for (int[] array : arrays) {
                for (int i = 0; i < array.length; i++) {
                    array[i]--;
                }
            }
        }

    }

    static class IOUtils {
        public static void readIntArrays(InputReader in, int[]... arrays) {
            for (int i = 0; i < arrays[0].length; i++) {
                for (int j = 0; j < arrays.length; j++) {
                    arrays[j][i] = in.readInt();
                }
            }
        }

    }

    static class InputReader {
        private InputStream stream;
        private byte[] buf = new byte[1024];
        private int curChar;
        private int numChars;
        private InputReader.SpaceCharFilter filter;

        public InputReader(InputStream stream) {
            this.stream = stream;
        }

        public int read() {
            if (numChars == -1) {
                throw new InputMismatchException();
            }
            if (curChar >= numChars) {
                curChar = 0;
                try {
                    numChars = stream.read(buf);
                } catch (IOException e) {
                    throw new InputMismatchException();
                }
                if (numChars <= 0) {
                    return -1;
                }
            }
            return buf[curChar++];
        }

        public int readInt() {
            int c = read();
            while (isSpaceChar(c)) {
                c = read();
            }
            int sgn = 1;
            if (c == '-') {
                sgn = -1;
                c = read();
            }
            int res = 0;
            do {
                if (c < '0' || c > '9') {
                    throw new InputMismatchException();
                }
                res *= 10;
                res += c - '0';
                c = read();
            } while (!isSpaceChar(c));
            return res * sgn;
        }

        public boolean isSpaceChar(int c) {
            if (filter != null) {
                return filter.isSpaceChar(c);
            }
            return isWhitespace(c);
        }

        public static boolean isWhitespace(int c) {
            return c == ' ' || c == '\n' || c == '\r' || c == '\t' || c == -1;
        }

        public interface SpaceCharFilter {
            public boolean isSpaceChar(int ch);

        }

    }

    static class ArrayUtils {
        public static long sumArray(int[] array) {
            return new IntArray(array).sum();
        }

    }

    static interface IntList extends IntReversableCollection {
        public abstract int get(int index);

        public abstract void removeAt(int index);

        default public IntIterator intIterator() {
            return new IntIterator() {
                private int at;
                private boolean removed;

                public int value() {
                    if (removed) {
                        throw new IllegalStateException();
                    }
                    return get(at);
                }

                public boolean advance() {
                    at++;
                    removed = false;
                    return isValid();
                }

                public boolean isValid() {
                    return !removed && at < size();
                }

                public void remove() {
                    removeAt(at);
                    at--;
                    removed = true;
                }
            };
        }

    }

    static class Graph {
        public static final int REMOVED_BIT = 0;
        protected int vertexCount;
        protected int edgeCount;
        private int[] firstOutbound;
        private int[] firstInbound;
        private Edge[] edges;
        private int[] nextInbound;
        private int[] nextOutbound;
        private int[] from;
        private int[] to;
        private long[] weight;
        public long[] capacity;
        private int[] reverseEdge;
        private int[] flags;

        public Graph(int vertexCount) {
            this(vertexCount, vertexCount);
        }

        public Graph(int vertexCount, int edgeCapacity) {
            this.vertexCount = vertexCount;
            firstOutbound = new int[vertexCount];
            Arrays.fill(firstOutbound, -1);

            from = new int[edgeCapacity];
            to = new int[edgeCapacity];
            nextOutbound = new int[edgeCapacity];
            flags = new int[edgeCapacity];
        }

        public int addEdge(int fromID, int toID, long weight, long capacity, int reverseEdge) {
            ensureEdgeCapacity(edgeCount + 1);
            if (firstOutbound[fromID] != -1) {
                nextOutbound[edgeCount] = firstOutbound[fromID];
            } else {
                nextOutbound[edgeCount] = -1;
            }
            firstOutbound[fromID] = edgeCount;
            if (firstInbound != null) {
                if (firstInbound[toID] != -1) {
                    nextInbound[edgeCount] = firstInbound[toID];
                } else {
                    nextInbound[edgeCount] = -1;
                }
                firstInbound[toID] = edgeCount;
            }
            this.from[edgeCount] = fromID;
            this.to[edgeCount] = toID;
            if (capacity != 0) {
                if (this.capacity == null) {
                    this.capacity = new long[from.length];
                }
                this.capacity[edgeCount] = capacity;
            }
            if (weight != 0) {
                if (this.weight == null) {
                    this.weight = new long[from.length];
                }
                this.weight[edgeCount] = weight;
            }
            if (reverseEdge != -1) {
                if (this.reverseEdge == null) {
                    this.reverseEdge = new int[from.length];
                    Arrays.fill(this.reverseEdge, 0, edgeCount, -1);
                }
                this.reverseEdge[edgeCount] = reverseEdge;
            }
            if (edges != null) {
                edges[edgeCount] = createEdge(edgeCount);
            }
            return edgeCount++;
        }

        protected final GraphEdge createEdge(int id) {
            return new GraphEdge(id);
        }

        public final int addFlowWeightedEdge(int from, int to, long weight, long capacity) {
            if (capacity == 0) {
                return addEdge(from, to, weight, 0, -1);
            } else {
                int lastEdgeCount = edgeCount;
                addEdge(to, from, -weight, 0, lastEdgeCount + entriesPerEdge());
                return addEdge(from, to, weight, capacity, lastEdgeCount);
            }
        }

        protected int entriesPerEdge() {
            return 1;
        }

        public final int addWeightedEdge(int from, int to, long weight) {
            return addFlowWeightedEdge(from, to, weight, 0);
        }

        public final int addSimpleEdge(int from, int to) {
            return addWeightedEdge(from, to, 0);
        }

        protected final int edgeCapacity() {
            return from.length;
        }

        public final int firstOutbound(int vertex) {
            int id = firstOutbound[vertex];
            while (id != -1 && isRemoved(id)) {
                id = nextOutbound[id];
            }
            return id;
        }

        public final int nextOutbound(int id) {
            id = nextOutbound[id];
            while (id != -1 && isRemoved(id)) {
                id = nextOutbound[id];
            }
            return id;
        }

        public final int destination(int id) {
            return to[id];
        }

        public final boolean flag(int id, int bit) {
            return (flags[id] >> bit & 1) != 0;
        }

        public final boolean isRemoved(int id) {
            return flag(id, REMOVED_BIT);
        }

        protected void ensureEdgeCapacity(int size) {
            if (from.length < size) {
                int newSize = Math.max(size, 2 * from.length);
                if (edges != null) {
                    edges = resize(edges, newSize);
                }
                from = resize(from, newSize);
                to = resize(to, newSize);
                nextOutbound = resize(nextOutbound, newSize);
                if (nextInbound != null) {
                    nextInbound = resize(nextInbound, newSize);
                }
                if (weight != null) {
                    weight = resize(weight, newSize);
                }
                if (capacity != null) {
                    capacity = resize(capacity, newSize);
                }
                if (reverseEdge != null) {
                    reverseEdge = resize(reverseEdge, newSize);
                }
                flags = resize(flags, newSize);
            }
        }

        protected final int[] resize(int[] array, int size) {
            int[] newArray = new int[size];
            System.arraycopy(array, 0, newArray, 0, array.length);
            return newArray;
        }

        private long[] resize(long[] array, int size) {
            long[] newArray = new long[size];
            System.arraycopy(array, 0, newArray, 0, array.length);
            return newArray;
        }

        private Edge[] resize(Edge[] array, int size) {
            Edge[] newArray = new Edge[size];
            System.arraycopy(array, 0, newArray, 0, array.length);
            return newArray;
        }

        protected class GraphEdge implements Edge {
            protected int id;

            protected GraphEdge(int id) {
                this.id = id;
            }

        }

    }

    static class BidirectionalGraph extends Graph {
        public int[] transposedEdge;

        public BidirectionalGraph(int vertexCount) {
            this(vertexCount, vertexCount);
        }

        public BidirectionalGraph(int vertexCount, int edgeCapacity) {
            super(vertexCount, 2 * edgeCapacity);
            transposedEdge = new int[2 * edgeCapacity];
        }

        public static BidirectionalGraph createGraph(int vertexCount, int[] from, int[] to) {
            BidirectionalGraph graph = new BidirectionalGraph(vertexCount, from.length);
            for (int i = 0; i < from.length; i++) {
                graph.addSimpleEdge(from[i], to[i]);
            }
            return graph;
        }


        public int addEdge(int fromID, int toID, long weight, long capacity, int reverseEdge) {
            int lastEdgeCount = edgeCount;
            super.addEdge(fromID, toID, weight, capacity, reverseEdge);
            super.addEdge(toID, fromID, weight, capacity, reverseEdge == -1 ? -1 : reverseEdge + 1);
            this.transposedEdge[lastEdgeCount] = lastEdgeCount + 1;
            this.transposedEdge[lastEdgeCount + 1] = lastEdgeCount;
            return lastEdgeCount;
        }


        protected int entriesPerEdge() {
            return 2;
        }


        protected void ensureEdgeCapacity(int size) {
            if (size > edgeCapacity()) {
                super.ensureEdgeCapacity(size);
                transposedEdge = resize(transposedEdge, edgeCapacity());
            }
        }

    }

    static class OutputWriter {
        private final PrintWriter writer;

        public OutputWriter(OutputStream outputStream) {
            writer = new PrintWriter(new BufferedWriter(new OutputStreamWriter(outputStream)));
        }

        public OutputWriter(Writer writer) {
            this.writer = new PrintWriter(writer);
        }

        public void close() {
            writer.close();
        }

        public void printLine(long i) {
            writer.println(i);
        }

    }
}

Submission Info

Submission Time
Task D - Oriented Tree
User Egor
Language Java8 (OpenJDK 1.8.0)
Score 1800
Code Size 25241 Byte
Status AC
Exec Time 1283 ms
Memory 186092 KB

Judge Result

Set Name Sample All
Score / Max Score 0 / 0 1800 / 1800
Status
AC × 4
AC × 33
Set Name Test Cases
Sample 0_00.txt, 0_01.txt, 0_02.txt, 0_03.txt
All 0_00.txt, 0_01.txt, 0_02.txt, 0_03.txt, 1_00.txt, 1_01.txt, 1_02.txt, 1_03.txt, 1_04.txt, 1_05.txt, 1_06.txt, 1_07.txt, 1_08.txt, 1_09.txt, 1_10.txt, 1_11.txt, 1_12.txt, 1_13.txt, 1_14.txt, 1_15.txt, 1_16.txt, 1_17.txt, 1_18.txt, 1_19.txt, 1_20.txt, 1_21.txt, 1_22.txt, 1_23.txt, 1_24.txt, 1_25.txt, 1_26.txt, 1_27.txt, 1_28.txt
Case Name Status Exec Time Memory
0_00.txt AC 76 ms 18004 KB
0_01.txt AC 75 ms 21332 KB
0_02.txt AC 77 ms 22740 KB
0_03.txt AC 77 ms 19156 KB
1_00.txt AC 76 ms 19668 KB
1_01.txt AC 1283 ms 186092 KB
1_02.txt AC 1263 ms 178348 KB
1_03.txt AC 88 ms 16596 KB
1_04.txt AC 91 ms 21716 KB
1_05.txt AC 108 ms 19376 KB
1_06.txt AC 109 ms 20616 KB
1_07.txt AC 113 ms 20180 KB
1_08.txt AC 113 ms 20436 KB
1_09.txt AC 155 ms 22244 KB
1_10.txt AC 221 ms 22100 KB
1_11.txt AC 301 ms 32484 KB
1_12.txt AC 303 ms 44316 KB
1_13.txt AC 513 ms 49328 KB
1_14.txt AC 1059 ms 91872 KB
1_15.txt AC 1103 ms 165244 KB
1_16.txt AC 909 ms 151996 KB
1_17.txt AC 98 ms 20560 KB
1_18.txt AC 107 ms 22428 KB
1_19.txt AC 104 ms 20108 KB
1_20.txt AC 103 ms 19900 KB
1_21.txt AC 103 ms 23124 KB
1_22.txt AC 90 ms 20948 KB
1_23.txt AC 107 ms 22416 KB
1_24.txt AC 90 ms 20948 KB
1_25.txt AC 108 ms 22996 KB
1_26.txt AC 105 ms 22868 KB
1_27.txt AC 106 ms 23124 KB
1_28.txt AC 106 ms 20820 KB