import json import os from functools import partial from typing import Any, Callable, Dict, List, Optional import torch from torch._inductor.autoheuristic.autoheuristic_utils import ( AHContext, AHMetadata, AHOperation, Choice, CHOICE_COL, Feedback, FEEDBACK_COL, get_metadata_str_from_log, ) from torch._inductor.autoheuristic.learned_heuristic_controller import ( LearnedHeuristicController, ) from torch._inductor.ir import ChoiceCaller from torch._inductor.runtime.runtime_utils import cache_dir from torch._inductor.utils import get_gpu_shared_memory class LocalFeedback: """ To be able to collect data for a choice, a function providing feedback given a choice has to be provided. LocalFeedback can be used when AutoHeuristic should immediately run the function to collect feedback for each choice (see pad_mm.py, where the autotuning happens locally, for an example). """ def __init__(self, feedback_fn: Callable[[Choice], Feedback]) -> None: self.feedback_fn = feedback_fn def __call__(self, choice: Choice) -> Feedback: return self.feedback_fn(choice) class InconsistentMetadata(Exception): """ Exception that is thrown when AutoHeuristic tries to log data to a file where the metadata stored in the file does not match the metadata it would store if the file didn't exist. """ class AutoHeuristic: """ AutoHeuristic is a framework that allows one to collect data, learn a heuristic (i.e. a regression tree) and generate the heuristic to code. This class allows one to collect data. The collected data can then be used to train a heuristic (see torchgen/autoheuristic/). """ collected_feedback: Dict[Choice, Feedback] def __init__( self, fallback: Callable[[], Choice], choices: List[Choice], feedback: Optional[LocalFeedback], context: AHContext, name: str, augment_context: Optional[List[AHOperation]] = None, precondition: Optional[Callable[[AHMetadata, AHContext], bool]] = None, ) -> None: """ Initializes an instance of the AutoHeuristic class. Args: fallback: A callable that returns a Choice when the heuristic is unsure which choice to make, or AutoHeuristic is in data collection mode. choices: A list of possible choices the heuristic can make. feedback: An instance of LocalFeedback that provides feedback for a given choice. context: Context to store with each choice and feedback. name: A string that identifies the heuristic. augment_context: An optional list of AHOperation instances that augment the context. precondition: A callable that returns a boolean indicating whether AutoHeuristic should run. """ self.fallback = fallback self.choices = choices self.feedback = feedback self.context = context self.name = name self.collected_feedback = {} self.augment_context = augment_context self.metadata = AHMetadata( get_gpu_shared_memory(), torch.cuda.get_device_capability(), self.choices, self.name, ) self.precondition = precondition if not self.satisfies_precondition(): return if torch._inductor.config.autoheuristic_log_path == "DEFAULT": self.log_path = self.get_default_log_path() else: self.log_path = torch._inductor.config.autoheuristic_log_path if torch._inductor.config.collect_autoheuristic(self.name): if self.feedback is not None: for choice in self.choices: feedback_val = self.feedback(choice) self.save_data(choice, feedback_val) def satisfies_precondition(self) -> bool: return self.precondition is None or self.precondition( self.metadata, self.context ) def get_choice(self) -> Choice: """ Returns the chosen option based on the value of autoheuristic_use. If self.name is one of the comma separated strings in autoheuristic_use, it queries a learned heuristic to make a decision. Otherwise, it returns the fallback option. """ if not self.satisfies_precondition(): return self.fallback() if torch._inductor.config.use_autoheuristic(self.name): if self.augment_context is not None: self.context.apply_operations(self.augment_context) controller = LearnedHeuristicController( self.metadata, self.context, ) decision = controller.get_decision() if decision not in self.choices: # TODO(AlnisM): We might want to allow this in the future return self.fallback() if decision is not None: return decision return self.fallback() def get_top_k_choices( self, top_k: int, always_included: Optional[List[str]] = None ) -> Optional[List[Choice]]: if not self.satisfies_precondition(): return None if torch._inductor.config.use_autoheuristic(self.name): if self.augment_context is not None: self.context.apply_operations(self.augment_context) controller = LearnedHeuristicController( self.metadata, self.context, ) choices = controller.get_decisions_ranked(top_k) if choices is None: return None if always_included is not None: for choice in always_included: if choice not in choices: choices.append(choice) return choices return None def get_collected_feedback(self, choice: Choice) -> Any: return self.collected_feedback.get(choice, None) @staticmethod def get_device_identifier() -> str: # a heuristic might work well for one GPU, but not for another # we store the collected data per GPU model and learn a heuristic per GPU model # TODO(AlnisM): just using the device name for now, but the same GPU model can have different names device_name = torch.cuda.get_device_name().replace(" ", "_") return device_name def get_default_log_path(self) -> str: device_name = self.get_device_identifier() path = f"{cache_dir()}/autoheuristic/{device_name}/" os.makedirs(path, exist_ok=True) path += f"{self.name}.txt" return path def serialize_metadata(self) -> str: metadata_dict = self.metadata.to_dict() ( num_features, cat_features, ) = self.context.get_numerical_and_categorical_features() metadata_dict["numerical_features"] = num_features metadata_dict["categorical_features"] = cat_features return json.dumps(metadata_dict) def save_data(self, choice: Choice, feedback_val: Feedback) -> None: self.collected_feedback[choice] = feedback_val log_path = self.log_path lines = [] log_exists = os.path.exists(log_path) if log_exists: # if log already exists, make sure it is consistent metadata = self.serialize_metadata() existing_metadata = get_metadata_str_from_log(self.log_path) if existing_metadata != metadata: raise InconsistentMetadata( "Given metadata does not match existing metadata" ) else: lines.append(self.serialize_metadata()) feature_header = self.context.get_feature_names_csv() header = feature_header + "," + CHOICE_COL + "," + FEEDBACK_COL lines.append(header) line = "" feature_values = self.context.get_feature_values_csv() line += feature_values + "," + choice + "," + str(feedback_val) lines.append(line) with open(log_path, "a") as f: f.write("\n".join(lines) + "\n") class AutoHeuristicSelectAlgorithm(AutoHeuristic): """ AutoHeuristicSelectAlgorithm is a subclass of AutoHeuristic that allows one to collect data and learn a heuristic when one wants to use AutoHeuristic for kernel choice selection. """ def __init__( self, fallback: Callable[[], Optional[ChoiceCaller]], choices: List[ChoiceCaller], input_nodes: List[Any], context: AHContext, name: str, augment_context: Optional[List[AHOperation]] = None, precondition: Optional[Callable[[AHMetadata, AHContext], bool]] = None, ) -> None: """ The arguments choices, input_nodes and name have to match the ones used in the call to autotune_select_algorithm(), e.g. if the following call is made autotune_select_algorithm(name, choices, input_nodes, layout), the same name, choices and input_nodes have to be used here. """ self.input_nodes = input_nodes self.choicestr2choice: Dict[str, ChoiceCaller] = {} for choice in choices: self.choicestr2choice[choice.autoheuristic_id()] = choice choices_str = list(self.choicestr2choice.keys()) def fallback_str() -> str: fallback_choice = fallback() if fallback_choice is None: # TODO: Find a nicer way to handle this return "unsure" return fallback_choice.autoheuristic_id() super().__init__( fallback_str, choices_str, None, context, name, augment_context, precondition, ) if ( torch._inductor.config.collect_autoheuristic(self.name) and self.satisfies_precondition() ): self.register_global_feedback(input_nodes, choices) def register_global_feedback( self, input_nodes: List[Any], choices: List[ChoiceCaller] ) -> None: """ Registers a callback in select_algorithm, which is called with the timing of each choice. """ from torch._inductor.select_algorithm import ( add_feedback_saver, create_inputs_key, create_precompile_key, ) def store_global_feedback( ah_inputs_key: str, ah_precompile_key: str, timings: Dict[ChoiceCaller, float], name: str, input_nodes: List[Any], choices: List[ChoiceCaller], ) -> None: current_inputs_key = create_inputs_key(input_nodes) if current_inputs_key != ah_inputs_key: return current_precompile_key = create_precompile_key( name, current_inputs_key, choices ) if current_precompile_key != ah_precompile_key: return for choice, time in timings.items(): self.save_data(choice.autoheuristic_id(), time) inputs_key = create_inputs_key(input_nodes) precompile_key = create_precompile_key(self.name, inputs_key, choices) feedback_saver = partial(store_global_feedback, inputs_key, precompile_key) add_feedback_saver(feedback_saver) def get_choice_caller(self) -> Optional[ChoiceCaller]: choice = self.get_choice() return self.choicestr2choice.get(choice, None) def get_top_k_choices_caller( self, top_k: int, always_included: Optional[List[str]] = None ) -> Optional[List[ChoiceCaller]]: choices = self.get_top_k_choices(top_k, always_included) if choices is None: return None return [self.choicestr2choice[choice] for choice in choices]