class Solution:
def leastInterval(self, tasks: List[str], n: int) -> int:
count_dict = collections.defaultdict(int)
for task in tasks:
count_dict[task] += 1
max_freq = max(count_dict.values())
max_freq_count = 0
for key, value in count_dict.items():
if value == max_freq:
max_freq_count += 1
res = (max_freq - 1) * (n + 1) + max_freq_count
return max(res, len(tasks))