#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from typing import List, Tuple

import fbgemm_gpu
import numpy as np
import torch
from hypothesis import settings, Verbosity

# pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`.
open_source: bool = getattr(fbgemm_gpu, "open_source", False)

if open_source:
    # pyre-ignore[21]
    from test_utils import (
        gpu_unavailable,
        running_in_oss,
        running_on_github,
        TEST_WITH_ROCM,
    )
else:
    torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:cumem_utils")
    from fbgemm_gpu.test.test_utils import (  # noqa F401
        gpu_unavailable,
        running_in_oss,
        running_on_github,
        TEST_WITH_ROCM,
    )


torch.ops.import_module("fbgemm_gpu.sparse_ops")
settings.register_profile("derandomize", derandomize=True)
settings.load_profile("derandomize")


MAX_EXAMPLES = 40

# For long running tests reduce the number of iterations to reduce timeout errors.
MAX_EXAMPLES_LONG_RUNNING = 15

FORWARD_MAX_THREADS = 512

VERBOSITY: Verbosity = Verbosity.verbose


def gen_mixed_B_batch_sizes(B: int, T: int) -> Tuple[List[List[int]], List[int]]:
    num_ranks = np.random.randint(low=1, high=4)
    low = max(int(0.25 * B), 1)
    high = int(B)
    if low == high:
        Bs_rank_feature = [[B] * num_ranks for _ in range(T)]
    else:
        Bs_rank_feature = [
            np.random.randint(low=low, high=high, size=num_ranks).tolist()
            for _ in range(T)
        ]
    Bs = [sum(Bs_feature) for Bs_feature in Bs_rank_feature]
    return Bs_rank_feature, Bs


def format_ref_tensors_in_mixed_B_layout(
    ref_tensors: List[torch.Tensor], Bs_rank_feature: List[List[int]]
) -> torch.Tensor:
    # Relayout the reference tensor
    # Jagged dimension: (rank, table, local batch)
    num_ranks = len(Bs_rank_feature[0])
    split_tensors = [[] for _ in range(num_ranks)]  # shape (rank, table)
    for t, ref_tensor in enumerate(ref_tensors):
        assert ref_tensor.shape[0] == sum(Bs_rank_feature[t])
        tensors = ref_tensor.split(Bs_rank_feature[t])
        for r, tensor in enumerate(tensors):
            split_tensors[r].append(tensor.flatten())
    concat_list = []
    for r in range(num_ranks):
        concat_list += split_tensors[r]
    return torch.cat(concat_list, dim=0)


def assert_torch_equal(tensor_a: torch.Tensor, tensor_b: torch.Tensor) -> None:
    assert torch.equal(tensor_a, tensor_b)


def get_max_thread_blocks(stream: torch.cuda.streams.Stream) -> int:
    # Based on the empirical studies, having a max grid size that is 64x larger than
    # the number of SMs gives good performance across the board
    MAX_THREAD_BLOCKS_FACTOR = 64
    device = stream.device_index
    return (
        MAX_THREAD_BLOCKS_FACTOR
        * torch.cuda.get_device_properties(device).multi_processor_count
    )
