require "llvm"
require "../syntax/parser"
require "../syntax/visitor"
require "../semantic"
require "../program"
require "./llvm_builder_helper"

module Crystal
  MAIN_NAME           = "__crystal_main"
  RAISE_NAME          = "__crystal_raise"
  RAISE_OVERFLOW_NAME = "__crystal_raise_overflow"
  MALLOC_NAME         = "__crystal_malloc64"
  MALLOC_ATOMIC_NAME  = "__crystal_malloc_atomic64"
  REALLOC_NAME        = "__crystal_realloc64"
  GET_EXCEPTION_NAME  = "__crystal_get_exception"
  ONCE_INIT           = "__crystal_once_init"
  ONCE                = "__crystal_once"

  class Program
    def run(code, filename = nil, debug = Debug::Default)
      parser = new_parser(code)
      parser.filename = filename
      node = parser.parse
      node = normalize node
      node = semantic node
      evaluate node, debug: debug
    end

    def evaluate(node, debug = Debug::Default)
      visitor = CodeGenVisitor.new self, node, single_module: true, debug: debug
      visitor.accept node
      visitor.process_finished_hooks
      visitor.finish

      llvm_mod = visitor.modules[""].mod
      llvm_mod.target = target_machine.triple

      main = visitor.typed_fun?(llvm_mod, MAIN_NAME).not_nil!

      # It seems the JIT doesn't like it if we return an empty type (struct {})
      llvm_context = llvm_mod.context
      main_return_type = main.type.return_type
      main_return_type = llvm_context.void if node.type.nil_type?

      wrapper_type = LLVM::Type.function([] of LLVM::Type, main_return_type)
      wrapper = llvm_mod.functions.add("__evaluate_wrapper", wrapper_type) do |func|
        func.basic_blocks.append "entry" do |builder|
          argc = llvm_context.int32.const_int(0)
          argv = llvm_context.void_pointer.pointer.null
          ret = builder.call(main.type, main.func, [argc, argv])
          (node.type.void? || node.type.nil_type?) ? builder.ret : builder.ret(ret)
        end
      end

      llvm_mod.verify

      # We use the block form so we dispose the JIT as soon as possible.
      # This isn't really needed, but on LLVM 3.9 it sometimes messes the
      # stack trace for unwind, or so it seems, or there's something else
      # that for now we can't understand.
      #
      # Since this only affects specs (not the code that is generated by
      # the compiler) we can postpone the understanding and continue
      # running specs just fine.
      #
      # See https://github.com/crystal-lang/crystal/pull/3439
      LLVM::JITCompiler.new(llvm_mod) do |jit|
        jit.run_function wrapper, [] of LLVM::GenericValue, llvm_context
      end
    end

    def codegen(node, single_module = false, debug = Debug::Default)
      visitor = CodeGenVisitor.new self, node, single_module: single_module, debug: debug
      visitor.accept node
      visitor.process_finished_hooks
      visitor.finish

      visitor.modules
    end

    def llvm_id
      @llvm_id ||= LLVMId.new(self)
    end

    def llvm_typer
      @llvm_typer ||= LLVMTyper.new(self, LLVM::Context.new)
    end

    def size_of(type)
      if type.void?
        # We need `sizeof(Void)` to be 1 because doing
        # `Pointer(Void).malloc` must work like `Pointer(UInt8).malloc`,
        # that is, consider Void like the size of a byte.
        1
      elsif type.is_a?(BoolType)
        # LLVM reports 0 for bool (i1) but it must be 1 because it does occupy memory
        1
      else
        llvm_typer.size_of(llvm_typer.llvm_type(type))
      end
    end

    def instance_size_of(type)
      llvm_typer.size_of(llvm_typer.llvm_struct_type(type))
    end

    def offset_of(type, element_index)
      return 0_u64 if type.extern_union? || type.is_a?(StaticArrayInstanceType)
      llvm_typer.offset_of(llvm_typer.llvm_type(type), element_index)
    end

    def instance_offset_of(type, element_index)
      # extern unions and static arrays must be value types, which always use
      # the above `offset_of` instead
      llvm_typer.offset_of(llvm_typer.llvm_struct_type(type), element_index + 1)
    end
  end

  class CodeGenVisitor < Visitor
    SYMBOL_TABLE_NAME = ":symbol_table"

    include LLVMBuilderHelper

    getter llvm_mod : LLVM::Module
    getter builder : CrystalLLVMBuilder
    getter main : LLVM::Function
    getter modules : Hash(String, ModuleInfo)
    getter context : Context
    getter llvm_typer : LLVMTyper
    getter alloca_block : LLVM::BasicBlock
    getter entry_block : LLVM::BasicBlock
    getter personality_name : String
    property last : LLVM::Value

    class LLVMVar
      getter pointer : LLVM::Value
      getter type : Type

      # Normally a variable is associated with an alloca.
      # So for example, if you have a "x = Reference.new" you will have
      # an "Reference**" llvm value and you need to load that value
      # to access it.
      # However, the "self" argument is not copied to a local variable:
      # it's accessed from the arguments list, and is a "Reference*"
      # llvm value, so in a way it's "already loaded".
      # This field is true if that's the case.
      getter already_loaded : Bool
      getter debug_variable_created : Bool

      def initialize(@pointer, @type, @already_loaded = false, @debug_variable_created = false)
      end
    end

    alias LLVMVars = Hash(String, LLVMVar)

    record Handler, node : ExceptionHandler, context : Context
    record StringKey, mod : LLVM::Module, string : String
    record ModuleInfo, mod : LLVM::Module, typer : LLVMTyper, builder : CrystalLLVMBuilder

    @abi : LLVM::ABI
    @main_ret_type : Type
    @argc : LLVM::Value
    @argv : LLVM::Value
    @rescue_block : LLVM::BasicBlock?
    @catch_pad : LLVM::Value?
    @fun_types : Hash({LLVM::Module, String}, LLVM::Type)
    @sret_value : LLVM::Value?
    @cant_pass_closure_to_c_exception_call : Call?
    @main_llvm_context : LLVM::Context
    @main_llvm_typer : LLVMTyper
    @main_module_info : ModuleInfo
    @main_builder : CrystalLLVMBuilder
    @call_location : Location?

    @malloc_fun : LLVMTypedFunction?
    @malloc_atomic_fun : LLVMTypedFunction?
    @realloc_fun : LLVMTypedFunction?
    @raise_overflow_fun : LLVMTypedFunction?
    @c_malloc_fun : LLVMTypedFunction?
    @c_realloc_fun : LLVMTypedFunction?

    def initialize(@program : Program, @node : ASTNode, single_module = false, @debug = Debug::Default)
      @single_module = !!single_module
      @abi = @program.target_machine.abi
      @llvm_context = LLVM::Context.new
      # LLVM::Context.register(@llvm_context, "main")
      @llvm_mod = @llvm_context.new_module("main_module")
      @main_mod = @llvm_mod
      @main_llvm_context = @main_mod.context
      @llvm_typer = LLVMTyper.new(@program, @llvm_context)
      @main_llvm_typer = @llvm_typer
      @main_ret_type = node.type? || @program.nil_type
      ret_type = @llvm_typer.llvm_return_type(@main_ret_type)
      main_type = LLVM::Type.function([llvm_context.int32, llvm_context.void_pointer.pointer], ret_type)
      @main = @llvm_mod.functions.add(MAIN_NAME, main_type)
      @fun_types = { {@llvm_mod, MAIN_NAME} => main_type }

      if @program.has_flag? "windows"
        @personality_name = "__CxxFrameHandler3"
        @main.personality_function = windows_personality_fun.func
      else
        @personality_name = "__crystal_personality"
      end

      @context = Context.new @main, main_type, @program
      @context.return_type = @main_ret_type

      @argc = @main.params[0]
      @argc.name = "argc"

      @argv = @main.params[1]
      @argv.name = "argv"

      @builder = new_builder(@main_llvm_context)
      @main_builder = @builder

      @main_module_info = ModuleInfo.new(@main_mod, @main_llvm_typer, @builder)
      @modules = {"" => @main_module_info} of String => ModuleInfo
      @types_to_modules = {} of Type => ModuleInfo

      set_internal_fun_debug_location(@main, MAIN_NAME, nil)

      @alloca_block, @entry_block = new_entry_block_chain "alloca", "entry"

      @in_lib = false
      @strings = {} of StringKey => LLVM::Value
      @symbols = {} of String => Int32
      @symbols_by_index = [] of String
      @symbol_table_values = [] of LLVM::Value
      program.symbols.each_with_index do |sym, index|
        @symbols[sym] = index
        @symbols_by_index << sym
        @symbol_table_values << build_string_constant(sym, sym)
      end

      unless program.symbols.empty?
        symbol_table = define_symbol_table @llvm_mod, @llvm_typer
        symbol_table.initializer = llvm_type(@program.string).const_array(@symbol_table_values)
      end

      @last = llvm_nil
      @fun_literal_count = 0

      # This flag is to generate less code. If there's an if in the middle
      # of a series of expressions we don't need the result, so there's no
      # need to build a phi for it.
      # Also, we don't need the value of unions returned from calls if they
      # are not going to be used.
      @needs_value = true

      @unused_fun_defs = [] of FunDef
      @proc_counts = Hash(String, Int32).new(0)

      @llvm_mod.data_layout = self.data_layout

      # We need to define __crystal_malloc and __crystal_realloc as soon as possible,
      # to avoid some memory being allocated with plain malloc.
      codgen_well_known_functions @node

      initialize_predefined_constants

      if @debug.line_numbers?
        set_current_debug_location Location.new(@program.filename || "(no name)", 1, 1)
      end

      once_init
      initialize_simple_constants

      alloca_vars @program.vars, @program

      emit_vars_debug_info(@program.vars) if @debug.variables?
    end

    getter llvm_context

    def new_builder(llvm_context)
      wrap_builder(llvm_context.new_builder)
    end

    # Here we only initialize simple constants, those
    # that has simple values like 1, "foo" and other literals.
    def initialize_simple_constants
      @program.const_initializers.each do |initializer|
        # Simple constants are never initialized: they are always inlined
        next if initializer.compile_time_value
        next unless initializer.simple?

        initialize_simple_const(initializer)
      end
    end

    def wrap_builder(builder)
      CrystalLLVMBuilder.new builder, llvm_typer, c_printf_fun
    end

    def define_symbol_table(llvm_mod, llvm_typer)
      llvm_mod.globals.add llvm_typer.llvm_type(@program.string).array(@symbol_table_values.size), SYMBOL_TABLE_NAME
    end

    def data_layout
      @program.target_machine.data_layout
    end

    class CodegenWellKnownFunctions < Visitor
      @codegen : CodeGenVisitor

      def initialize(@codegen)
      end

      def visit(node : FileNode)
        true
      end

      def visit(node : Expressions)
        true
      end

      def visit(node : FunDef)
        case node.name
        when MALLOC_NAME, MALLOC_ATOMIC_NAME, REALLOC_NAME, RAISE_NAME,
             @codegen.personality_name, GET_EXCEPTION_NAME, RAISE_OVERFLOW_NAME,
             ONCE_INIT, ONCE
          @codegen.accept node
        end

        false
      end

      def visit(node : ASTNode)
        false
      end
    end

    def codgen_well_known_functions(node)
      visitor = CodegenWellKnownFunctions.new(self)
      node.accept visitor
    end

    def visit_any(node)
      !@builder.end
    end

    def type
      context.type.not_nil!
    end

    def finish
      codegen_return @main_ret_type

      # If there are no instructions in the alloca block and the
      # const block, we just removed them (less noise)
      if alloca_block.instructions.empty?
        alloca_block.delete
      else
        br_block_chain alloca_block, entry_block
      end

      @unused_fun_defs.each do |node|
        codegen_fun node.real_name, node.external, @program, is_exported_fun: true
      end

      env_dump = ENV["DUMP"]?
      case env_dump
      when Nil
        # Nothing
      when "1"
        dump_all_llvm = true
      else
        dump_llvm_regex = Regex.new(env_dump)
      end

      @modules.each do |name, info|
        mod = info.mod
        push_debug_info_metadata(mod) unless @debug.none?

        mod.dump if dump_all_llvm || name =~ dump_llvm_regex

        # Always run verifications so we can catch bugs earlier and more often.
        # We can probably remove this, or only enable this when compiling in
        # release mode, once we reach 1.0.
        mod.verify
      end

      dump_type_id if ENV["CRYSTAL_DUMP_TYPE_ID"]? == "1"
    end

    private def dump_type_id
      ids = @program.llvm_id.@ids.to_a
      ids.sort_by! { |_, (min, max)| {min, -max} }

      puts "CRYSTAL_DUMP_TYPE_ID"
      parent_ids = [] of {Int32, Int32}
      ids.each do |type, (min, max)|
        while parent_id = parent_ids.last?
          break if min >= parent_id[0] && max <= parent_id[1]
          parent_ids.pop
        end
        indent = " " * (2 * parent_ids.size)

        show_generic_args = type.is_a?(GenericInstanceType) ||
                            type.is_a?(GenericClassInstanceMetaclassType) ||
                            type.is_a?(GenericModuleInstanceMetaclassType)
        puts "#{indent}{#{min} - #{max}}: #{type.to_s(generic_args: show_generic_args)}"
        parent_ids << {min, max}
      end
      puts
    end

    def visit(node : Annotation)
      false
    end

    def visit(node : FunDef)
      if @in_lib
        return false
      end

      unless node.external.dead?
        # Mark as dead so we don't generate it twice
        # (can happen with well known functions like __crystal_raise)
        node.external.dead = true

        if node.external.used?
          codegen_fun node.real_name, node.external, @program, is_exported_fun: true
        else
          # If the fun is not invoked we codegen it at the end so
          # we don't have issues with constants being used before
          # they are declared.
          # But, apparently, llvm requires us to define them so that
          # calls can find them, so we do so.
          codegen_fun node.real_name, node.external, @program, is_exported_fun: false
          @unused_fun_defs << node
        end
      end

      false
    end

    def visit(node : FileNode)
      with_context(Context.new(context.fun, context.fun_type, context.type)) do
        file_module = @program.file_module(node.filename)
        if vars = file_module.vars?
          set_current_debug_location Location.new(node.filename, 1, 1) if @debug.line_numbers?
          alloca_vars vars, file_module

          emit_vars_debug_info(vars) if @debug.variables?
        end
        accept node.node
        @last = llvm_nil
      end

      false
    end

    def visit(node : Nop)
      @last = llvm_nil
    end

    def visit(node : NilLiteral)
      @last = llvm_nil
    end

    def visit(node : BoolLiteral)
      @last = int1(node.value ? 1 : 0)
    end

    def visit(node : CharLiteral)
      @last = int32(node.value.ord)
    end

    def visit(node : NumberLiteral)
      case node.kind
      in .i8?
        @last = int8(node.value.to_i8)
      in .u8?
        @last = int8(node.value.to_u8)
      in .i16?
        @last = int16(node.value.to_i16)
      in .u16?
        @last = int16(node.value.to_u16)
      in .i32?
        @last = int32(node.value.to_i32)
      in .u32?
        @last = int32(node.value.to_u32)
      in .i64?
        @last = int64(node.value.to_i64)
      in .u64?
        @last = int64(node.value.to_u64)
      in .i128?
        @last = int128(node.value.to_i128)
      in .u128?
        @last = int128(node.value.to_u128)
      in .f32?
        @last = float32(node.value)
      in .f64?
        @last = float64(node.value)
      end
    end

    def visit(node : StringLiteral)
      @last = build_string_constant(node.value, node.value)
    end

    def visit(node : SymbolLiteral)
      @last = int(@symbols[node.value])
    end

    def visit(node : TupleLiteral)
      request_value do
        type = node.type.as(TupleInstanceType)

        if node.elements.any?(Splat)
          tuple_size = node.elements.sum do |exp|
            exp.is_a?(Splat) ? exp.type.as(TupleInstanceType).tuple_types.size : 1
          end
          exp_values = Array({Type, LLVM::Value}).new(tuple_size)

          node.elements.each do |exp|
            accept exp

            if exp.is_a?(Splat)
              tuple_type = exp.type.as(TupleInstanceType)
              tuple_type.tuple_types.each_with_index do |subtype, j|
                exp_values << {subtype, codegen_tuple_indexer(tuple_type, @last, j)}
              end
            else
              exp_values << {exp.type, @last}
            end
          end

          @last = allocate_tuple(type) do |_, i|
            exp_values[i]
          end
        else
          @last = allocate_tuple(type) do |tuple_type, i|
            exp = node.elements[i]
            accept exp
            {exp.type, @last}
          end
        end
      end
      false
    end

    def visit(node : NamedTupleLiteral)
      request_value do
        type = node.type.as(NamedTupleInstanceType)
        struct_type = llvm_type(type)
        tuple = alloca struct_type
        node.entries.each do |entry|
          accept entry.value
          index = type.name_index(entry.key).not_nil!
          entry_type = type.entries[index].type
          assign aggregate_index(struct_type, tuple, index), entry_type, entry.value.type, @last
        end
        @last = tuple
      end
      false
    end

    def visit(node : PointerOf)
      @last = case node_exp = node.exp
              when Var
                context.vars[node_exp.name].pointer
              when InstanceVar
                instance_var_ptr context.type.remove_typedef, node_exp.name, llvm_self_ptr
              when ClassVar
                # Make sure the class var is initializer before taking a pointer of it
                if node_exp.var.initializer
                  initialize_class_var(node_exp)
                end
                get_global class_var_global_name(node_exp.var), node_exp.type, node_exp.var
              when Global
                node.raise "BUG: there should be no use of global variables other than $~ and $?"
              when Path
                # Make sure the constant is initialized before taking a pointer of it
                const = node_exp.target_const.not_nil!
                read_const_pointer(const)
              when ReadInstanceVar
                accept node_exp.obj
                instance_var_ptr (node_exp.obj.type), node_exp.name, @last
              when Call
                # lib external var
                extern = node_exp.dependencies.first.as(External)
                var = get_external_var(extern)
                check_c_fun extern.type, var
              else
                raise "BUG: #{node}"
              end
      false
    end

    def visit(node : ProcLiteral)
      fun_literal_name = fun_literal_name(node)
      is_closure = node.def.closure?

      # If we don't care about a proc literal's return type then we mark the associated
      # def as returning void. This can't be done in the type inference phase because
      # of bindings and type propagation.
      if node.force_nil?
        node.def.set_type @program.nil
      else
        # Use proc literal's type, which might have a broader type then the body
        # (for example, return type: Int32 | String, body: String)
        node.def.set_type node.return_type
      end

      the_fun = codegen_fun fun_literal_name, node.def, context.type, fun_module_info: @main_module_info, is_fun_literal: true, is_closure: is_closure
      the_fun = check_main_fun fun_literal_name, the_fun

      set_current_debug_location(node) if @debug.line_numbers?
      fun_ptr = cast_to_void_pointer(the_fun.func)
      if is_closure
        ctx_ptr = cast_to_void_pointer(context.closure_ptr.not_nil!)
      else
        ctx_ptr = llvm_context.void_pointer.null
      end
      @last = make_fun node.type, fun_ptr, ctx_ptr

      false
    end

    def fun_literal_name(node : ProcLiteral)
      location = node.location.try &.expanded_location
      if location && (type = node.type?)
        proc_name = true
        filename = location.filename.as(String)
        fun_literal_name = Crystal.safe_mangling(@program, "~proc#{type}@#{Crystal.relative_filename(filename)}:#{location.line_number}")
      else
        proc_name = false
        fun_literal_name = "~fun_literal"
      end
      proc_count = @proc_counts[fun_literal_name]
      proc_count += 1
      @proc_counts[fun_literal_name] = proc_count

      if proc_count > 1
        if proc_name
          fun_literal_name = "#{fun_literal_name[0...5]}#{proc_count}#{fun_literal_name[5..-1]}"
        else
          fun_literal_name = "#{fun_literal_name}#{proc_count}"
        end
      end

      fun_literal_name
    end

    def visit(node : ProcPointer)
      owner = node.call.target_def.owner

      if obj = node.obj
        accept obj
        call_self = @last
      elsif owner.passed_as_self?
        call_self = llvm_self
      end

      last_fun = target_def_fun(node.call.target_def, owner)

      set_current_debug_location(node) if @debug.line_numbers?
      fun_ptr = cast_to_void_pointer(last_fun.func)
      if call_self && !owner.metaclass? && !owner.is_a?(LibType)
        ctx_ptr = cast_to_void_pointer(call_self)
      else
        ctx_ptr = llvm_context.void_pointer.null
      end
      @last = make_fun node.type, fun_ptr, ctx_ptr

      false
    end

    def visit(node : Expressions)
      old_needs_value = @needs_value
      @needs_value = false

      last_index = node.expressions.size - 1
      node.expressions.each_with_index do |exp, i|
        @needs_value = true if old_needs_value && i == last_index
        accept exp
      end

      @needs_value = old_needs_value
      false
    end

    def visit(node : Return)
      node_type = accept_control_expression(node)

      codegen_return_node(node, node_type)

      false
    end

    def codegen_return_node(node, node_type)
      old_last = @last

      execute_ensures_until(node.target.as(Def))

      @last = old_last

      if return_phi = context.return_phi
        return_phi.add @last, node_type
      else
        codegen_return node_type
      end
    end

    def codegen_return(type : NoReturnType | Nil)
      unreachable
    end

    def codegen_return(type : Type)
      return if @builder.end

      method_type = context.return_type.not_nil!
      if method_type.void?
        ret
      elsif method_type.nil_type?
        ret
      elsif method_type.no_return?
        unreachable
      else
        value = upcast(@last, method_type, type)
        ret to_rhs(value, method_type)
      end
    end

    def visit(node : ClassDef)
      node.hook_expansions.try &.each do |hook|
        accept hook
      end
      accept node.body
      @last = llvm_nil
      false
    end

    def visit(node : ModuleDef)
      accept node.body
      @last = llvm_nil
      false
    end

    def visit(node : LibDef)
      @in_lib = true
      accept node.body
      @in_lib = false
      @last = llvm_nil
      false
    end

    def visit(node : CStructOrUnionDef)
      @last = llvm_nil
      false
    end

    def visit(node : EnumDef)
      node.members.each do |member|
        if member.is_a?(Assign)
          accept member
        end
      end

      @last = llvm_nil
      false
    end

    def visit(node : ExternalVar)
      @last = llvm_nil
      false
    end

    def visit(node : TypeDef)
      @last = llvm_nil
      false
    end

    def visit(node : Alias)
      @last = llvm_nil
      false
    end

    def visit(node : TypeOf)
      # convert virtual metaclasses to non-virtual ones, because only the
      # non-virtual type IDs are needed
      @last = type_id(node.type.devirtualize)
      false
    end

    def visit(node : SizeOf)
      @last = trunc(llvm_size(node.exp.type.sizeof_type), llvm_context.int32)
      false
    end

    def visit(node : InstanceSizeOf)
      @last = trunc(llvm_struct_size(node.exp.type.sizeof_type), llvm_context.int32)
      false
    end

    def visit(node : Include)
      node.hook_expansions.try &.each do |hook|
        accept hook
      end

      @last = llvm_nil
      false
    end

    def visit(node : Extend)
      node.hook_expansions.try &.each do |hook|
        accept hook
      end

      @last = llvm_nil
      false
    end

    def visit(node : If)
      if node.truthy?
        accept node.cond
        accept node.then
        if @needs_value && (node_type = node.type?) && (then_type = node.then.type?)
          @last = upcast(@last, node_type, then_type)
        end
        return false
      end

      if node.falsey?
        accept node.cond
        accept node.else
        if @needs_value && (node_type = node.type?) && (else_type = node.else.type?)
          @last = upcast(@last, node_type, else_type)
        end
        return false
      end

      then_block, else_block = new_blocks "then", "else"

      request_value do
        set_current_debug_location(node) if @debug.line_numbers?
        codegen_cond_branch node.cond, then_block, else_block
      end

      Phi.open(self, node, @needs_value) do |phi|
        codegen_if_branch phi, node.then, then_block, false
        codegen_if_branch phi, node.else, else_block, true
      end

      false
    end

    def codegen_if_branch(phi, node, branch_block, last)
      position_at_end branch_block
      accept node
      phi.add @last, node.type?, last
    end

    def visit(node : While)
      set_ensure_exception_handler(node)

      with_cloned_context do
        cond = node.cond.single_expression
        endless_while = cond.true_literal?

        if endless_while
          while_block = new_block "while"

          Phi.open(self, node, @needs_value) do |phi|
            context.while_block = while_block
            context.break_phi = phi
            context.next_phi = nil

            br while_block

            position_at_end while_block

            discard_value(node.body)
            br while_block
          end
        else
          while_block, body_block, fail_block = new_blocks "while", "body", "fail"

          Phi.open(self, node, @needs_value) do |phi|
            context.while_block = while_block
            context.break_phi = phi
            context.next_phi = nil

            br while_block

            position_at_end while_block

            request_value do
              set_current_debug_location node.cond if @debug.line_numbers?
              codegen_cond_branch node.cond, body_block, fail_block
            end

            position_at_end body_block

            discard_value(node.body)
            br while_block

            position_at_end fail_block

            phi.add llvm_nil, @program.nil, last: true
          end
        end
      end

      false
    end

    def codegen_cond_branch(node_cond, then_block, else_block)
      cond codegen_cond(node_cond), then_block, else_block

      nil
    end

    def codegen_cond(node : ASTNode)
      accept node
      codegen_cond node.type.remove_indirection
    end

    def visit(node : Not)
      request_value(node.exp)
      @last = codegen_cond node.exp.type.remove_indirection
      @last = not @last
      false
    end

    def visit(node : Break)
      set_current_debug_location(node) if @debug.line_numbers?
      node_type = accept_control_expression(node)

      case target = node.target
      when Call
        if break_phi = context.break_phi
          old_last = @last
          execute_ensures_until(target)
          @last = old_last

          break_phi.add @last, node_type
          return false
        end
      when While
        if break_phi = context.break_phi
          old_last = @last
          execute_ensures_until(target)
          @last = old_last

          break_phi.add @last, node_type
          return false
        end
      end

      node.raise "BUG: unknown exit for break"
    end

    def visit(node : Next)
      set_current_debug_location(node) if @debug.line_numbers?
      node_type = accept_control_expression(node)

      case target = node.target
      when Block
        if next_phi = context.next_phi
          old_last = @last
          execute_ensures_until(target)
          @last = old_last

          next_phi.add @last, node_type
          return false
        end
      when While
        if while_block = context.while_block
          execute_ensures_until(target)
          br while_block
          return false
        end
      else
        # The only possibility is that we are in a captured block,
        # so this is the same as a return
        codegen_return_node(node, node_type)
        return false
      end

      node.raise "BUG: unknown exit for next"
    end

    def accept_control_expression(node)
      if exp = node.exp
        request_value(exp)
        exp.type? || @program.nil
      else
        @last = llvm_nil
        @program.nil
      end
    end

    def visit(node : Assign)
      return false if node.discarded?

      target, value = node.target, node.value
      codegen_assign(target, value, node)
    end

    def codegen_assign(target : Underscore, value, node)
      accept value
      false
    end

    def codegen_assign(target : Path, value, node)
      const = target.target_const.not_nil!
      if const.used? && !const.simple? && !const.compile_time_value
        initialize_const(const)
      end
      @last = llvm_nil
      false
    end

    def codegen_assign(target, value, node)
      target_type = target.type?

      # This means it's an instance variable initialize of a generic type,
      # or a class variable initializer
      unless target_type
        if target.is_a?(ClassVar)
          # This is the case of a class var initializer
          initialize_class_var(target)
        end
        return false
      end

      # This is the case of an instance variable initializer
      if target.is_a?(InstanceVar) && !context.type.is_a?(InstanceVarContainer)
        return false
      end

      request_value(value)

      return if value.no_returns?

      last = @last

      set_current_debug_location node if @debug.line_numbers?
      ptr = case target
            when InstanceVar
              instance_var_ptr context.type, target.name, llvm_self_ptr
            when Global
              node.raise "BUG: there should be no use of global variables other than $~ and $?"
            when ClassVar
              read_class_var_ptr(target)
            when Var
              # Can't assign void
              return if target.type.void?

              # If assigning to a special variable in a method that yields,
              # assign to that variable too.
              check_assign_to_special_var_in_block(target, value)

              var = context.vars[target.name]?
              if var
                target_type = var.type
                var.pointer
              else
                target.raise "BUG: missing var #{target}"
              end
            else
              node.raise "Unknown assign target in codegen: #{target}"
            end

      @last = last
      llvm_value = last

      # When setting an instance variable of an extern type, if it's a Proc
      # type we need to check that the value is not a closure and just get
      # the function pointer
      if target.is_a?(InstanceVar) && context.type.extern? && target.type.proc?
        llvm_value = check_proc_is_not_closure(llvm_value, target.type)
      end

      assign ptr, target_type, value.type, llvm_value

      false
    end

    def check_assign_to_special_var_in_block(target, value)
      if (block_context = context.block_context?) && target.special_var?
        var = block_context.vars[target.name]
        assign var.pointer, var.type, value.type, @last
      end
    end

    def get_global(name, type, real_var, initial_value = nil)
      if real_var.thread_local?
        get_thread_local(name, type, real_var)
      else
        get_global_var(name, type, real_var, initial_value)
      end
    end

    def get_global_var(name, type, real_var, initial_value = nil)
      ptr = @llvm_mod.globals[name]?
      unless ptr
        llvm_type = llvm_type(type)

        thread_local = real_var.thread_local?

        # Declare global in this module as external
        ptr = @llvm_mod.globals.add(llvm_type, name)
        ptr.thread_local = true if thread_local

        if @llvm_mod == @main_mod
          ptr.initializer = initial_value || llvm_type.null
        else
          ptr.linkage = LLVM::Linkage::External

          # Define it in main if it's not already defined
          main_ptr = @main_mod.globals[name]?
          unless main_ptr
            main_llvm_type = @main_llvm_typer.llvm_type(type)
            main_ptr = @main_mod.globals.add(main_llvm_type, name)
            main_ptr.initializer = initial_value || main_llvm_type.null
            main_ptr.thread_local = true if thread_local
          end
        end
      end

      ptr
    end

    def get_thread_local(name, type, real_var)
      # If it's thread local, we use a NoInline function to access it
      # because of http://lists.llvm.org/pipermail/llvm-dev/2016-February/094736.html
      #
      # So, we basically make a function like this (assuming the global is a i32):
      #
      # define void @"*$foo"(i32**) noinline {
      #   store i32* @"$foo", i32** %0
      #   ret void
      # }
      #
      # And then in the caller we alloca an i32*, pass it, and then load the pointer,
      # which is the same as the global, but through a non-inlined function.
      #
      # Making a function that just returns the pointer doesn't work: LLVM inlines it.
      fun_name = "*#{name}"
      thread_local_fun = typed_fun?(@main_mod, fun_name)
      unless thread_local_fun
        thread_local_fun = in_main do
          define_main_function(fun_name, [llvm_type(type).pointer.pointer], llvm_context.void) do |func|
            set_internal_fun_debug_location(func, fun_name, real_var.location)
            builder.store get_global_var(name, type, real_var), func.params[0]
            builder.ret
          end
        end
        thread_local_fun.func.add_attribute LLVM::Attribute::NoInline
      end
      thread_local_fun = check_main_fun(fun_name, thread_local_fun)
      pointer_type = llvm_type(type).pointer
      indirection_ptr = alloca pointer_type
      call thread_local_fun, indirection_ptr
      load pointer_type, indirection_ptr
    end

    def visit(node : TypeDeclaration)
      return false if node.discarded?

      var = node.var
      case var
      when Var
        declare_var var

        if value = node.value
          codegen_assign(var, value, node)
          return false
        end
      when Global
        node.raise "BUG: there should be no use of global variables other than $~ and $?"
      when ClassVar
        # This is the case of a class var initializer
        initialize_class_var(var)
      end

      @last = llvm_nil

      false
    end

    def visit(node : UninitializedVar)
      var = node.var

      case var
      when Var
        llvm_var = declare_var var
        if node.type.nil_type? || !@needs_value
          @last = llvm_nil
        else
          @last = to_lhs(llvm_var.pointer, node.type)
        end
      else
        @last = llvm_nil
      end

      false
    end

    def visit(node : Var)
      # It can happen that a variable ends up with no type, as in:
      #
      #     i = 0
      #     i.is_a?(Int32) ? 1 : i # here
      #
      # In that case we treat it as NoReturn.
      return unreachable unless node.type?

      var = context.vars[node.name]?
      if var
        return unreachable if var.type.no_return?

        # Special variables always have an extra pointer
        already_loaded = (node.special_var? ? false : var.already_loaded)
        @last = downcast var.pointer, node.type, var.type, already_loaded, extern: false
      elsif node.name == "self"
        if node.type.metaclass?
          @last = type_id(node.type)
        else
          @last = downcast llvm_self_ptr, node.type, context.type, true
        end
      else
        node.raise "BUG: missing context var: #{node.name}"
      end
    end

    def visit(node : Global)
      node.raise "BUG: there should be no use of global variables other than $~ and $?"
    end

    def visit(node : ClassVar)
      @last = read_class_var(node)
    end

    def visit(node : InstanceVar)
      read_instance_var node.type, context.type, node.name, llvm_self_ptr
    end

    def end_visit(node : ReadInstanceVar)
      obj_type = node.obj.type
      if obj_type.is_a?(UnionType)
        union_ptr = @last
        union_type_id = type_id(union_ptr, obj_type)

        Phi.open(self, node, @needs_value) do |phi|
          obj_type.union_types.each do |union_type|
            id_matches = match_type_id(node.type, union_type, union_type_id)

            current_match_label, next_match_label = new_blocks "current_match", "next_match"
            cond id_matches, current_match_label, next_match_label
            position_at_end current_match_label

            value_ptr = downcast union_ptr, union_type, obj_type, true
            ivar_type = union_type.lookup_instance_var(node.name).type
            read_instance_var ivar_type, union_type, node.name, value_ptr

            phi.add @last, ivar_type

            position_at_end next_match_label
          end
          unreachable
        end
      else
        read_instance_var node.type, node.obj.type, node.name, @last
      end
    end

    def read_instance_var(node_type, type, name, value)
      type = type.remove_typedef
      ivar = type.lookup_instance_var(name)
      ivar_ptr = instance_var_ptr type, name, value
      @last = downcast ivar_ptr, node_type, ivar.type, false, extern: type.extern?
      if type.extern?
        # When reading the instance variable of a C struct or union
        # we need to convert C functions to Crystal procs. This
        # can happen for example in Struct#to_s, where all fields
        # are inspected.
        @last = check_c_fun node_type, @last
      end
      false
    end

    def visit(node : Cast)
      request_value(node.obj)

      last_value = @last

      obj_type = node.obj.type
      to_type = node.to.type.virtual_type

      if to_type.pointer?
        if obj_type.nil_type?
          @last = llvm_type(to_type).null
        else
          @last = cast_to last_value, to_type
        end
      elsif obj_type.pointer?
        # Special case: for `ptr.as(Nil)` there's no bitcast involved
        if to_type.nil_type?
          @last = llvm_nil
        else
          @last = cast_to last_value, to_type
        end
      else
        resulting_type = node.type
        if node.upcast?
          @last = upcast last_value, resulting_type, obj_type
        elsif obj_type != resulting_type
          type_id = type_id last_value, obj_type
          cmp = match_type_id obj_type, resulting_type, type_id

          matches_block, doesnt_match_block = new_blocks "matches", "doesnt_match"
          cond cmp, matches_block, doesnt_match_block

          position_at_end doesnt_match_block

          temp_var_name = @program.new_temp_var_name
          context.vars[temp_var_name] = LLVMVar.new(last_value, obj_type, already_loaded: true)
          accept type_cast_exception_call(obj_type, to_type, node, temp_var_name)
          context.vars.delete temp_var_name

          position_at_end matches_block
          @last = downcast last_value, resulting_type, obj_type, true
        end
      end

      false
    end

    def visit(node : NilableCast)
      request_value(node.obj)

      last_value = @last

      obj_type = node.obj.type
      to_type = node.to.type

      resulting_type = node.type

      filtered_type = obj_type.filter_by(to_type)

      unless filtered_type
        @last = upcast llvm_nil, resulting_type, @program.nil
        return
      end

      non_nilable_type = node.non_nilable_type

      if node.upcast?
        @last = upcast last_value, non_nilable_type, obj_type
        @last = upcast @last, resulting_type, non_nilable_type
      elsif obj_type != non_nilable_type
        type_id = type_id last_value, obj_type
        cmp = match_type_id obj_type, non_nilable_type, type_id

        Phi.open(self, node, @needs_value) do |phi|
          matches_block, doesnt_match_block = new_blocks "matches", "doesnt_match"
          cond cmp, matches_block, doesnt_match_block

          position_at_end doesnt_match_block
          @last = upcast llvm_nil, resulting_type, @program.nil
          phi.add @last, resulting_type

          position_at_end matches_block
          @last = downcast last_value, non_nilable_type, obj_type, true
          @last = upcast @last, resulting_type, non_nilable_type
          phi.add @last, resulting_type, last: true
        end
      else
        @last = upcast last_value, resulting_type, obj_type
      end

      false
    end

    def type_cast_exception_call(from_type, to_type, node, var_name)
      pieces = [
        StringLiteral.new("cast from ").at(node),
        Call.new(Var.new(var_name).at(node), "class").at(node),
        StringLiteral.new(" to #{to_type} failed").at(node),
      ] of ASTNode

      if location = node.location
        pieces << StringLiteral.new(", at #{location.expanded_location}:#{location.line_number}").at(node)
      end

      ex = Call.new(Path.global("TypeCastError").at(node), "new", StringInterpolation.new(pieces).at(node)).at(node)
      call = Call.global("raise", ex).at(node)
      call = @program.normalize(call)

      meta_vars = MetaVars.new
      meta_vars[var_name] = MetaVar.new(var_name, type: from_type)
      visitor = MainVisitor.new(@program, meta_vars)
      @program.visit_main call, visitor: visitor
      call
    end

    def cant_pass_closure_to_c_exception_call
      @cant_pass_closure_to_c_exception_call ||= begin
        call = Call.global("raise", StringLiteral.new("passing a closure to C is not allowed")).at(UNKNOWN_LOCATION)
        @program.visit_main call
        call.raise "::raise must be of NoReturn return type!" unless call.type.is_a?(NoReturnType)
        call
      end
    end

    def visit(node : IsA)
      codegen_type_filter node, &.filter_by(node.const.type)
    end

    def visit(node : RespondsTo)
      codegen_type_filter node, &.filter_by_responds_to(node.name)
    end

    def codegen_type_filter(node, &)
      accept node.obj
      obj_type = node.obj.type

      type_id = type_id @last, obj_type
      filtered_type = yield(obj_type).not_nil!

      @last = match_type_id obj_type, filtered_type, type_id

      false
    end

    def declare_var(var)
      context.vars[var.name] ||= begin
        pointer = var.no_returns? ? llvm_nil : alloca(llvm_type(var.type), var.name)
        debug_variable_created =
          if context.fun.naked?
            # Naked functions must not have debug info associated with them
            false
          else
            declare_variable(var.name, var.type, pointer, var.location)
          end
        LLVMVar.new(pointer, var.type, debug_variable_created: debug_variable_created)
      end
    end

    def declare_lib_var(name, type, thread_local)
      var = @llvm_mod.globals[name]?
      unless var
        var = llvm_mod.globals.add(llvm_c_return_type(type), name)
        var.linkage = LLVM::Linkage::External
        if @program.has_flag?("win32") && @program.has_flag?("preview_dll")
          var.dll_storage_class = LLVM::DLLStorageClass::DLLImport
        end
        var.thread_local = thread_local
      end
      var
    end

    def visit(node : Def)
      node.hook_expansions.try &.each do |hook|
        accept hook
      end

      @last = llvm_nil
      false
    end

    def visit(node : Macro)
      @last = llvm_nil
      false
    end

    def visit(node : Path)
      if const = node.target_const
        read_const(const, node)
      elsif replacement = node.syntax_replacement
        accept replacement
      else
        node_type = node.type
        # Special case: if the type is a type tuple we need to create a tuple for it
        if node_type.is_a?(TupleInstanceType)
          @last = allocate_tuple(node_type) do |tuple_type, i|
            {tuple_type, type_id(tuple_type)}
          end
        else
          @last = type_id(node.type)
        end
      end
      false
    end

    def visit(node : Generic)
      @last = type_id(node.type)
      false
    end

    def visit(node : Yield)
      if node.expanded
        raise "BUG: #{node} at #{node.location} should have been expanded"
      end

      block_context = context.block_context.not_nil!
      block = context.block
      splat_index = block.splat_index

      closured_vars = closured_vars(block.vars, block)

      malloc_closure closured_vars, block_context, block_context.closure_parent_context

      old_scope = block_context.vars["%scope"]?

      if node_scope = node.scope
        request_value(node_scope)
        block_context.vars["%scope"] = LLVMVar.new(@last, node_scope.type)
      end

      # First accept all yield expressions and assign them to block vars
      unless node.exps.empty?
        exp_values = Array({LLVM::Value, Type}).new(node.exps.size)

        # We first accept the expressions and store the values, without
        # assigning them to the block vars yet because we might have
        # a nested yield that would override a block argument's value
        node.exps.each_with_index do |exp, i|
          request_value(exp)

          if exp.is_a?(Splat)
            tuple_type = exp.type.as(TupleInstanceType)
            tuple_type.tuple_types.each_with_index do |subtype, j|
              exp_values << {codegen_tuple_indexer(tuple_type, @last, j), subtype}
            end
          else
            exp_values << {@last, exp.type}
          end
        end

        # Now assign exp values to block arguments
        if splat_index
          j = 0
          block.args.each_with_index do |arg, i|
            block_var = block_context.vars[arg.name]
            if i == splat_index
              exp_value = allocate_tuple(arg.type.as(TupleInstanceType)) do
                exp_value2, exp_type = exp_values[j]
                j += 1
                {exp_type, exp_value2}
              end
              exp_type = arg.type
            else
              exp_value, exp_type = exp_values[j]
              j += 1
            end
            assign block_var.pointer, block_var.type, exp_type, exp_value
          end
        else
          # Check if tuple unpacking is needed
          if exp_values.size == 1 &&
             (exp_type = exp_values.first[1]).is_a?(TupleInstanceType) &&
             block.args.size > 1
            exp_value = exp_values.first[0]
            exp_type.tuple_types.each_with_index do |tuple_type, i|
              arg = block.args[i]?
              if arg && arg.name != "_"
                t_type = tuple_type
                t_value = codegen_tuple_indexer(exp_type, exp_value, i)
                block_var = block_context.vars[arg.name]
                assign block_var.pointer, block_var.type, t_type, t_value
              end
            end
          else
            exp_values.each_with_index do |(exp_value, exp_type), i|
              if (arg = block.args[i]?) && arg.name != "_"
                block_var = block_context.vars[arg.name]
                assign block_var.pointer, block_var.type, exp_type, exp_value
              end
            end
          end
        end
      end

      Phi.open(self, block, @needs_value) do |phi|
        with_cloned_context(block_context) do |old|
          # Reset vars that are declared inside the block and are nilable
          reset_nilable_vars block

          context.break_phi = old.return_phi
          context.next_phi = phi
          context.closure_parent_context = block_context.closure_parent_context

          set_ensure_exception_handler(block)

          request_value(block.body)
        end

        phi.add @last, block.body.type?, last: true
      end

      if old_scope
        block_context.vars["%scope"] = old_scope
      end

      false
    end

    def visit(node : Unreachable)
      builder.unreachable
    end

    def check_proc_is_not_closure(value, type)
      check_fun_name = "~check_proc_is_not_closure"
      func = typed_fun?(@main_mod, check_fun_name) || create_check_proc_is_not_closure_fun(check_fun_name)
      func = check_main_fun check_fun_name, func
      value = call func, [value] of LLVM::Value
      pointer_cast value, llvm_proc_type(type).pointer
    end

    def create_check_proc_is_not_closure_fun(fun_name)
      in_main do
        define_main_function(fun_name, [llvm_typer.proc_type], llvm_context.void_pointer) do |func|
          set_internal_fun_debug_location(func, fun_name)

          param = func.params.first

          fun_ptr = extract_value param, 0
          ctx_ptr = extract_value param, 1

          ctx_is_null_block = new_block "ctx_is_null"
          ctx_is_not_null_block = new_block "ctx_is_not_null"

          ctx_is_null = equal? ctx_ptr, llvm_context.void_pointer.null
          cond ctx_is_null, ctx_is_null_block, ctx_is_not_null_block

          position_at_end ctx_is_null_block
          ret fun_ptr

          position_at_end ctx_is_not_null_block
          accept cant_pass_closure_to_c_exception_call
        end
      end
    end

    def make_fun(type, fun_ptr, ctx_ptr)
      struct_type = llvm_type(type)
      closure_ptr = alloca struct_type
      store fun_ptr, aggregate_index(struct_type, closure_ptr, 0)
      store ctx_ptr, aggregate_index(struct_type, closure_ptr, 1)
      load(struct_type, closure_ptr)
    end

    def make_nilable_fun(type)
      null = llvm_context.void_pointer.null
      make_fun type, null, null
    end

    def in_main(&)
      old_builder = self.builder
      old_position = old_builder.insert_block
      old_llvm_mod = @llvm_mod
      old_llvm_context = @llvm_context
      old_llvm_typer = @llvm_typer
      old_fun = context.fun
      old_fun_type = context.fun_type
      old_ensure_exception_handlers = @ensure_exception_handlers
      old_rescue_block = @rescue_block
      old_catch_pad = @catch_pad
      old_entry_block = @entry_block
      old_alloca_block = @alloca_block
      old_needs_value = @needs_value
      old_debug_location = @current_debug_location

      @llvm_mod = @main_mod
      @llvm_context = @main_llvm_context
      @llvm_typer = @main_llvm_typer
      @builder = @main_builder

      @ensure_exception_handlers = nil
      @rescue_block = nil
      @catch_pad = nil

      clear_current_debug_location if @debug.line_numbers?

      block_value = yield

      @builder = old_builder
      position_at_end old_position

      @llvm_mod = old_llvm_mod
      @llvm_context = old_llvm_context
      @llvm_typer = old_llvm_typer
      @ensure_exception_handlers = old_ensure_exception_handlers
      @rescue_block = old_rescue_block
      @catch_pad = old_catch_pad
      @entry_block = old_entry_block
      @alloca_block = old_alloca_block
      @needs_value = old_needs_value
      context.fun = old_fun
      context.fun_type = old_fun_type
      set_current_debug_location old_debug_location if @debug.line_numbers?

      block_value
    end

    def define_main_function(name, arg_types : Array(LLVM::Type), return_type : LLVM::Type, needs_alloca : Bool = false, &)
      define_main_function(name, LLVM::Type.function(arg_types, return_type), needs_alloca) { |func| yield func }
    end

    def define_main_function(name, type : LLVM::Type, needs_alloca : Bool = false, &)
      if @llvm_mod != @main_mod
        raise "wrong usage of define_main_function: you must put it inside an `in_main` block"
      end

      func = add_typed_fun(@main_mod, name, type)
      context.fun = func.func
      context.fun_type = type
      context.fun.linkage = LLVM::Linkage::Internal if @single_module
      if needs_alloca
        new_entry_block
        yield func.func
        br_from_alloca_to_entry
      else
        block = func.func.basic_blocks.append "entry"
        position_at_end block
        yield func.func
      end
      func
    end

    # used for generated internal functions like `~metaclass` and `~match`
    def set_internal_fun_debug_location(func, name, location = nil)
      return if @debug.none?
      location ||= UNKNOWN_LOCATION
      emit_fun_debug_metadata(func, name, location)
      set_current_debug_location(location) if @debug.line_numbers?
    end

    private UNKNOWN_LOCATION = Location.new("??", 0, 0)

    def llvm_self(type = context.type)
      self_var = context.vars["self"]?
      if self_var
        downcast self_var.pointer, type, self_var.type, true
      else
        type_id(type.not_nil!)
      end
    end

    def llvm_self_ptr
      type = context.type
      if type.is_a?(VirtualType)
        if type.struct?
          # A virtual struct doesn't need a cast to a more generic pointer
          # (it's the union already)
          llvm_self
        else
          cast_to llvm_self, type.base_type
        end
      else
        llvm_self
      end
    end

    def new_entry_block
      @alloca_block, @entry_block = new_entry_block_chain "alloca", "entry"
    end

    def new_entry_block_chain(*names)
      blocks = new_blocks *names
      position_at_end blocks.last
      blocks
    end

    def br_from_alloca_to_entry
      # If there are no instructions in the alloca we can delete
      # it and just keep the entry block (less noise).
      if alloca_block.instructions.empty?
        alloca_block.delete
      else
        br_block_chain alloca_block, entry_block
      end
    end

    def br_block_chain(*blocks)
      old_block = insert_block

      0.upto(blocks.size - 2) do |i|
        position_at_end blocks[i]
        clear_current_debug_location if @debug.line_numbers?
        br blocks[i + 1]
      end

      position_at_end old_block
    end

    def new_block(name = "")
      context.fun.basic_blocks.append name
    end

    def new_blocks(*names)
      names.map { |name| new_block name }
    end

    def alloca_vars(vars, obj = nil, args = nil, parent_context = nil, reset_nilable_vars = true)
      self_closured = obj.is_a?(Def) && obj.self_closured?
      closured_vars = closured_vars(vars, obj)
      alloca_non_closured_vars(vars, obj, args, reset_nilable_vars)
      malloc_closure closured_vars, context, parent_context, self_closured
    end

    def alloca_non_closured_vars(vars, obj = nil, args = nil, reset_nilable_vars = true)
      return unless vars

      in_alloca_block do
        # Allocate all variables which are not closured and don't belong to an outer closure
        vars.each do |name, var|
          next if name == "self" || context.vars[name]?

          var_type = var.type? || @program.nil

          if var_type.void?
            context.vars[name] = LLVMVar.new(llvm_nil, @program.void)
          elsif var_type.no_return?
            # No alloca for NoReturn
          elsif var.closure_in?(obj)
            # We deal with closured vars later
          elsif !obj || var.belongs_to?(obj)
            # We deal with arguments later
            is_arg = args.try &.any? { |arg| arg.name == var.name }
            next if is_arg

            ptr = alloca llvm_type(var_type), name

            location = var.location
            if location.nil? && obj.is_a?(ASTNode)
              location = obj.location
            end

            debug_variable_created =
              if location && !context.fun.naked?
                declare_variable name, var_type, ptr, location, alloca_block
              else
                false
              end
            context.vars[name] = LLVMVar.new(ptr, var_type, debug_variable_created: debug_variable_created)

            # Assign default nil for variables that are bound to the nil variable
            if reset_nilable_vars && bound_to_mod_nil?(var)
              assign ptr, var_type, @program.nil, llvm_nil
            end
          else
            # The variable belong to an outer closure
          end
        end
      end
    end

    def closured_vars(vars, obj = nil)
      return unless vars

      closure_vars = nil

      vars.each_value do |var|
        # It might be the case that a closured variable ends up without
        # a type, as in #2196, because a branch can't be typed and is
        # finally removed before codegen. In that case we just assume
        # Nil as a type.
        if var.closure_in?(obj) && var.type?
          closure_vars ||= [] of MetaVar
          closure_vars << var
        end
      end

      closure_vars
    end

    def malloc_closure(closure_vars, current_context, parent_context = nil, self_closured = false)
      parent_closure_type = parent_context.try &.closure_type

      if closure_vars || self_closured
        closure_vars ||= [] of MetaVar
        closure_type = @llvm_typer.closure_context_type(closure_vars, parent_closure_type, (self_closured ? current_context.type : nil))
        closure_ptr = malloc closure_type
        closure_vars.each_with_index do |var, i|
          current_context.vars[var.name] = LLVMVar.new(gep(closure_type, closure_ptr, 0, i, var.name), var.type)
        end
        closure_skip_parent = false

        if parent_closure_type
          store parent_context.not_nil!.closure_ptr.not_nil!, gep(closure_type, closure_ptr, 0, closure_vars.size, "parent")
        end

        if self_closured
          offset = parent_closure_type ? 1 : 0
          self_value = to_rhs(llvm_self, current_context.type)

          store self_value, gep(closure_type, closure_ptr, 0, closure_vars.size + offset, "self")

          current_context.closure_self = current_context.type
        end
      elsif parent_context && parent_context.closure_type
        closure_vars = parent_context.closure_vars
        closure_type = parent_context.closure_type
        closure_ptr = parent_context.closure_ptr
        closure_skip_parent = true
      else
        closure_skip_parent = false
      end

      current_context.closure_vars = closure_vars
      current_context.closure_type = closure_type
      current_context.closure_ptr = closure_ptr
      current_context.closure_skip_parent = closure_skip_parent
    end

    def undef_vars(vars, obj)
      return unless vars

      vars.each do |name, var|
        # Don't remove special vars because they are local for the entire method
        if var.belongs_to?(obj) && !var.special_var?
          context.vars.delete(name)
        end
      end
    end

    # Sets to nil any variable in node that is nilable.
    def reset_nilable_vars(node)
      vars = node.vars
      return unless vars

      vars.each do |name, var|
        if var.context == node && bound_to_mod_nil?(var)
          context_var = context.vars[name]
          assign context_var.pointer, context_var.type, @program.nil, llvm_nil
        end
      end
    end

    def bound_to_mod_nil?(var)
      var.dependencies.any? &.same?(@program.nil_var)
    end

    def alloca(type, name = "")
      in_alloca_block { builder.alloca type, name }
    end

    def in_alloca_block(&)
      old_block = insert_block
      position_at_end alloca_block
      value = yield
      position_at_end old_block
      value
    end

    def printf(format, args = [] of LLVM::Value)
      call c_printf_fun, [builder.global_string_pointer(format)] + args
    end

    # Emits a debug message that shows the current llvm basic block name,
    # the location within the codegen that was used to emit this log.
    #
    # The message is only generated if `CRYSTAL_DEBUG_CODEGEN` is set
    #
    # The block given to this method should yield `printf` arguments to show
    # additional information. The following forms are all valid and helps to
    # allocate the arguments only if the message is to be generated.
    #
    # ```
    # debug_codegen_log
    # debug_codegen_log { }
    # debug_codegen_log { "Lorem" }
    # debug_codegen_log { {"Lorem"} }
    # debug_codegen_log { {"Lorem %d", [an_int_llvm_value] of LLVM::Value} }
    # ```
    #
    def debug_codegen_log(file = __FILE__, line = __LINE__, &)
      return unless ENV["CRYSTAL_DEBUG_CODEGEN"]?
      printf_args = yield || ""
      printf_args = {printf_args, [] of LLVM::Value} if printf_args.is_a?(String)
      printf_args = {printf_args[0], [] of LLVM::Value} if printf_args.is_a?({String})
      msg, args = printf_args
      printf("<block: #{insert_block.name || "???"} @ #{Crystal.relative_filename(file)}:#{line}> #{msg}\n", args)
    end

    # :ditto:
    def debug_codegen_log(file = __FILE__, line = __LINE__)
      debug_codegen_log(file, line) { }
    end

    def unreachable(file = __FILE__, line = __LINE__)
      debug_codegen_log(file, line) { "Reached the unreachable!" }
      builder.unreachable
    end

    def allocate_aggregate(type)
      struct_type = llvm_struct_type(type)
      if type.passed_by_value?
        type_ptr = alloca struct_type
      else
        if type.is_a?(InstanceVarContainer) && !type.struct? &&
           type.all_instance_vars.each_value.any? &.type.has_inner_pointers?
          type_ptr = malloc struct_type
        else
          type_ptr = malloc_atomic struct_type
        end
      end

      memset type_ptr, int8(0), struct_type.size
      run_instance_vars_initializers(type, type, type_ptr)

      unless type.struct?
        type_id_ptr = aggregate_index(struct_type, type_ptr, 0)
        store type_id(type), type_id_ptr
      end

      @last = type_ptr
    end

    def allocate_tuple(type, &)
      struct_type = llvm_type(type)
      tuple = alloca struct_type
      type.tuple_types.each_with_index do |tuple_type, i|
        exp_type, value = yield tuple_type, i
        assign aggregate_index(struct_type, tuple, i), tuple_type, exp_type, value
      end
      tuple
    end

    def run_instance_vars_initializers(real_type, type : ClassType | GenericClassInstanceType, type_ptr)
      if superclass = type.superclass
        run_instance_vars_initializers(real_type, superclass, type_ptr)
      end

      run_instance_vars_initializers_non_recursive real_type, type, type_ptr
    end

    def run_instance_vars_initializers(real_type, type : Type, type_ptr)
      # Nothing to do
    end

    def run_instance_vars_initializers_non_recursive(real_type, type, type_ptr)
      initializers = type.instance_vars_initializers
      return unless initializers

      initializers.each do |init|
        ivar = real_type.lookup_instance_var(init.name)

        with_cloned_context do
          # Instance var initializers must run with "self"
          # properly set up to the type being allocated
          context.type = real_type.metaclass
          context.vars = LLVMVars.new
          alloca_vars init.meta_vars

          accept init.value

          ivar_ptr = instance_var_ptr real_type, init.name, type_ptr
          assign ivar_ptr, ivar.type, init.value.type, @last
        end
      end
    end

    def malloc(type)
      generic_malloc(type) { crystal_malloc_fun }
    end

    def malloc_atomic(type)
      generic_malloc(type) { crystal_malloc_atomic_fun }
    end

    def generic_malloc(type, &)
      size = type.size

      if malloc_fun = yield
        pointer = call malloc_fun, size
      else
        pointer = call_c_malloc size
      end

      pointer_cast pointer, type.pointer
    end

    def array_malloc(type, count)
      generic_array_malloc(type, count) { crystal_malloc_fun }
    end

    def array_malloc_atomic(type, count)
      generic_array_malloc(type, count) { crystal_malloc_atomic_fun }
    end

    def generic_array_malloc(type, count, &)
      size = builder.mul type.size, count

      if malloc_fun = yield
        pointer = call malloc_fun, size
      else
        pointer = call_c_malloc size
      end

      memset pointer, int8(0), size
      pointer_cast pointer, type.pointer
    end

    def crystal_malloc_fun
      @malloc_fun ||= typed_fun?(@main_mod, MALLOC_NAME)
      if malloc_fun = @malloc_fun
        check_main_fun MALLOC_NAME, malloc_fun
      else
        nil
      end
    end

    def crystal_malloc_atomic_fun
      @malloc_atomic_fun ||= typed_fun?(@main_mod, MALLOC_ATOMIC_NAME)
      if malloc_fun = @malloc_atomic_fun
        check_main_fun MALLOC_ATOMIC_NAME, malloc_fun
      else
        nil
      end
    end

    def crystal_realloc_fun
      @realloc_fun ||= typed_fun?(@main_mod, REALLOC_NAME)
      if realloc_fun = @realloc_fun
        check_main_fun REALLOC_NAME, realloc_fun
      else
        nil
      end
    end

    def crystal_raise_overflow_fun
      @raise_overflow_fun ||= typed_fun?(@main_mod, RAISE_OVERFLOW_NAME)
      if raise_overflow_fun = @raise_overflow_fun
        check_main_fun RAISE_OVERFLOW_NAME, raise_overflow_fun
      else
        raise Error.new("Missing __crystal_raise_overflow function, either use std-lib's prelude or define it")
      end
    end

    # We only use C's malloc in tests that don't require the prelude,
    # so they don't require the GC. Outside tests these are not used,
    # and __crystal_* functions are invoked instead.

    def call_c_malloc(size)
      size = trunc(size, llvm_context.int32) unless @program.bits64?
      call c_malloc_fun, size
    end

    def c_malloc_fun
      malloc_fun = @c_malloc_fun = fetch_typed_fun(@main_mod, "malloc") do
        size = @program.bits64? ? @main_llvm_context.int64 : @main_llvm_context.int32
        LLVM::Type.function([size], @main_llvm_context.void_pointer)
      end

      check_main_fun "malloc", malloc_fun
    end

    def call_c_realloc(buffer, size)
      size = trunc(size, llvm_context.int32) unless @program.bits64?
      call c_realloc_fun, [buffer, size]
    end

    def c_realloc_fun
      realloc_fun = @c_realloc_fun = fetch_typed_fun(@main_mod, "realloc") do
        size = @program.bits64? ? @main_llvm_context.int64 : @main_llvm_context.int32
        LLVM::Type.function([@main_llvm_context.void_pointer, size], @main_llvm_context.void_pointer)
      end

      check_main_fun "realloc", realloc_fun
    end

    def memset(pointer, value, size)
      len_arg = @program.bits64? ? size : trunc(size, llvm_context.int32)

      pointer = cast_to_void_pointer pointer
      res = call c_memset_fun, [pointer, value, len_arg, int1(0)]
      LibLLVM.set_instr_param_alignment(res, 1, 4)

      res
    end

    def memcpy(dest, src, len, align, volatile)
      res = call c_memcpy_fun, [dest, src, len, volatile]

      LibLLVM.set_instr_param_alignment(res, 1, align)
      LibLLVM.set_instr_param_alignment(res, 2, align)

      res
    end

    def realloc(buffer, size)
      if realloc_fun = crystal_realloc_fun
        call realloc_fun, [buffer, size]
      else
        call_c_realloc buffer, size
      end
    end

    private def c_printf_fun
      fetch_typed_fun(@llvm_mod, "printf") do
        LLVM::Type.function([@llvm_context.void_pointer], @llvm_context.int32, true)
      end
    end

    private def c_memset_fun
      name = {% if LibLLVM::IS_LT_150 %}
               @program.bits64? ? "llvm.memset.p0i8.i64" : "llvm.memset.p0i8.i32"
             {% else %}
               @program.bits64? ? "llvm.memset.p0.i64" : "llvm.memset.p0.i32"
             {% end %}

      fetch_typed_fun(@llvm_mod, name) do
        len_type = @program.bits64? ? @llvm_context.int64 : @llvm_context.int32
        arg_types = [@llvm_context.void_pointer, @llvm_context.int8, len_type, @llvm_context.int1]
        LLVM::Type.function(arg_types, @llvm_context.void)
      end
    end

    private def c_memcpy_fun
      name = {% if LibLLVM::IS_LT_150 %}
               @program.bits64? ? "llvm.memcpy.p0i8.p0i8.i64" : "llvm.memcpy.p0i8.p0i8.i32"
             {% else %}
               @program.bits64? ? "llvm.memcpy.p0.p0.i64" : "llvm.memcpy.p0.p0.i32"
             {% end %}

      fetch_typed_fun(@llvm_mod, name) do
        len_type = @program.bits64? ? @llvm_context.int64 : @llvm_context.int32
        arg_types = [@llvm_context.void_pointer, @llvm_context.void_pointer, len_type, @llvm_context.int1]
        LLVM::Type.function(arg_types, @llvm_context.void)
      end
    end

    def to_lhs(value, type)
      # `llvm_embedded_type` needed for void-like types
      type.passed_by_value? ? value : load(llvm_embedded_type(type), value)
    end

    def to_rhs(value, type)
      type.passed_by_value? ? load(llvm_embedded_type(type), value) : value
    end

    def extern_to_lhs(value, type)
      type.passed_by_value? ? value : load(llvm_embedded_c_type(type), value)
    end

    def extern_to_rhs(value, type)
      type.passed_by_value? ? load(llvm_embedded_c_type(type), value) : value
    end

    # *type* is the pointee type of *ptr* (not the type of the returned
    # element)
    def aggregate_index(type : LLVM::Type, ptr : LLVM::Value, index : Int32)
      gep type, ptr, 0, index
    end

    def instance_var_ptr(type, name, pointer)
      if type.extern_union?
        return union_field_ptr(type, type.instance_vars[name].type, pointer)
      end

      index = type.index_of_instance_var(name).not_nil!

      unless type.struct?
        index += 1
      end

      target_type = type
      if type.is_a?(VirtualType)
        if type.struct?
          if (_type = type.remove_indirection).is_a?(UnionType)
            # For a struct we need to cast the second part of the union to the base type
            _, value_ptr = union_type_and_value_pointer(pointer, _type)
            target_type = type.base_type
            pointer = cast_to_pointer value_ptr, target_type
          else
            # Nothing, there's only one subclass so it's the struct already
          end
        else
          target_type = type.base_type
          pointer = cast_to pointer, target_type
        end
      end

      aggregate_index llvm_struct_type(target_type), pointer, index
    end

    def process_finished_hooks
      last = @last
      @program.process_finished_hooks(self)
      @last = last
    end

    def build_string_constant(str, name = "str")
      name = "#{name[0..18]}..." if name.bytesize > 18
      name = name.gsub '@', '.'
      name = "'#{name}'"
      key = StringKey.new(@llvm_mod, str)
      @strings[key] ||= begin
        global = @llvm_mod.globals.add(@llvm_typer.llvm_string_type(str.bytesize), name)
        global.linkage = LLVM::Linkage::Private
        global.global_constant = true
        global.initializer = llvm_context.const_struct [
          type_id(@program.string),
          int32(str.bytesize),
          int32(str.size),
          llvm_context.const_string(str),
        ]
        cast_to global, @program.string
      end
    end

    def request_value(request : Bool = true, &)
      old_needs_value = @needs_value
      @needs_value = request
      begin
        yield
      ensure
        @needs_value = old_needs_value
      end
    end

    def request_value(node : ASTNode)
      request_value do
        accept node
      end
    end

    def discard_value(node : ASTNode)
      request_value(false) do
        accept node
      end
    end

    def accept(node)
      node.accept self
    end

    def visit(node : ExpandableNode)
      raise "BUG: #{node} (#{node.class}) at #{node.location} should have been expanded"
    end

    def visit(node : ASTNode)
      true
    end
  end

  def self.safe_mangling(program, name)
    if program.has_flag?("windows")
      String.build do |str|
        name.each_char do |char|
          if char.ascii_alphanumeric? || char == '_'
            str << char
          else
            str << '.'
            char.ord.to_s(str, 16, upcase: true)
            str << '.'
          end
        end
      end
    else
      name
    end
  end
end

require "./*"
