""" Sum of digits sequence Problem 551 Let a(0), a(1),... be an integer sequence defined by: a(0) = 1 for n >= 1, a(n) is the sum of the digits of all preceding terms The sequence starts with 1, 1, 2, 4, 8, ... You are given a(10^6) = 31054319. Find a(10^15) """ ks = range(2, 20 + 1) base = [10**k for k in range(ks[-1] + 1)] memo: dict[int, dict[int, list[list[int]]]] = {} def next_term(a_i, k, i, n): """ Calculates and updates a_i in-place to either the n-th term or the smallest term for which c > 10^k when the terms are written in the form: a(i) = b * 10^k + c For any a(i), if digitsum(b) and c have the same value, the difference between subsequent terms will be the same until c >= 10^k. This difference is cached to greatly speed up the computation. Arguments: a_i -- array of digits starting from the one's place that represent the i-th term in the sequence k -- k when terms are written in the from a(i) = b*10^k + c. Term are calulcated until c > 10^k or the n-th term is reached. i -- position along the sequence n -- term to calculate up to if k is large enough Return: a tuple of difference between ending term and starting term, and the number of terms calculated. ex. if starting term is a_0=1, and ending term is a_10=62, then (61, 9) is returned. """ # ds_b - digitsum(b) ds_b = sum(a_i[j] for j in range(k, len(a_i))) c = sum(a_i[j] * base[j] for j in range(min(len(a_i), k))) diff, dn = 0, 0 max_dn = n - i sub_memo = memo.get(ds_b) if sub_memo is not None: jumps = sub_memo.get(c) if jumps is not None and len(jumps) > 0: # find and make the largest jump without going over max_jump = -1 for _k in range(len(jumps) - 1, -1, -1): if jumps[_k][2] <= k and jumps[_k][1] <= max_dn: max_jump = _k break if max_jump >= 0: diff, dn, _kk = jumps[max_jump] # since the difference between jumps is cached, add c new_c = diff + c for j in range(min(k, len(a_i))): new_c, a_i[j] = divmod(new_c, 10) if new_c > 0: add(a_i, k, new_c) else: sub_memo[c] = [] else: sub_memo = {c: []} memo[ds_b] = sub_memo if dn >= max_dn or c + diff >= base[k]: return diff, dn if k > ks[0]: while True: # keep doing smaller jumps _diff, terms_jumped = next_term(a_i, k - 1, i + dn, n) diff += _diff dn += terms_jumped if dn >= max_dn or c + diff >= base[k]: break else: # would be too small a jump, just compute sequential terms instead _diff, terms_jumped = compute(a_i, k, i + dn, n) diff += _diff dn += terms_jumped jumps = sub_memo[c] # keep jumps sorted by # of terms skipped j = 0 while j < len(jumps): if jumps[j][1] > dn: break j += 1 # cache the jump for this value digitsum(b) and c sub_memo[c].insert(j, (diff, dn, k)) return (diff, dn) def compute(a_i, k, i, n): """ same as next_term(a_i, k, i, n) but computes terms without memoizing results. """ if i >= n: return 0, i if k > len(a_i): a_i.extend([0 for _ in range(k - len(a_i))]) # note: a_i -> b * 10^k + c # ds_b -> digitsum(b) # ds_c -> digitsum(c) start_i = i ds_b, ds_c, diff = 0, 0, 0 for j in range(len(a_i)): if j >= k: ds_b += a_i[j] else: ds_c += a_i[j] while i < n: i += 1 addend = ds_c + ds_b diff += addend ds_c = 0 for j in range(k): s = a_i[j] + addend addend, a_i[j] = divmod(s, 10) ds_c += a_i[j] if addend > 0: break if addend > 0: add(a_i, k, addend) return diff, i - start_i def add(digits, k, addend): """ adds addend to digit array given in digits starting at index k """ for j in range(k, len(digits)): s = digits[j] + addend if s >= 10: quotient, digits[j] = divmod(s, 10) addend = addend // 10 + quotient else: digits[j] = s addend = addend // 10 if addend == 0: break while addend > 0: addend, digit = divmod(addend, 10) digits.append(digit) def solution(n: int = 10**15) -> int: """ returns n-th term of sequence >>> solution(10) 62 >>> solution(10**6) 31054319 >>> solution(10**15) 73597483551591773 """ digits = [1] i = 1 dn = 0 while True: diff, terms_jumped = next_term(digits, 20, i + dn, n) dn += terms_jumped if dn == n - i: break a_n = 0 for j in range(len(digits)): a_n += digits[j] * 10**j return a_n if __name__ == "__main__": print(f"{solution() = }")