diff --git a/src/hotspot/cpu/x86/assembler_x86.cpp b/src/hotspot/cpu/x86/assembler_x86.cpp index c2fcbcea71e24..c287f349383fb 100644 --- a/src/hotspot/cpu/x86/assembler_x86.cpp +++ b/src/hotspot/cpu/x86/assembler_x86.cpp @@ -3476,6 +3476,22 @@ void Assembler::vmovdqu(XMMRegister dst, XMMRegister src) { emit_int16(0x6F, (0xC0 | encode)); } +void Assembler::vmovw(XMMRegister dst, Register src) { + assert(VM_Version::supports_avx512_fp16(), "requires AVX512-FP16"); + InstructionAttr attributes(AVX_128bit, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ false); + attributes.set_is_evex_instruction(); + int encode = vex_prefix_and_encode(dst->encoding(), 0, src->encoding(), VEX_SIMD_66, VEX_OPCODE_MAP5, &attributes, true); + emit_int16(0x6E, (0xC0 | encode)); +} + +void Assembler::vmovw(Register dst, XMMRegister src) { + assert(VM_Version::supports_avx512_fp16(), "requires AVX512-FP16"); + InstructionAttr attributes(AVX_128bit, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ false); + attributes.set_is_evex_instruction(); + int encode = vex_prefix_and_encode(src->encoding(), 0, dst->encoding(), VEX_SIMD_66, VEX_OPCODE_MAP5, &attributes, true); + emit_int16(0x7E, (0xC0 | encode)); +} + void Assembler::vmovdqu(XMMRegister dst, Address src) { assert(UseAVX > 0, ""); InstructionMark im(this); @@ -8443,6 +8459,238 @@ void Assembler::vpaddq(XMMRegister dst, XMMRegister nds, Address src, int vector emit_operand(dst, src, 0); } +void Assembler::evaddph(XMMRegister dst, XMMRegister nds, XMMRegister src, int vector_len) { + assert(VM_Version::supports_avx512_fp16(), "requires AVX512-FP16"); + assert(vector_len == Assembler::AVX_512bit || VM_Version::supports_avx512vl(), ""); + InstructionAttr attributes(vector_len, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ true); + attributes.set_is_evex_instruction(); + int encode = vex_prefix_and_encode(dst->encoding(), nds->encoding(), src->encoding(), VEX_SIMD_NONE, VEX_OPCODE_MAP5, &attributes); + emit_int16(0x58, (0xC0 | encode)); +} + +void Assembler::evaddph(XMMRegister dst, XMMRegister nds, Address src, int vector_len) { + assert(VM_Version::supports_avx512_fp16(), "requires AVX512-FP16"); + assert(vector_len == Assembler::AVX_512bit || VM_Version::supports_avx512vl(), ""); + InstructionMark im(this); + InstructionAttr attributes(vector_len, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ true); + attributes.set_is_evex_instruction(); + attributes.set_address_attributes(/* tuple_type */ EVEX_FVM, /* input_size_in_bits */ EVEX_NObit); + vex_prefix(src, nds->encoding(), dst->encoding(), VEX_SIMD_NONE, VEX_OPCODE_MAP5, &attributes); + emit_int8(0x58); + emit_operand(dst, src, 0); +} + +void Assembler::evsubph(XMMRegister dst, XMMRegister nds, XMMRegister src, int vector_len) { + assert(VM_Version::supports_avx512_fp16(), "requires AVX512-FP16"); + assert(vector_len == Assembler::AVX_512bit || VM_Version::supports_avx512vl(), ""); + InstructionAttr attributes(vector_len, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ true); + attributes.set_is_evex_instruction(); + int encode = vex_prefix_and_encode(dst->encoding(), nds->encoding(), src->encoding(), VEX_SIMD_NONE, VEX_OPCODE_MAP5, &attributes); + emit_int16(0x5C, (0xC0 | encode)); +} + +void Assembler::evsubph(XMMRegister dst, XMMRegister nds, Address src, int vector_len) { + assert(VM_Version::supports_avx512_fp16(), "requires AVX512-FP16"); + assert(vector_len == Assembler::AVX_512bit || VM_Version::supports_avx512vl(), ""); + InstructionMark im(this); + InstructionAttr attributes(vector_len, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ true); + attributes.set_is_evex_instruction(); + attributes.set_address_attributes(/* tuple_type */ EVEX_FVM, /* input_size_in_bits */ EVEX_NObit); + vex_prefix(src, nds->encoding(), dst->encoding(), VEX_SIMD_NONE, VEX_OPCODE_MAP5, &attributes); + emit_int8(0x5C); + emit_operand(dst, src, 0); +} + +void Assembler::evmulph(XMMRegister dst, XMMRegister nds, XMMRegister src, int vector_len) { + assert(VM_Version::supports_avx512_fp16(), "requires AVX512-FP16"); + assert(vector_len == Assembler::AVX_512bit || VM_Version::supports_avx512vl(), ""); + InstructionAttr attributes(vector_len, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ true); + attributes.set_is_evex_instruction(); + int encode = vex_prefix_and_encode(dst->encoding(), nds->encoding(), src->encoding(), VEX_SIMD_NONE, VEX_OPCODE_MAP5, &attributes); + emit_int16(0x59, (0xC0 | encode)); +} + +void Assembler::evmulph(XMMRegister dst, XMMRegister nds, Address src, int vector_len) { + assert(VM_Version::supports_avx512_fp16(), "requires AVX512-FP16"); + assert(vector_len == Assembler::AVX_512bit || VM_Version::supports_avx512vl(), ""); + InstructionMark im(this); + InstructionAttr attributes(vector_len, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ true); + attributes.set_is_evex_instruction(); + attributes.set_address_attributes(/* tuple_type */ EVEX_FVM, /* input_size_in_bits */ EVEX_NObit); + vex_prefix(src, nds->encoding(), dst->encoding(), VEX_SIMD_NONE, VEX_OPCODE_MAP5, &attributes); + emit_int8(0x59); + emit_operand(dst, src, 0); +} + +void Assembler::evminph(XMMRegister dst, XMMRegister nds, XMMRegister src, int vector_len) { + assert(VM_Version::supports_avx512_fp16(), "requires AVX512-FP16"); + assert(vector_len == Assembler::AVX_512bit || VM_Version::supports_avx512vl(), ""); + InstructionAttr attributes(vector_len, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ true); + attributes.set_is_evex_instruction(); + int encode = vex_prefix_and_encode(dst->encoding(), nds->encoding(), src->encoding(), VEX_SIMD_NONE, VEX_OPCODE_MAP5, &attributes); + emit_int16(0x5D, (0xC0 | encode)); +} + +void Assembler::evminph(XMMRegister dst, XMMRegister nds, Address src, int vector_len) { + assert(VM_Version::supports_avx512_fp16(), "requires AVX512-FP16"); + assert(vector_len == Assembler::AVX_512bit || VM_Version::supports_avx512vl(), ""); + InstructionMark im(this); + InstructionAttr attributes(vector_len, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ true); + attributes.set_is_evex_instruction(); + attributes.set_address_attributes(/* tuple_type */ EVEX_FVM, /* input_size_in_bits */ EVEX_NObit); + vex_prefix(src, nds->encoding(), dst->encoding(), VEX_SIMD_NONE, VEX_OPCODE_MAP5, &attributes); + emit_int8(0x5D); + emit_operand(dst, src, 0); +} + +void Assembler::evmaxph(XMMRegister dst, XMMRegister nds, XMMRegister src, int vector_len) { + assert(VM_Version::supports_avx512_fp16(), "requires AVX512-FP16"); + assert(vector_len == Assembler::AVX_512bit || VM_Version::supports_avx512vl(), ""); + InstructionAttr attributes(vector_len, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ true); + attributes.set_is_evex_instruction(); + int encode = vex_prefix_and_encode(dst->encoding(), nds->encoding(), src->encoding(), VEX_SIMD_NONE, VEX_OPCODE_MAP5, &attributes); + emit_int16(0x5F, (0xC0 | encode)); +} + +void Assembler::evmaxph(XMMRegister dst, XMMRegister nds, Address src, int vector_len) { + assert(VM_Version::supports_avx512_fp16(), "requires AVX512-FP16"); + assert(vector_len == Assembler::AVX_512bit || VM_Version::supports_avx512vl(), ""); + InstructionMark im(this); + InstructionAttr attributes(vector_len, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ true); + attributes.set_is_evex_instruction(); + attributes.set_address_attributes(/* tuple_type */ EVEX_FVM, /* input_size_in_bits */ EVEX_NObit); + vex_prefix(src, nds->encoding(), dst->encoding(), VEX_SIMD_NONE, VEX_OPCODE_MAP5, &attributes); + emit_int8(0x5F); + emit_operand(dst, src, 0); +} + +void Assembler::evdivph(XMMRegister dst, XMMRegister nds, XMMRegister src, int vector_len) { + assert(VM_Version::supports_avx512_fp16(), "requires AVX512-FP16"); + assert(vector_len == Assembler::AVX_512bit || VM_Version::supports_avx512vl(), ""); + InstructionAttr attributes(vector_len, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ true); + attributes.set_is_evex_instruction(); + int encode = vex_prefix_and_encode(dst->encoding(), nds->encoding(), src->encoding(), VEX_SIMD_NONE, VEX_OPCODE_MAP5, &attributes); + emit_int16(0x5E, (0xC0 | encode)); +} + +void Assembler::evdivph(XMMRegister dst, XMMRegister nds, Address src, int vector_len) { + assert(VM_Version::supports_avx512_fp16(), "requires AVX512-FP16"); + assert(vector_len == Assembler::AVX_512bit || VM_Version::supports_avx512vl(), ""); + InstructionMark im(this); + InstructionAttr attributes(vector_len, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ true); + attributes.set_is_evex_instruction(); + attributes.set_address_attributes(/* tuple_type */ EVEX_FVM, /* input_size_in_bits */ EVEX_NObit); + vex_prefix(src, nds->encoding(), dst->encoding(), VEX_SIMD_NONE, VEX_OPCODE_MAP5, &attributes); + emit_int8(0x5E); + emit_operand(dst, src, 0); +} + +void Assembler::evsqrtph(XMMRegister dst, XMMRegister src, int vector_len) { + assert(VM_Version::supports_avx512_fp16(), ""); + assert(vector_len == Assembler::AVX_512bit || VM_Version::supports_avx512vl(), ""); + InstructionAttr attributes(vector_len, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ true); + attributes.set_is_evex_instruction(); + int encode = vex_prefix_and_encode(dst->encoding(), 0, src->encoding(), VEX_SIMD_NONE, VEX_OPCODE_MAP5, &attributes); + emit_int16(0x51, (0xC0 | encode)); +} + +void Assembler::evsqrtph(XMMRegister dst, Address src, int vector_len) { + assert(VM_Version::supports_avx512_fp16(), ""); + assert(vector_len == Assembler::AVX_512bit || VM_Version::supports_avx512vl(), ""); + InstructionMark im(this); + InstructionAttr attributes(vector_len, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ true); + attributes.set_is_evex_instruction(); + attributes.set_address_attributes(/* tuple_type */ EVEX_FV, /* input_size_in_bits */ EVEX_NObit); + vex_prefix(src, 0, dst->encoding(), VEX_SIMD_NONE, VEX_OPCODE_MAP5, &attributes); + emit_int8(0x51); + emit_operand(dst, src, 0); +} + +void Assembler::evfmadd132ph(XMMRegister dst, XMMRegister nds, XMMRegister src, int vector_len) { + assert(VM_Version::supports_avx512_fp16(), ""); + assert(vector_len == Assembler::AVX_512bit || VM_Version::supports_avx512vl(), ""); + InstructionAttr attributes(vector_len, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ true); + attributes.set_is_evex_instruction(); + int encode = vex_prefix_and_encode(dst->encoding(), nds->encoding(), src->encoding(), VEX_SIMD_66, VEX_OPCODE_MAP6, &attributes); + emit_int16(0x98, (0xC0 | encode)); +} + +void Assembler::evfmadd132ph(XMMRegister dst, XMMRegister nds, Address src, int vector_len) { + assert(VM_Version::supports_avx512_fp16(), ""); + assert(vector_len == Assembler::AVX_512bit || VM_Version::supports_avx512vl(), ""); + InstructionMark im(this); + InstructionAttr attributes(vector_len, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ true); + attributes.set_is_evex_instruction(); + attributes.set_address_attributes(/* tuple_type */ EVEX_FV, /* input_size_in_bits */ EVEX_NObit); + vex_prefix(src, nds->encoding(), dst->encoding(), VEX_SIMD_66, VEX_OPCODE_MAP6, &attributes); + emit_int8(0x98); + emit_operand(dst, src, 0); +} + +void Assembler::vaddsh(XMMRegister dst, XMMRegister nds, XMMRegister src) { + assert(VM_Version::supports_avx512_fp16(), "requires AVX512-FP16"); + InstructionAttr attributes(AVX_128bit, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ false); + attributes.set_is_evex_instruction(); + int encode = vex_prefix_and_encode(dst->encoding(), nds->encoding(), src->encoding(), VEX_SIMD_F3, VEX_OPCODE_MAP5, &attributes); + emit_int16(0x58, (0xC0 | encode)); +} + +void Assembler::vsubsh(XMMRegister dst, XMMRegister nds, XMMRegister src) { + assert(VM_Version::supports_avx512_fp16(), "requires AVX512-FP16"); + InstructionAttr attributes(AVX_128bit, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ false); + attributes.set_is_evex_instruction(); + int encode = vex_prefix_and_encode(dst->encoding(), nds->encoding(), src->encoding(), VEX_SIMD_F3, VEX_OPCODE_MAP5, &attributes); + emit_int16(0x5C, (0xC0 | encode)); +} + +void Assembler::vdivsh(XMMRegister dst, XMMRegister nds, XMMRegister src) { + assert(VM_Version::supports_avx512_fp16(), "requires AVX512-FP16"); + InstructionAttr attributes(AVX_128bit, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ false); + attributes.set_is_evex_instruction(); + int encode = vex_prefix_and_encode(dst->encoding(), nds->encoding(), src->encoding(), VEX_SIMD_F3, VEX_OPCODE_MAP5, &attributes); + emit_int16(0x5E, (0xC0 | encode)); +} + +void Assembler::vmulsh(XMMRegister dst, XMMRegister nds, XMMRegister src) { + assert(VM_Version::supports_avx512_fp16(), "requires AVX512-FP16"); + InstructionAttr attributes(AVX_128bit, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ false); + attributes.set_is_evex_instruction(); + int encode = vex_prefix_and_encode(dst->encoding(), nds->encoding(), src->encoding(), VEX_SIMD_F3, VEX_OPCODE_MAP5, &attributes); + emit_int16(0x59, (0xC0 | encode)); +} + +void Assembler::vmaxsh(XMMRegister dst, XMMRegister nds, XMMRegister src) { + assert(VM_Version::supports_avx512_fp16(), "requires AVX512-FP16"); + InstructionAttr attributes(AVX_128bit, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ false); + attributes.set_is_evex_instruction(); + int encode = vex_prefix_and_encode(dst->encoding(), nds->encoding(), src->encoding(), VEX_SIMD_F3, VEX_OPCODE_MAP5, &attributes); + emit_int16(0x5F, (0xC0 | encode)); +} + +void Assembler::vminsh(XMMRegister dst, XMMRegister nds, XMMRegister src) { + assert(VM_Version::supports_avx512_fp16(), "requires AVX512-FP16"); + InstructionAttr attributes(AVX_128bit, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ false); + attributes.set_is_evex_instruction(); + int encode = vex_prefix_and_encode(dst->encoding(), nds->encoding(), src->encoding(), VEX_SIMD_F3, VEX_OPCODE_MAP5, &attributes); + emit_int16(0x5D, (0xC0 | encode)); +} + +void Assembler::vsqrtsh(XMMRegister dst, XMMRegister src) { + assert(VM_Version::supports_avx512_fp16(), "requires AVX512-FP16"); + InstructionAttr attributes(AVX_128bit, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ false); + attributes.set_is_evex_instruction(); + int encode = vex_prefix_and_encode(dst->encoding(), 0, src->encoding(), VEX_SIMD_F3, VEX_OPCODE_MAP5, &attributes); + emit_int16(0x51, (0xC0 | encode)); +} + +void Assembler::vfmadd132sh(XMMRegister dst, XMMRegister src1, XMMRegister src2) { + assert(VM_Version::supports_avx512_fp16(), ""); + InstructionAttr attributes(AVX_128bit, /* vex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ false); + attributes.set_is_evex_instruction(); + int encode = vex_prefix_and_encode(dst->encoding(), src1->encoding(), src2->encoding(), VEX_SIMD_66, VEX_OPCODE_MAP6, &attributes); + emit_int16((unsigned char)0x99, (0xC0 | encode)); +} + void Assembler::vpaddsb(XMMRegister dst, XMMRegister nds, XMMRegister src, int vector_len) { assert(UseAVX > 0 && (vector_len == Assembler::AVX_512bit || (!needs_evex(dst, nds, src) || VM_Version::supports_avx512vl())), ""); assert(!needs_evex(dst, nds, src) || VM_Version::supports_avx512bw(), ""); diff --git a/src/hotspot/cpu/x86/assembler_x86.hpp b/src/hotspot/cpu/x86/assembler_x86.hpp index 25be0d6a48d32..1397a70d65219 100644 --- a/src/hotspot/cpu/x86/assembler_x86.hpp +++ b/src/hotspot/cpu/x86/assembler_x86.hpp @@ -585,6 +585,8 @@ class Assembler : public AbstractAssembler { VEX_OPCODE_0F_38 = 0x2, VEX_OPCODE_0F_3A = 0x3, VEX_OPCODE_0F_3C = 0x4, + VEX_OPCODE_MAP5 = 0x5, + VEX_OPCODE_MAP6 = 0x6, VEX_OPCODE_MASK = 0x1F }; @@ -1815,6 +1817,9 @@ class Assembler : public AbstractAssembler { void movsbl(Register dst, Address src); void movsbl(Register dst, Register src); + void vmovw(XMMRegister dst, Register src); + void vmovw(Register dst, XMMRegister src); + #ifdef _LP64 void movsbq(Register dst, Address src); void movsbq(Register dst, Register src); @@ -2691,6 +2696,33 @@ class Assembler : public AbstractAssembler { void vpaddd(XMMRegister dst, XMMRegister nds, Address src, int vector_len); void vpaddq(XMMRegister dst, XMMRegister nds, Address src, int vector_len); + // FP16 instructions + void vaddsh(XMMRegister dst, XMMRegister nds, XMMRegister src); + void vsubsh(XMMRegister dst, XMMRegister nds, XMMRegister src); + void vmulsh(XMMRegister dst, XMMRegister nds, XMMRegister src); + void vdivsh(XMMRegister dst, XMMRegister nds, XMMRegister src); + void vmaxsh(XMMRegister dst, XMMRegister nds, XMMRegister src); + void vminsh(XMMRegister dst, XMMRegister nds, XMMRegister src); + void vsqrtsh(XMMRegister dst, XMMRegister src); + void vfmadd132sh(XMMRegister dst, XMMRegister src1, XMMRegister src2); + + void evaddph(XMMRegister dst, XMMRegister nds, XMMRegister src, int vector_len); + void evaddph(XMMRegister dst, XMMRegister nds, Address src, int vector_len); + void evsubph(XMMRegister dst, XMMRegister nds, XMMRegister src, int vector_len); + void evsubph(XMMRegister dst, XMMRegister nds, Address src, int vector_len); + void evdivph(XMMRegister dst, XMMRegister nds, XMMRegister src, int vector_len); + void evdivph(XMMRegister dst, XMMRegister nds, Address src, int vector_len); + void evmulph(XMMRegister dst, XMMRegister nds, XMMRegister src, int vector_len); + void evmulph(XMMRegister dst, XMMRegister nds, Address src, int vector_len); + void evminph(XMMRegister dst, XMMRegister nds, XMMRegister src, int vector_len); + void evminph(XMMRegister dst, XMMRegister nds, Address src, int vector_len); + void evmaxph(XMMRegister dst, XMMRegister nds, XMMRegister src, int vector_len); + void evmaxph(XMMRegister dst, XMMRegister nds, Address src, int vector_len); + void evfmadd132ph(XMMRegister dst, XMMRegister nds, XMMRegister src, int vector_len); + void evfmadd132ph(XMMRegister dst, XMMRegister nds, Address src, int vector_len); + void evsqrtph(XMMRegister dst, XMMRegister src1, int vector_len); + void evsqrtph(XMMRegister dst, Address src1, int vector_len); + // Saturating packed insturctions. void vpaddsb(XMMRegister dst, XMMRegister nds, XMMRegister src, int vector_len); void vpaddsw(XMMRegister dst, XMMRegister nds, XMMRegister src, int vector_len); diff --git a/src/hotspot/cpu/x86/c2_MacroAssembler_x86.cpp b/src/hotspot/cpu/x86/c2_MacroAssembler_x86.cpp index 008d52dcb1ce9..555aabab7a11a 100644 --- a/src/hotspot/cpu/x86/c2_MacroAssembler_x86.cpp +++ b/src/hotspot/cpu/x86/c2_MacroAssembler_x86.cpp @@ -6668,6 +6668,42 @@ void C2_MacroAssembler::vector_rearrange_int_float(BasicType bt, XMMRegister dst } } +void C2_MacroAssembler::efp16sh(int opcode, XMMRegister dst, XMMRegister src1, XMMRegister src2) { + switch(opcode) { + case Op_AddHF: vaddsh(dst, src1, src2); break; + case Op_SubHF: vsubsh(dst, src1, src2); break; + case Op_MulHF: vmulsh(dst, src1, src2); break; + case Op_DivHF: vdivsh(dst, src1, src2); break; + case Op_MaxHF: vmaxsh(dst, src1, src2); break; + case Op_MinHF: vminsh(dst, src1, src2); break; + default: assert(false, "%s", NodeClassNames[opcode]); break; + } +} + +void C2_MacroAssembler::evfp16ph(int opcode, XMMRegister dst, XMMRegister src1, XMMRegister src2, int vlen_enc) { + switch(opcode) { + case Op_AddVHF: evaddph(dst, src1, src2, vlen_enc); break; + case Op_SubVHF: evsubph(dst, src1, src2, vlen_enc); break; + case Op_MulVHF: evmulph(dst, src1, src2, vlen_enc); break; + case Op_DivVHF: evdivph(dst, src1, src2, vlen_enc); break; + case Op_MaxVHF: evmaxph(dst, src1, src2, vlen_enc); break; + case Op_MinVHF: evminph(dst, src1, src2, vlen_enc); break; + default: assert(false, "%s", NodeClassNames[opcode]); break; + } +} + +void C2_MacroAssembler::evfp16ph(int opcode, XMMRegister dst, XMMRegister src1, Address src2, int vlen_enc) { + switch(opcode) { + case Op_AddVHF: evaddph(dst, src1, src2, vlen_enc); break; + case Op_SubVHF: evsubph(dst, src1, src2, vlen_enc); break; + case Op_MulVHF: evmulph(dst, src1, src2, vlen_enc); break; + case Op_DivVHF: evdivph(dst, src1, src2, vlen_enc); break; + case Op_MaxVHF: evmaxph(dst, src1, src2, vlen_enc); break; + case Op_MinVHF: evminph(dst, src1, src2, vlen_enc); break; + default: assert(false, "%s", NodeClassNames[opcode]); break; + } +} + void C2_MacroAssembler::vector_saturating_op(int ideal_opc, BasicType elem_bt, XMMRegister dst, XMMRegister src1, XMMRegister src2, int vlen_enc) { switch(elem_bt) { case T_BYTE: diff --git a/src/hotspot/cpu/x86/c2_MacroAssembler_x86.hpp b/src/hotspot/cpu/x86/c2_MacroAssembler_x86.hpp index 3a36fd75e3f86..fbfcdbd190ae1 100644 --- a/src/hotspot/cpu/x86/c2_MacroAssembler_x86.hpp +++ b/src/hotspot/cpu/x86/c2_MacroAssembler_x86.hpp @@ -505,6 +505,11 @@ void vector_rearrange_int_float(BasicType bt, XMMRegister dst, XMMRegister shuffle, XMMRegister src, int vlen_enc); + void efp16sh(int opcode, XMMRegister dst, XMMRegister src1, XMMRegister src2); + + void evfp16ph(int opcode, XMMRegister dst, XMMRegister src1, XMMRegister src2, int vlen_enc); + + void evfp16ph(int opcode, XMMRegister dst, XMMRegister src1, Address src2, int vlen_enc); void vgather_subword(BasicType elem_ty, XMMRegister dst, Register base, Register idx_base, Register offset, Register mask, XMMRegister xtmp1, XMMRegister xtmp2, XMMRegister xtmp3, Register rtmp, diff --git a/src/hotspot/cpu/x86/vm_version_x86.cpp b/src/hotspot/cpu/x86/vm_version_x86.cpp index 688cd4fa5a6d1..a167b9feb64d7 100644 --- a/src/hotspot/cpu/x86/vm_version_x86.cpp +++ b/src/hotspot/cpu/x86/vm_version_x86.cpp @@ -1028,6 +1028,7 @@ void VM_Version::get_processor_features() { _features &= ~CPU_AVX512_BITALG; _features &= ~CPU_AVX512_IFMA; _features &= ~CPU_APX_F; + _features &= ~CPU_AVX512_FP16; } // Currently APX support is only enabled for targets supporting AVX512VL feature. @@ -1078,6 +1079,7 @@ void VM_Version::get_processor_features() { _features &= ~CPU_AVX512_BITALG; _features &= ~CPU_AVX512_IFMA; _features &= ~CPU_AVX_IFMA; + _features &= ~CPU_AVX512_FP16; } } @@ -3110,6 +3112,9 @@ uint64_t VM_Version::CpuidInfo::feature_flags() const { } if (sef_cpuid7_edx.bits.serialize != 0) result |= CPU_SERIALIZE; + + if (_cpuid_info.sef_cpuid7_edx.bits.avx512_fp16 != 0) + result |= CPU_AVX512_FP16; } // ZX features. diff --git a/src/hotspot/cpu/x86/vm_version_x86.hpp b/src/hotspot/cpu/x86/vm_version_x86.hpp index 004b64ebe6eb1..d4c7e83bc20bc 100644 --- a/src/hotspot/cpu/x86/vm_version_x86.hpp +++ b/src/hotspot/cpu/x86/vm_version_x86.hpp @@ -276,7 +276,9 @@ class VM_Version : public Abstract_VM_Version { serialize : 1, : 5, cet_ibt : 1, - : 11; + : 2, + avx512_fp16 : 1, + : 8; } bits; }; @@ -416,8 +418,9 @@ class VM_Version : public Abstract_VM_Version { decl(CET_SS, "cet_ss", 57) /* Control Flow Enforcement - Shadow Stack */ \ decl(AVX512_IFMA, "avx512_ifma", 58) /* Integer Vector FMA instructions*/ \ decl(AVX_IFMA, "avx_ifma", 59) /* 256-bit VEX-coded variant of AVX512-IFMA*/ \ - decl(APX_F, "apx_f", 60) /* Intel Advanced Performance Extensions*/\ - decl(SHA512, "sha512", 61) /* SHA512 instructions*/ + decl(APX_F, "apx_f", 60) /* Intel Advanced Performance Extensions*/ \ + decl(SHA512, "sha512", 61) /* SHA512 instructions*/ \ + decl(AVX512_FP16, "avx512_fp16", 62) /* AVX512 FP16 ISA support*/ #define DECLARE_CPU_FEATURE_FLAG(id, name, bit) CPU_##id = (1ULL << bit), CPU_FEATURE_FLAGS(DECLARE_CPU_FEATURE_FLAG) @@ -753,6 +756,7 @@ class VM_Version : public Abstract_VM_Version { static bool supports_avx512_bitalg() { return (_features & CPU_AVX512_BITALG) != 0; } static bool supports_avx512_vbmi() { return (_features & CPU_AVX512_VBMI) != 0; } static bool supports_avx512_vbmi2() { return (_features & CPU_AVX512_VBMI2) != 0; } + static bool supports_avx512_fp16() { return (_features & CPU_AVX512_FP16) != 0; } static bool supports_hv() { return (_features & CPU_HV) != 0; } static bool supports_serialize() { return (_features & CPU_SERIALIZE) != 0; } static bool supports_f16c() { return (_features & CPU_F16C) != 0; } @@ -840,7 +844,7 @@ class VM_Version : public Abstract_VM_Version { // For AVX CPUs only. f16c support is disabled if UseAVX == 0. static bool supports_float16() { - return supports_f16c() || supports_avx512vl(); + return supports_f16c() || supports_avx512vl() || supports_avx512_fp16(); } // Check intrinsic support diff --git a/src/hotspot/cpu/x86/x86.ad b/src/hotspot/cpu/x86/x86.ad index 09023562dd63c..fc030edc8c60a 100644 --- a/src/hotspot/cpu/x86/x86.ad +++ b/src/hotspot/cpu/x86/x86.ad @@ -1461,6 +1461,20 @@ bool Matcher::match_rule_supported(int opcode) { return false; } break; + case Op_AddHF: + case Op_DivHF: + case Op_FmaHF: + case Op_MaxHF: + case Op_MinHF: + case Op_MulHF: + case Op_ReinterpretS2HF: + case Op_ReinterpretHF2S: + case Op_SubHF: + case Op_SqrtHF: + if (!VM_Version::supports_avx512_fp16()) { + return false; + } + break; case Op_VectorLoadShuffle: case Op_VectorRearrange: case Op_MulReductionVI: @@ -1726,6 +1740,18 @@ bool Matcher::match_rule_supported_vector(int opcode, int vlen, BasicType bt) { // * 128bit vroundpd instruction is present only in AVX1 int size_in_bits = vlen * type2aelembytes(bt) * BitsPerByte; switch (opcode) { + case Op_AddVHF: + case Op_DivVHF: + case Op_FmaVHF: + case Op_MaxVHF: + case Op_MinVHF: + case Op_MulVHF: + case Op_SubVHF: + case Op_SqrtVHF: + if (!VM_Version::supports_avx512_fp16()) { + return false; + } + break; case Op_AbsVF: case Op_NegVF: if ((vlen == 16) && (VM_Version::supports_avx512dq() == false)) { @@ -4487,6 +4513,35 @@ instruct vReplS_reg(vec dst, rRegI src) %{ ins_pipe( pipe_slow ); %} +#ifdef _LP64 +instruct ReplHF_imm(vec dst, immH con, rRegI rtmp) %{ + match(Set dst (Replicate con)); + effect(TEMP rtmp); + format %{ "replicateHF $dst, $con \t! using $rtmp as TEMP" %} + ins_encode %{ + int vlen_enc = vector_length_encoding(this); + BasicType bt = Matcher::vector_element_basic_type(this); + assert(VM_Version::supports_avx512_fp16() && bt == T_SHORT, ""); + __ movl($rtmp$$Register, $con$$constant); + __ evpbroadcastw($dst$$XMMRegister, $rtmp$$Register, vlen_enc); + %} + ins_pipe( pipe_slow ); +%} + +instruct ReplHF_reg(vec dst, regF src, rRegI rtmp) %{ + predicate(VM_Version::supports_avx512_fp16() && Matcher::vector_element_basic_type(n) == T_SHORT); + match(Set dst (Replicate src)); + effect(TEMP rtmp); + format %{ "replicateHF $dst, $src \t! using $rtmp as TEMP" %} + ins_encode %{ + int vlen_enc = vector_length_encoding(this); + __ vmovw($rtmp$$Register, $src$$XMMRegister); + __ evpbroadcastw($dst$$XMMRegister, $rtmp$$Register, vlen_enc); + %} + ins_pipe( pipe_slow ); +%} +#endif + instruct ReplS_mem(vec dst, memory mem) %{ predicate(UseAVX >= 2 && Matcher::vector_element_basic_type(n) == T_SHORT); match(Set dst (Replicate (LoadS mem))); @@ -10858,3 +10913,151 @@ instruct vector_selectfrom_twovectors_reg_evex(vec index, vec src1, vec src2) %} ins_pipe(pipe_slow); %} + +instruct reinterpretS2HF(regF dst, rRegI src) +%{ + match(Set dst (ReinterpretS2HF src)); + format %{ "vmovw $dst, $src" %} + ins_encode %{ + __ vmovw($dst$$XMMRegister, $src$$Register); + %} + ins_pipe(pipe_slow); +%} + +instruct convF2HFAndS2HF(regF dst, regF src) +%{ + match(Set dst (ReinterpretS2HF (ConvF2HF src))); + format %{ "convF2HFAndS2HF $dst, $src" %} + ins_encode %{ + __ vcvtps2ph($dst$$XMMRegister, $src$$XMMRegister, 0x04, Assembler::AVX_128bit); + %} + ins_pipe(pipe_slow); +%} + +instruct reinterpretHF2S(rRegI dst, regF src) +%{ + match(Set dst (ReinterpretHF2S src)); + format %{ "vmovw $dst, $src" %} + ins_encode %{ + __ vmovw($dst$$Register, $src$$XMMRegister); + %} + ins_pipe(pipe_slow); +%} + +instruct scalar_sqrt_HF_reg(regF dst, regF src) +%{ + match(Set dst (SqrtHF src)); + format %{ "scalar_sqrt_fp16 $dst, $src" %} + ins_encode %{ + __ vsqrtsh($dst$$XMMRegister, $src$$XMMRegister); + %} + ins_pipe(pipe_slow); +%} + +instruct scalar_binOps_HF_reg(regF dst, regF src1, regF src2) +%{ + match(Set dst (AddHF src1 src2)); + match(Set dst (DivHF src1 src2)); + match(Set dst (MaxHF src1 src2)); + match(Set dst (MinHF src1 src2)); + match(Set dst (MulHF src1 src2)); + match(Set dst (SubHF src1 src2)); + format %{ "scalar_binop_fp16 $dst, $src1, $src2" %} + ins_encode %{ + int opcode = this->ideal_Opcode(); + __ efp16sh(opcode, $dst$$XMMRegister, $src1$$XMMRegister, $src2$$XMMRegister); + %} + ins_pipe(pipe_slow); +%} + +instruct scalar_fma_HF_reg(regF dst, regF src1, regF src2) +%{ + match(Set dst (FmaHF src2 (Binary dst src1))); + effect(DEF dst); + format %{ "scalar_fma_fp16 $dst, $src1, $src2\t# $dst = $dst * $src1 + $src2 fma packedH" %} + ins_encode %{ + __ vfmadd132sh($dst$$XMMRegister, $src2$$XMMRegister, $src1$$XMMRegister); + %} + ins_pipe( pipe_slow ); +%} + +instruct vector_sqrt_HF_reg(vec dst, vec src) +%{ + match(Set dst (SqrtVHF src)); + format %{ "vector_sqrt_fp16 $dst, $src" %} + ins_encode %{ + int vlen_enc = vector_length_encoding(this); + __ evsqrtph($dst$$XMMRegister, $src$$XMMRegister, vlen_enc); + %} + ins_pipe(pipe_slow); +%} + +instruct vector_sqrt_HF_mem(vec dst, memory src) +%{ + match(Set dst (SqrtVHF (VectorReinterpret (LoadVector src)))); + format %{ "vector_sqrt_fp16_mem $dst, $src" %} + ins_encode %{ + int vlen_enc = vector_length_encoding(this); + __ evsqrtph($dst$$XMMRegister, $src$$Address, vlen_enc); + %} + ins_pipe(pipe_slow); +%} + +instruct vector_binOps_HF_reg(vec dst, vec src1, vec src2) +%{ + match(Set dst (AddVHF src1 src2)); + match(Set dst (DivVHF src1 src2)); + match(Set dst (MaxVHF src1 src2)); + match(Set dst (MinVHF src1 src2)); + match(Set dst (MulVHF src1 src2)); + match(Set dst (SubVHF src1 src2)); + format %{ "vector_binop_fp16 $dst, $src1, $src2" %} + ins_encode %{ + int vlen_enc = vector_length_encoding(this); + int opcode = this->ideal_Opcode(); + __ evfp16ph(opcode, $dst$$XMMRegister, $src1$$XMMRegister, $src2$$XMMRegister, vlen_enc); + %} + ins_pipe(pipe_slow); +%} + +instruct vector_binOps_HF_mem(vec dst, vec src1, memory src2) +%{ + match(Set dst (AddVHF src1 (VectorReinterpret (LoadVector src2)))); + match(Set dst (DivVHF src1 (VectorReinterpret (LoadVector src2)))); + match(Set dst (MaxVHF src1 (VectorReinterpret (LoadVector src2)))); + match(Set dst (MinVHF src1 (VectorReinterpret (LoadVector src2)))); + match(Set dst (MulVHF src1 (VectorReinterpret (LoadVector src2)))); + match(Set dst (SubVHF src1 (VectorReinterpret (LoadVector src2)))); + format %{ "vector_binop_fp16_mem $dst, $src1, $src2" %} + ins_encode %{ + int vlen_enc = vector_length_encoding(this); + int opcode = this->ideal_Opcode(); + __ evfp16ph(opcode, $dst$$XMMRegister, $src1$$XMMRegister, $src2$$Address, vlen_enc); + %} + ins_pipe(pipe_slow); +%} + + +instruct vector_fma_HF_reg(vec dst, vec src1, vec src2) +%{ + match(Set dst (FmaVHF src2 (Binary dst src1))); + effect(DEF dst); + format %{ "vector_fma_fp16 $dst, $src1, $src2\t# $dst = $dst * $src1 + $src2 fma packedH" %} + ins_encode %{ + int vlen_enc = vector_length_encoding(this); + __ evfmadd132ph($dst$$XMMRegister, $src2$$XMMRegister, $src1$$XMMRegister, vlen_enc); + %} + ins_pipe( pipe_slow ); +%} + +instruct vector_fma_HF_mem(vec dst, memory src1, vec src2) +%{ + match(Set dst (FmaVHF src2 (Binary dst (VectorReinterpret (LoadVector src1))))); + effect(DEF dst); + format %{ "vector_fma_fp16_mem $dst, $src1, $src2\t# $dst = $dst * $src1 + $src2 fma packedH" %} + ins_encode %{ + int vlen_enc = vector_length_encoding(this); + __ evfmadd132ph($dst$$XMMRegister, $src2$$XMMRegister, $src1$$Address, vlen_enc); + %} + ins_pipe( pipe_slow ); +%} diff --git a/src/hotspot/cpu/x86/x86_64.ad b/src/hotspot/cpu/x86/x86_64.ad index 4667922505cbc..3d1a311eac19f 100644 --- a/src/hotspot/cpu/x86/x86_64.ad +++ b/src/hotspot/cpu/x86/x86_64.ad @@ -2382,6 +2382,16 @@ operand immF() interface(CONST_INTER); %} +// Half Float Immediate +operand immH() +%{ + match(ConH); + + op_cost(15); + format %{ %} + interface(CONST_INTER); +%} + // Double Immediate zero operand immD0() %{ @@ -4840,6 +4850,16 @@ instruct loadConF(regF dst, immF con) %{ ins_pipe(pipe_slow); %} +instruct loadConH(regF dst, immH con) %{ + match(Set dst con); + ins_cost(125); + format %{ "movss $dst, [$constantaddress]\t# load from constant table: halffloat=$con" %} + ins_encode %{ + __ movflt($dst$$XMMRegister, $constantaddress($con)); + %} + ins_pipe(pipe_slow); +%} + instruct loadConN0(rRegN dst, immN0 src, rFlagsReg cr) %{ match(Set dst src); effect(KILL cr); @@ -7022,6 +7042,17 @@ instruct castFF(regF dst) ins_pipe(empty); %} +instruct castHH(regF dst) +%{ + match(Set dst (CastHH dst)); + + size(0); + format %{ "# castHH of $dst" %} + ins_encode(/* empty encoding */); + ins_cost(0); + ins_pipe(empty); +%} + instruct castDD(regD dst) %{ match(Set dst (CastDD dst)); diff --git a/src/hotspot/share/adlc/archDesc.cpp b/src/hotspot/share/adlc/archDesc.cpp index f084f506bf587..237e0731db362 100644 --- a/src/hotspot/share/adlc/archDesc.cpp +++ b/src/hotspot/share/adlc/archDesc.cpp @@ -1053,6 +1053,7 @@ const char *ArchDesc::getIdealType(const char *idealOp) { case 'P': return "TypePtr::BOTTOM"; case 'N': return "TypeNarrowOop::BOTTOM"; case 'F': return "Type::FLOAT"; + case 'H': return "Type::HALF_FLOAT"; case 'D': return "Type::DOUBLE"; case 'L': return "TypeLong::LONG"; case 's': return "TypeInt::CC /*flags*/"; @@ -1090,7 +1091,7 @@ void ArchDesc::initBaseOpTypes() { char *ident = (char *)NodeClassNames[j]; if (!strcmp(ident, "ConI") || !strcmp(ident, "ConP") || !strcmp(ident, "ConN") || !strcmp(ident, "ConNKlass") || - !strcmp(ident, "ConF") || !strcmp(ident, "ConD") || + !strcmp(ident, "ConH") || !strcmp(ident, "ConF") || !strcmp(ident, "ConD") || !strcmp(ident, "ConL") || !strcmp(ident, "Con" ) || !strcmp(ident, "Bool")) { constructOperand(ident, true); diff --git a/src/hotspot/share/adlc/forms.cpp b/src/hotspot/share/adlc/forms.cpp index c34a73ea1e13f..16b96d89e464b 100644 --- a/src/hotspot/share/adlc/forms.cpp +++ b/src/hotspot/share/adlc/forms.cpp @@ -220,6 +220,7 @@ Form::DataType Form::ideal_to_const_type(const char *name) const { if (strcmp(name,"ConNKlass")==0) return Form::idealNKlass; if (strcmp(name,"ConL")==0) return Form::idealL; if (strcmp(name,"ConF")==0) return Form::idealF; + if (strcmp(name,"ConH")==0) return Form::idealH; if (strcmp(name,"ConD")==0) return Form::idealD; if (strcmp(name,"Bool")==0) return Form::idealI; diff --git a/src/hotspot/share/adlc/forms.hpp b/src/hotspot/share/adlc/forms.hpp index a82b9bbb3382d..77479c7fb5228 100644 --- a/src/hotspot/share/adlc/forms.hpp +++ b/src/hotspot/share/adlc/forms.hpp @@ -183,7 +183,8 @@ class Form { idealS = 8, // String type idealN = 9, // Narrow oop types idealNKlass = 10, // Narrow klass types - idealV = 11 // Vector type + idealV = 11, // Vector type + idealH = 12 // HalfFloat type }; // Convert ideal name to a DataType, return DataType::none if not a 'ConX' Form::DataType ideal_to_const_type(const char *ideal_type_name) const; diff --git a/src/hotspot/share/adlc/formssel.cpp b/src/hotspot/share/adlc/formssel.cpp index dfa414ef56484..35ecd3dfb5dda 100644 --- a/src/hotspot/share/adlc/formssel.cpp +++ b/src/hotspot/share/adlc/formssel.cpp @@ -1088,7 +1088,7 @@ uint InstructForm::reloc(FormDict &globals) { } else if ( oper ) { // floats and doubles loaded out of method's constant pool require reloc info Form::DataType type = oper->is_base_constant(globals); - if ( (type == Form::idealF) || (type == Form::idealD) ) { + if ( (type == Form::idealH) || (type == Form::idealF) || (type == Form::idealD) ) { ++reloc_entries; } } @@ -1099,7 +1099,7 @@ uint InstructForm::reloc(FormDict &globals) { // !!!!! // Check for any component being an immediate float or double. Form::DataType data_type = is_chain_of_constant(globals); - if( data_type==idealD || data_type==idealF ) { + if( data_type==idealH || data_type==idealD || data_type==idealF ) { reloc_entries++; } @@ -2662,6 +2662,7 @@ void OperandForm::format_constant(FILE *fp, uint const_index, uint const_type) { case Form::idealN: fprintf(fp," if (_c%d) _c%d->dump_on(st);\n", const_index, const_index); break; case Form::idealL: fprintf(fp," st->print(\"#\" INT64_FORMAT, (int64_t)_c%d);\n", const_index); break; case Form::idealF: fprintf(fp," st->print(\"#%%f\", _c%d);\n", const_index); break; + case Form::idealH: fprintf(fp," st->print(\"#%%d\", _c%d);\n", const_index); break; case Form::idealD: fprintf(fp," st->print(\"#%%f\", _c%d);\n", const_index); break; default: assert( false, "ShouldNotReachHere()"); @@ -2743,6 +2744,7 @@ void OperandForm::access_constant(FILE *fp, FormDict &globals, case idealP: fprintf(fp,"_c%d->get_con()",const_index); break; case idealL: fprintf(fp,"_c%d", const_index); break; case idealF: fprintf(fp,"_c%d", const_index); break; + case idealH: fprintf(fp,"_c%d", const_index); break; case idealD: fprintf(fp,"_c%d", const_index); break; default: assert( false, "ShouldNotReachHere()"); @@ -3953,19 +3955,20 @@ bool MatchNode::equivalent(FormDict &globals, MatchNode *mNode2) { // which could be swapped. void MatchNode::count_commutative_op(int& count) { static const char *commut_op_list[] = { - "AddI","AddL","AddF","AddD", + "AddI","AddL","AddHF","AddF","AddD", "AndI","AndL", - "MaxI","MinI","MaxF","MinF","MaxD","MinD", - "MulI","MulL","MulF","MulD", - "OrI","OrL", "XorI","XorL", + "MaxI","MinI","MaxHF","MinHF","MaxF","MinF","MaxD","MinD", + "MulI","MulL","MulHF","MulF","MulD", + "OrI","OrL", + "XorI","XorL" "UMax","UMin" }; static const char *commut_vector_op_list[] = { - "AddVB", "AddVS", "AddVI", "AddVL", "AddVF", "AddVD", - "MulVB", "MulVS", "MulVI", "MulVL", "MulVF", "MulVD", + "AddVB", "AddVS", "AddVI", "AddVL", "AddVHF", "AddVF", "AddVD", + "MulVB", "MulVS", "MulVI", "MulVL", "MulVHF", "MulVF", "MulVD", "AndV", "OrV", "XorV", - "MaxV", "MinV", "UMax","UMin" + "MaxVHF", "MinVHF", "MaxV", "MinV", "UMax","UMin" }; if (_lChild && _rChild && (_lChild->_lChild || _rChild->_lChild)) { @@ -4192,6 +4195,7 @@ int MatchRule::is_expensive() const { const char *opType = _rChild->_opType; if( strcmp(opType,"AtanD")==0 || strcmp(opType,"DivD")==0 || + strcmp(opType,"DivHF")==0 || strcmp(opType,"DivF")==0 || strcmp(opType,"DivI")==0 || strcmp(opType,"Log10D")==0 || @@ -4200,6 +4204,7 @@ int MatchRule::is_expensive() const { strcmp(opType,"ModI")==0 || strcmp(opType,"SqrtD")==0 || strcmp(opType,"SqrtF")==0 || + strcmp(opType,"SqrtHF")==0 || strcmp(opType,"TanD")==0 || strcmp(opType,"ConvD2F")==0 || strcmp(opType,"ConvD2I")==0 || @@ -4219,6 +4224,7 @@ int MatchRule::is_expensive() const { strcmp(opType,"DecodeNKlass")==0 || strcmp(opType,"FmaD") == 0 || strcmp(opType,"FmaF") == 0 || + strcmp(opType,"FmaHF") == 0 || strcmp(opType,"RoundDouble")==0 || strcmp(opType,"RoundDoubleMode")==0 || strcmp(opType,"RoundFloat")==0 || @@ -4331,15 +4337,15 @@ Form::DataType MatchRule::is_ideal_load() const { bool MatchRule::is_vector() const { static const char *vector_list[] = { - "AddVB","AddVS","AddVI","AddVL","AddVF","AddVD", - "SubVB","SubVS","SubVI","SubVL","SubVF","SubVD", - "MulVB","MulVS","MulVI","MulVL","MulVF","MulVD", - "DivVF","DivVD", + "AddVB","AddVHF", "AddVS","AddVI","AddVL","AddVF","AddVD", + "SubVB","SubVS","SubVI","SubVL", "SubVHF", "SubVF","SubVD", + "MulVB","MulVS","MulVI","MulVL", "MulVHF", "MulVF","MulVD", + "DivVHF","DivVF","DivVD", "AbsVB","AbsVS","AbsVI","AbsVL","AbsVF","AbsVD", "NegVF","NegVD","NegVI","NegVL", - "SqrtVD","SqrtVF", + "SqrtVD","SqrtVF","SqrtVHF", "AndV" ,"XorV" ,"OrV", - "MaxV", "MinV", "UMinV", "UMaxV", + "MaxV", "MinV", "MinVHF", "MaxVHF", "UMinV", "UMaxV", "CompressV", "ExpandV", "CompressM", "CompressBitsV", "ExpandBitsV", "AddReductionVI", "AddReductionVL", "AddReductionVF", "AddReductionVD", @@ -4360,8 +4366,8 @@ bool MatchRule::is_vector() const { "VectorCastB2X", "VectorCastS2X", "VectorCastI2X", "VectorCastL2X", "VectorCastF2X", "VectorCastD2X", "VectorCastF2HF", "VectorCastHF2F", "VectorUCastB2X", "VectorUCastS2X", "VectorUCastI2X", - "VectorMaskWrapper","VectorMaskCmp","VectorReinterpret","LoadVectorMasked","StoreVectorMasked", - "FmaVD","FmaVF","PopCountVI","PopCountVL","PopulateIndex","VectorLongToMask", + "VectorMaskWrapper", "VectorMaskCmp", "VectorReinterpret", "LoadVectorMasked", "StoreVectorMasked", + "FmaVD", "FmaVF", "FmaVHF", "PopCountVI", "PopCountVL", "PopulateIndex", "VectorLongToMask", "CountLeadingZerosV", "CountTrailingZerosV", "SignumVF", "SignumVD", "SaturatingAddV", "SaturatingSubV", // Next are vector mask ops. "MaskAll", "AndVMask", "OrVMask", "XorVMask", "VectorMaskCast", diff --git a/src/hotspot/share/adlc/output_c.cpp b/src/hotspot/share/adlc/output_c.cpp index cc6ed278b4901..f1f4398025513 100644 --- a/src/hotspot/share/adlc/output_c.cpp +++ b/src/hotspot/share/adlc/output_c.cpp @@ -2421,6 +2421,8 @@ class DefineEmitState { if( _constant_status == LITERAL_NOT_SEEN ) { if ( _constant_type == Form::idealD ) { fprintf(_fp,"->constantD()"); + } else if ( _constant_type == Form::idealH ) { + fprintf(_fp,"->constantH()"); } else if ( _constant_type == Form::idealF ) { fprintf(_fp,"->constantF()"); } else if ( _constant_type == Form::idealL ) { @@ -3789,6 +3791,8 @@ static void path_to_constant(FILE *fp, FormDict &globals, fprintf(fp, "_leaf->bottom_type()->is_narrowoop()"); } else if ( (strcmp(optype,"ConNKlass") == 0) ) { fprintf(fp, "_leaf->bottom_type()->is_narrowklass()"); + } else if ( (strcmp(optype,"ConH") == 0) ) { + fprintf(fp, "_leaf->geth()"); } else if ( (strcmp(optype,"ConF") == 0) ) { fprintf(fp, "_leaf->getf()"); } else if ( (strcmp(optype,"ConD") == 0) ) { diff --git a/src/hotspot/share/adlc/output_h.cpp b/src/hotspot/share/adlc/output_h.cpp index d6767bc1f7efc..486534875c570 100644 --- a/src/hotspot/share/adlc/output_h.cpp +++ b/src/hotspot/share/adlc/output_h.cpp @@ -233,6 +233,10 @@ static void declareConstStorage(FILE *fp, FormDict &globals, OperandForm *oper) if (i > 0) fprintf(fp,", "); fprintf(fp," jfloat _c%d;\n", i); } + else if (!strcmp(type, "ConH")) { + if (i > 0) fprintf(fp,", "); + fprintf(fp," jshort _c%d;\n", i); + } else if (!strcmp(type, "ConD")) { if (i > 0) fprintf(fp,", "); fprintf(fp," jdouble _c%d;\n", i); @@ -269,6 +273,10 @@ static void declareConstStorage(FILE *fp, FormDict &globals, OperandForm *oper) fprintf(fp," jlong _c%d;\n", i); i++; } + else if (!strcmp(comp->base_type(globals), "ConH")) { + fprintf(fp," jshort _c%d;\n", i); + i++; + } else if (!strcmp(comp->base_type(globals), "ConF")) { fprintf(fp," jfloat _c%d;\n", i); i++; @@ -314,6 +322,7 @@ static void defineConstructor(FILE *fp, const char *name, uint num_consts, case Form::idealNKlass : { fprintf(fp,"const TypeNarrowKlass *c%d", i); break; } case Form::idealP : { fprintf(fp,"const TypePtr *c%d", i); break; } case Form::idealL : { fprintf(fp,"jlong c%d", i); break; } + case Form::idealH : { fprintf(fp,"jshort c%d", i); break; } case Form::idealF : { fprintf(fp,"jfloat c%d", i); break; } case Form::idealD : { fprintf(fp,"jdouble c%d", i); break; } default: @@ -403,6 +412,11 @@ static uint dump_spec_constant(FILE *fp, const char *ideal_type, uint i, Operand fprintf(fp," st->print(\"/0x%%08x\", _c%d);\n", i); ++i; } + else if (!strcmp(ideal_type, "ConH")) { + fprintf(fp," st->print(\"#%%d\", _c%d);\n", i); + fprintf(fp," st->print(\"/0x%%08x\", _c%d);\n", i); + ++i; + } else if (!strcmp(ideal_type, "ConP")) { fprintf(fp," _c%d->dump_on(st);\n", i); ++i; @@ -1281,6 +1295,7 @@ void ArchDesc::declareClasses(FILE *fp) { case Form::idealF: type = "Type::FLOAT"; break; case Form::idealD: type = "Type::DOUBLE"; break; case Form::idealL: type = "TypeLong::LONG"; break; + case Form::idealH: type = "Type::HALF_FLOAT"; break; case Form::none: // fall through default: assert( false, "No support for this type of stackSlot"); @@ -1425,6 +1440,14 @@ void ArchDesc::declareClasses(FILE *fp) { fprintf(fp, " return _c0;"); fprintf(fp, " }\n"); } + else if (!strcmp(oper->ideal_type(_globalNames), "ConH")) { + fprintf(fp," virtual intptr_t constant() const {"); + fprintf(fp, " ShouldNotReachHere(); return 0; "); + fprintf(fp, " }\n"); + fprintf(fp," virtual jshort constantH() const {"); + fprintf(fp, " return (jshort)_c0;"); + fprintf(fp, " }\n"); + } else if (!strcmp(oper->ideal_type(_globalNames), "ConF")) { fprintf(fp," virtual intptr_t constant() const {"); fprintf(fp, " ShouldNotReachHere(); return 0; "); @@ -1897,6 +1920,9 @@ void ArchDesc::declareClasses(FILE *fp) { case Form::idealD: fprintf(fp," return TypeD::make(opnd_array(1)->constantD());\n"); break; + case Form::idealH: + fprintf(fp," return TypeH::make(opnd_array(1)->constantH());\n"); + break; case Form::idealF: fprintf(fp," return TypeF::make(opnd_array(1)->constantF());\n"); break; diff --git a/src/hotspot/share/classfile/vmIntrinsics.hpp b/src/hotspot/share/classfile/vmIntrinsics.hpp index 68121c56c328f..270d86c5f4a6b 100644 --- a/src/hotspot/share/classfile/vmIntrinsics.hpp +++ b/src/hotspot/share/classfile/vmIntrinsics.hpp @@ -926,19 +926,26 @@ class methodHandle; do_signature(getAndAddShort_signature, "(Ljava/lang/Object;JS)S" ) \ do_intrinsic(_getAndSetInt, jdk_internal_misc_Unsafe, getAndSetInt_name, getAndSetInt_signature, F_R) \ do_name( getAndSetInt_name, "getAndSetInt") \ - do_alias( getAndSetInt_signature, /*"(Ljava/lang/Object;JI)I"*/ getAndAddInt_signature) \ + do_alias( getAndSetInt_signature, /*"(Ljava/lang/Object;JI)I"*/ getAndAddInt_signature) \ do_intrinsic(_getAndSetLong, jdk_internal_misc_Unsafe, getAndSetLong_name, getAndSetLong_signature, F_R) \ do_name( getAndSetLong_name, "getAndSetLong") \ - do_alias( getAndSetLong_signature, /*"(Ljava/lang/Object;JJ)J"*/ getAndAddLong_signature) \ + do_alias( getAndSetLong_signature, /*"(Ljava/lang/Object;JJ)J"*/ getAndAddLong_signature)\ do_intrinsic(_getAndSetByte, jdk_internal_misc_Unsafe, getAndSetByte_name, getAndSetByte_signature, F_R) \ do_name( getAndSetByte_name, "getAndSetByte") \ - do_alias( getAndSetByte_signature, /*"(Ljava/lang/Object;JB)B"*/ getAndAddByte_signature) \ + do_alias( getAndSetByte_signature, /*"(Ljava/lang/Object;JB)B"*/ getAndAddByte_signature)\ do_intrinsic(_getAndSetShort, jdk_internal_misc_Unsafe, getAndSetShort_name, getAndSetShort_signature, F_R) \ - do_name( getAndSetShort_name, "getAndSetShort") \ - do_alias( getAndSetShort_signature, /*"(Ljava/lang/Object;JS)S"*/ getAndAddShort_signature) \ - do_intrinsic(_getAndSetReference, jdk_internal_misc_Unsafe, getAndSetReference_name, getAndSetReference_signature, F_R) \ - do_name( getAndSetReference_name, "getAndSetReference") \ + do_name( getAndSetShort_name, "getAndSetShort") \ + do_alias( getAndSetShort_signature, /*"(Ljava/lang/Object;JS)S"*/ getAndAddShort_signature) \ + do_intrinsic(_getAndSetReference, jdk_internal_misc_Unsafe, getAndSetReference_name, getAndSetReference_signature, F_R) \ + do_name( getAndSetReference_name, "getAndSetReference") \ do_signature(getAndSetReference_signature, "(Ljava/lang/Object;JLjava/lang/Object;)Ljava/lang/Object;" ) \ + \ + /* Float16Math API intrinsification support */ \ + /* Float16 signatures */ \ + do_signature(float16_unary_math_op_sig, "(SLjdk/internal/vm/vector/Float16Math$Float16UnaryMathOp;)S") \ + do_signature(float16_ternary_math_op_sig, "(SSSLjdk/internal/vm/vector/Float16Math$Float16TernaryMathOp;)S") \ + do_intrinsic(_sqrt_float16, jdk_internal_vm_vector_Float16Math, sqrt_name, float16_unary_math_op_sig, F_S) \ + do_intrinsic(_fma_float16, jdk_internal_vm_vector_Float16Math, fma_name, float16_ternary_math_op_sig, F_S) \ \ /* Vector API intrinsification support */ \ \ diff --git a/src/hotspot/share/classfile/vmSymbols.hpp b/src/hotspot/share/classfile/vmSymbols.hpp index 46c156a54452a..799ca76a863a1 100644 --- a/src/hotspot/share/classfile/vmSymbols.hpp +++ b/src/hotspot/share/classfile/vmSymbols.hpp @@ -91,7 +91,8 @@ class SerializeClosure; template(java_lang_Long_LongCache, "java/lang/Long$LongCache") \ template(java_lang_Void, "java/lang/Void") \ \ - template(jdk_internal_vm_vector_VectorSupport, "jdk/internal/vm/vector/VectorSupport") \ + template(jdk_internal_vm_vector_VectorSupport, "jdk/internal/vm/vector/VectorSupport") \ + template(jdk_internal_vm_vector_Float16Math, "jdk/internal/vm/vector/Float16Math") \ template(jdk_internal_vm_vector_VectorPayload, "jdk/internal/vm/vector/VectorSupport$VectorPayload") \ template(jdk_internal_vm_vector_Vector, "jdk/internal/vm/vector/VectorSupport$Vector") \ template(jdk_internal_vm_vector_VectorMask, "jdk/internal/vm/vector/VectorSupport$VectorMask") \ diff --git a/src/hotspot/share/opto/addnode.cpp b/src/hotspot/share/opto/addnode.cpp index 4215124b0d472..40a43a0ee738b 100644 --- a/src/hotspot/share/opto/addnode.cpp +++ b/src/hotspot/share/opto/addnode.cpp @@ -33,6 +33,7 @@ #include "opto/mulnode.hpp" #include "opto/phaseX.hpp" #include "opto/subnode.hpp" +#include "runtime/stubRoutines.hpp" // Portions of code courtesy of Clifford Click @@ -706,6 +707,22 @@ Node *AddFNode::Ideal(PhaseGVN *phase, bool can_reshape) { return commute(phase, this) ? this : nullptr; } +//============================================================================= +//------------------------------add_of_identity-------------------------------- +// Check for addition of the identity +const Type *AddHFNode::add_of_identity(const Type* t1, const Type* t2) const { + return nullptr; +} + +// Supplied function returns the sum of the inputs. +// This also type-checks the inputs for sanity. Guaranteed never to +// be passed a TOP or BOTTOM type, these are filtered out by pre-check. +const Type* AddHFNode::add_ring(const Type* t0, const Type* t1) const { + if (!t0->isa_half_float_constant() || !t1->isa_half_float_constant()) { + return bottom_type(); + } + return TypeH::make(t0->getf() + t1->getf()); +} //============================================================================= //------------------------------add_of_identity-------------------------------- @@ -1582,6 +1599,33 @@ Node* MaxNode::Identity(PhaseGVN* phase) { return AddNode::Identity(phase); } +//------------------------------add_ring--------------------------------------- +const Type* MinHFNode::add_ring(const Type* t0, const Type* t1) const { + const TypeH* r0 = t0->isa_half_float_constant(); + const TypeH* r1 = t1->isa_half_float_constant(); + if (r0 == nullptr || r1 == nullptr) { + return bottom_type(); + } + + if (r0->is_nan()) { + return r0; + } + if (r1->is_nan()) { + return r1; + } + + float f0 = r0->getf(); + float f1 = r1->getf(); + if (f0 != 0.0f || f1 != 0.0f) { + return f0 < f1 ? r0 : r1; + } + + // As per IEEE 754 specification, floating point comparison consider +ve and -ve + // zeros as equals. Thus, performing signed integral comparison for max value + // detection. + return (jint_cast(f0) < jint_cast(f1)) ? r0 : r1; +} + //------------------------------add_ring--------------------------------------- const Type* MinFNode::add_ring(const Type* t0, const Type* t1 ) const { const TypeF* r0 = t0->isa_float_constant(); @@ -1603,7 +1647,6 @@ const Type* MinFNode::add_ring(const Type* t0, const Type* t1 ) const { return f0 < f1 ? r0 : r1; } - // handle min of 0.0, -0.0 case. return (jint_cast(f0) < jint_cast(f1)) ? r0 : r1; } @@ -1632,6 +1675,34 @@ const Type* MinDNode::add_ring(const Type* t0, const Type* t1) const { return (jlong_cast(d0) < jlong_cast(d1)) ? r0 : r1; } +//------------------------------add_ring--------------------------------------- +const Type* MaxHFNode::add_ring(const Type* t0, const Type* t1) const { + const TypeH* r0 = t0->isa_half_float_constant(); + const TypeH* r1 = t1->isa_half_float_constant(); + if (r0 == nullptr || r1 == nullptr) { + return bottom_type(); + } + + if (r0->is_nan()) { + return r0; + } + if (r1->is_nan()) { + return r1; + } + + float f0 = r0->getf(); + float f1 = r1->getf(); + if (f0 != 0.0f || f1 != 0.0f) { + return f0 > f1 ? r0 : r1; + } + + // As per IEEE 754 specification, floating point comparison consider +ve and -ve + // zeros as equals. Thus, performing signed integral comparison for min value + // detection. + return (jint_cast(f0) > jint_cast(f1)) ? r0 : r1; +} + + //------------------------------add_ring--------------------------------------- const Type* MaxFNode::add_ring(const Type* t0, const Type* t1) const { const TypeF* r0 = t0->isa_float_constant(); diff --git a/src/hotspot/share/opto/addnode.hpp b/src/hotspot/share/opto/addnode.hpp index 8afbb440572bf..b7626fbc89977 100644 --- a/src/hotspot/share/opto/addnode.hpp +++ b/src/hotspot/share/opto/addnode.hpp @@ -162,6 +162,22 @@ class AddDNode : public AddNode { virtual uint ideal_reg() const { return Op_RegD; } }; +//------------------------------AddHFNode--------------------------------------- +// Add 2 half-precision floats +class AddHFNode : public AddNode { +public: + AddHFNode(Node* in1, Node* in2) : AddNode(in1,in2) {} + virtual int Opcode() const; + virtual const Type* add_of_identity(const Type* t1, const Type* t2) const; + virtual const Type* add_ring(const Type*, const Type*) const; + virtual const Type* add_id() const { return TypeH::ZERO; } + virtual const Type* bottom_type() const { return Type::HALF_FLOAT; } + int max_opcode() const { return Op_MaxHF; } + int min_opcode() const { return Op_MinHF; } + virtual Node* Identity(PhaseGVN* phase) { return this; } + virtual uint ideal_reg() const { return Op_RegF; } +}; + //------------------------------AddPNode--------------------------------------- // Add pointer plus integer to get pointer. NOT commutative, really. // So not really an AddNode. Lives here, because people associate it with @@ -396,16 +412,44 @@ class MaxFNode : public MaxNode { // Minimum of 2 floats. class MinFNode : public MaxNode { public: - MinFNode(Node *in1, Node *in2) : MaxNode(in1, in2) {} + MinFNode(Node* in1, Node* in2) : MaxNode(in1, in2) {} virtual int Opcode() const; - virtual const Type *add_ring(const Type*, const Type*) const; - virtual const Type *add_id() const { return TypeF::POS_INF; } - virtual const Type *bottom_type() const { return Type::FLOAT; } + virtual const Type* add_ring(const Type*, const Type*) const; + virtual const Type* add_id() const { return TypeF::POS_INF; } + virtual const Type* bottom_type() const { return Type::FLOAT; } virtual uint ideal_reg() const { return Op_RegF; } int max_opcode() const { return Op_MaxF; } int min_opcode() const { return Op_MinF; } }; +//------------------------------MaxHFNode-------------------------------------- +// Maximum of 2 half floats. +class MaxHFNode : public MaxNode { +public: + MaxHFNode(Node* in1, Node* in2) : MaxNode(in1, in2) {} + virtual int Opcode() const; + virtual const Type* add_ring(const Type*, const Type*) const; + virtual const Type* add_id() const { return TypeH::NEG_INF; } + virtual const Type* bottom_type() const { return Type::HALF_FLOAT; } + virtual uint ideal_reg() const { return Op_RegF; } + int max_opcode() const { return Op_MaxHF; } + int min_opcode() const { return Op_MinHF; } +}; + +//------------------------------MinHFNode--------------------------------------- +// Minimum of 2 half floats. +class MinHFNode : public MaxNode { +public: + MinHFNode(Node* in1, Node* in2) : MaxNode(in1, in2) {} + virtual int Opcode() const; + virtual const Type* add_ring(const Type*, const Type*) const; + virtual const Type* add_id() const { return TypeH::POS_INF; } + virtual const Type* bottom_type() const { return Type::HALF_FLOAT; } + virtual uint ideal_reg() const { return Op_RegF; } + int max_opcode() const { return Op_MaxHF; } + int min_opcode() const { return Op_MinHF; } +}; + //------------------------------MaxDNode--------------------------------------- // Maximum of 2 doubles. class MaxDNode : public MaxNode { diff --git a/src/hotspot/share/opto/c2compiler.cpp b/src/hotspot/share/opto/c2compiler.cpp index 138fd38bfa4eb..7b448ce725111 100644 --- a/src/hotspot/share/opto/c2compiler.cpp +++ b/src/hotspot/share/opto/c2compiler.cpp @@ -353,6 +353,12 @@ bool C2Compiler::is_intrinsic_supported(vmIntrinsics::ID id) { case vmIntrinsics::_floatToFloat16: if (!Matcher::match_rule_supported(Op_ConvF2HF)) return false; break; + case vmIntrinsics::_sqrt_float16: + if (!Matcher::match_rule_supported(Op_SqrtHF)) return false; + break; + case vmIntrinsics::_fma_float16: + if (!Matcher::match_rule_supported(Op_FmaHF)) return false; + break; /* CompareAndSet, Object: */ case vmIntrinsics::_compareAndSetReference: diff --git a/src/hotspot/share/opto/castnode.cpp b/src/hotspot/share/opto/castnode.cpp index 63ed54df4fd9d..77e7fa88175b0 100644 --- a/src/hotspot/share/opto/castnode.cpp +++ b/src/hotspot/share/opto/castnode.cpp @@ -458,6 +458,8 @@ Node* ConstraintCastNode::make_cast_for_type(Node* c, Node* in, const Type* type return new CastIINode(c, in, type, dependency, false, types); } else if (type->isa_long()) { return new CastLLNode(c, in, type, dependency, types); + } else if (type->isa_half_float()) { + return new CastHHNode(c, in, type, dependency, types); } else if (type->isa_float()) { return new CastFFNode(c, in, type, dependency, types); } else if (type->isa_double()) { diff --git a/src/hotspot/share/opto/castnode.hpp b/src/hotspot/share/opto/castnode.hpp index c193d406f939f..0f62bb0164d48 100644 --- a/src/hotspot/share/opto/castnode.hpp +++ b/src/hotspot/share/opto/castnode.hpp @@ -141,6 +141,17 @@ class CastLLNode: public ConstraintCastNode { virtual uint ideal_reg() const { return Op_RegL; } }; +class CastHHNode: public ConstraintCastNode { +public: + CastHHNode(Node* ctrl, Node* n, const Type* t, DependencyType dependency = RegularDependency, const TypeTuple* types = nullptr) + : ConstraintCastNode(ctrl, n, t, dependency, types) { + assert(ctrl != nullptr, "control must be set"); + init_class_id(Class_CastHH); + } + virtual int Opcode() const; + virtual uint ideal_reg() const { return in(1)->ideal_reg(); } +}; + class CastFFNode: public ConstraintCastNode { public: CastFFNode(Node* ctrl, Node* n, const Type* t, DependencyType dependency = RegularDependency, const TypeTuple* types = nullptr) diff --git a/src/hotspot/share/opto/classes.hpp b/src/hotspot/share/opto/classes.hpp index 60ee3e01137b0..790d0a9f863a5 100644 --- a/src/hotspot/share/opto/classes.hpp +++ b/src/hotspot/share/opto/classes.hpp @@ -36,6 +36,7 @@ macro(AddF) macro(AddI) macro(AddL) macro(AddP) +macro(AddHF) macro(Allocate) macro(AllocateArray) macro(AndI) @@ -64,6 +65,7 @@ macro(CallLeafVector) macro(CallRuntime) macro(CallStaticJava) macro(CastDD) +macro(CastHH) macro(CastFF) macro(CastII) macro(CastLL) @@ -132,6 +134,7 @@ macro(Con) macro(ConN) macro(ConNKlass) macro(ConD) +macro(ConH) macro(ConF) macro(ConI) macro(ConL) @@ -166,6 +169,7 @@ macro(CountTrailingZerosV) macro(CreateEx) macro(DecodeN) macro(DecodeNKlass) +macro(DivHF) macro(DivD) macro(DivF) macro(DivI) @@ -184,6 +188,7 @@ macro(FastLock) macro(FastUnlock) macro(FmaD) macro(FmaF) +macro(FmaHF) macro(ForwardException) macro(Goto) macro(Halt) @@ -222,6 +227,7 @@ macro(MachProj) macro(MulAddS2I) macro(MaxI) macro(MaxL) +macro(MaxHF) macro(MaxD) macro(MaxF) macro(MemBarAcquire) @@ -237,6 +243,7 @@ macro(MemBarStoreStore) macro(MergeMem) macro(MinI) macro(MinL) +macro(MinHF) macro(MinF) macro(MinD) macro(ModD) @@ -253,6 +260,7 @@ macro(IsInfiniteF) macro(IsFiniteF) macro(IsInfiniteD) macro(IsFiniteD) +macro(MulHF) macro(MulD) macro(MulF) macro(MulHiL) @@ -338,6 +346,7 @@ macro(SignumVF) macro(SignumVD) macro(SqrtD) macro(SqrtF) +macro(SqrtHF) macro(RoundF) macro(RoundD) macro(Start) @@ -357,6 +366,7 @@ macro(StrEquals) macro(StrIndexOf) macro(StrIndexOfChar) macro(StrInflatedCopy) +macro(SubHF) macro(SubD) macro(SubF) macro(SubI) @@ -379,6 +389,7 @@ macro(AddVI) macro(AddReductionVI) macro(AddVL) macro(AddReductionVL) +macro(AddVHF) macro(AddVF) macro(AddReductionVF) macro(AddVD) @@ -387,6 +398,7 @@ macro(SubVB) macro(SubVS) macro(SubVI) macro(SubVL) +macro(SubVHF) macro(SubVF) macro(SubVD) macro(MulVB) @@ -395,6 +407,7 @@ macro(MulVI) macro(MulReductionVI) macro(MulVL) macro(MulReductionVL) +macro(MulVHF) macro(MulVF) macro(MulReductionVF) macro(MulVD) @@ -402,6 +415,8 @@ macro(MulReductionVD) macro(MulAddVS2VI) macro(FmaVD) macro(FmaVF) +macro(FmaVHF) +macro(DivVHF) macro(DivVF) macro(DivVD) macro(AbsVB) @@ -416,6 +431,7 @@ macro(NegVF) macro(NegVD) macro(SqrtVD) macro(SqrtVF) +macro(SqrtVHF) macro(LShiftCntV) macro(RShiftCntV) macro(LShiftVB) @@ -436,6 +452,8 @@ macro(OrV) macro(OrReductionV) macro(XorV) macro(XorReductionV) +macro(MinVHF) +macro(MaxVHF) macro(MinV) macro(MaxV) macro(UMinV) @@ -485,6 +503,8 @@ macro(ExtractF) macro(ExtractD) macro(Digit) macro(LowerCase) +macro(ReinterpretS2HF) +macro(ReinterpretHF2S) macro(UpperCase) macro(Whitespace) macro(SelectFromTwoVector) diff --git a/src/hotspot/share/opto/connode.cpp b/src/hotspot/share/opto/connode.cpp index 00049c6f19e36..988f2cfc05f34 100644 --- a/src/hotspot/share/opto/connode.cpp +++ b/src/hotspot/share/opto/connode.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 1997, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 1997, 2024, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -44,6 +44,9 @@ uint ConNode::hash() const { //------------------------------make------------------------------------------- ConNode *ConNode::make(const Type *t) { + if (t->isa_half_float_constant()) { + return new ConHNode( t->is_half_float_constant() ); + } switch( t->basic_type() ) { case T_INT: return new ConINode( t->is_int() ); case T_LONG: return new ConLNode( t->is_long() ); diff --git a/src/hotspot/share/opto/connode.hpp b/src/hotspot/share/opto/connode.hpp index 618326ec52711..77be0c2ee759d 100644 --- a/src/hotspot/share/opto/connode.hpp +++ b/src/hotspot/share/opto/connode.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 1997, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 1997, 2024, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -115,6 +115,19 @@ class ConLNode : public ConNode { }; +//------------------------------ConHNode--------------------------------------- +// Simple half float constants +class ConHNode : public ConNode { +public: + ConHNode(const TypeH* t) : ConNode(t) {} + virtual int Opcode() const; + + // Factory method: + static ConHNode* make(float con) { + return new ConHNode(TypeH::make(con)); + } +}; + //------------------------------ConFNode--------------------------------------- // Simple float constants class ConFNode : public ConNode { diff --git a/src/hotspot/share/opto/constantTable.cpp b/src/hotspot/share/opto/constantTable.cpp index 345af51839d4a..9c19c32bffb1b 100644 --- a/src/hotspot/share/opto/constantTable.cpp +++ b/src/hotspot/share/opto/constantTable.cpp @@ -63,6 +63,7 @@ bool ConstantTable::Constant::operator==(const Constant& other) { } // For floating point values we compare the bit pattern. switch (type()) { + case T_SHORT: return (_v._value.i == other._v._value.i); case T_INT: return (_v._value.i == other._v._value.i); case T_FLOAT: return jint_cast(_v._value.f) == jint_cast(other._v._value.f); case T_LONG: return (_v._value.j == other._v._value.j); @@ -87,6 +88,7 @@ static int constant_size(ConstantTable::Constant* con) { return type2aelembytes(con->type()) * con->get_array()->length(); } switch (con->type()) { + case T_SHORT: return sizeof(jint ); case T_INT: return sizeof(jint ); case T_LONG: return sizeof(jlong ); case T_FLOAT: return sizeof(jfloat ); @@ -152,6 +154,7 @@ bool ConstantTable::emit(C2_MacroAssembler* masm) const { constant_addr = masm->array_constant(con.type(), con.get_array(), con.alignment()); } else { switch (con.type()) { + case T_SHORT: constant_addr = masm->int_constant( con.get_jint() ); break; case T_INT: constant_addr = masm->int_constant( con.get_jint() ); break; case T_LONG: constant_addr = masm->long_constant( con.get_jlong() ); break; case T_FLOAT: constant_addr = masm->float_constant( con.get_jfloat() ); break; @@ -268,6 +271,7 @@ ConstantTable::Constant ConstantTable::add(MachConstantNode* n, MachOper* oper) BasicType type = oper->type()->basic_type(); switch (type) { case T_LONG: value.j = oper->constantL(); break; + case T_SHORT: value.i = oper->constantH(); break; case T_INT: value.i = oper->constant(); break; case T_FLOAT: value.f = oper->constantF(); break; case T_DOUBLE: value.d = oper->constantD(); break; diff --git a/src/hotspot/share/opto/convertnode.cpp b/src/hotspot/share/opto/convertnode.cpp index 0a2131782a237..a0cdc41bdad6e 100644 --- a/src/hotspot/share/opto/convertnode.cpp +++ b/src/hotspot/share/opto/convertnode.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2014, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2014, 2024, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -27,8 +27,10 @@ #include "opto/castnode.hpp" #include "opto/connode.hpp" #include "opto/convertnode.hpp" +#include "opto/divnode.hpp" #include "opto/matcher.hpp" #include "opto/movenode.hpp" +#include "opto/mulnode.hpp" #include "opto/phaseX.hpp" #include "opto/subnode.hpp" #include "runtime/stubRoutines.hpp" @@ -249,6 +251,37 @@ const Type* ConvF2HFNode::Value(PhaseGVN* phase) const { return TypeInt::make( StubRoutines::f2hf(tf->getf()) ); } +//------------------------------Ideal------------------------------------------ +Node* ConvF2HFNode::Ideal(PhaseGVN* phase, bool can_reshape) { + // Float16 instance encapsulates a short field holding IEEE 754 + // binary16 value. On unboxing, this short field is loaded into a + // GPR register while FP operation operates over floating point + // registers. ConvHF2F converts incoming short value to a FP32 value + // to perform operation at FP32 granularity. However, if target + // support FP16 ISA we can save this redundant up casting and + // optimize the graph pallet using following transformation. + // + // ConvF2HF(FP32BinOp(ConvHF2F(x), ConvHF2F(y))) => + // ReinterpretHF2S(FP16BinOp(ReinterpretS2HF(x), ReinterpretS2HF(y))) + // + // Please note we need to inject appropriate reinterpretation + // IR to move the values b/w GPR and floating point register + // before and after FP16 operation. + + if (Float16NodeFactory::is_float32_binary_oper(in(1)->Opcode()) && + in(1)->in(1)->Opcode() == Op_ConvHF2F && + in(1)->in(2)->Opcode() == Op_ConvHF2F) { + if (Matcher::match_rule_supported(Float16NodeFactory::get_float16_binary_oper(in(1)->Opcode())) && + Matcher::match_rule_supported(Op_ReinterpretS2HF) && + Matcher::match_rule_supported(Op_ReinterpretHF2S)) { + Node* in1 = phase->transform(new ReinterpretS2HFNode(in(1)->in(1)->in(1))); + Node* in2 = phase->transform(new ReinterpretS2HFNode(in(1)->in(2)->in(1))); + Node* binop = phase->transform(Float16NodeFactory::make(in(1)->Opcode(), in(1)->in(0), in1, in2)); + return new ReinterpretHF2SNode(binop); + } + } + return nullptr; +} //============================================================================= //------------------------------Value------------------------------------------ const Type* ConvF2INode::Value(PhaseGVN* phase) const { @@ -897,3 +930,76 @@ const Type* RoundDoubleModeNode::Value(PhaseGVN* phase) const { return Type::DOUBLE; } //============================================================================= + +const Type* ReinterpretS2HFNode::Value(PhaseGVN* phase) const { + const Type* type = phase->type(in(1)); + // Convert short constant value to a Half Float constant value + if ((type->isa_int() && type->is_int()->is_con())) { + jshort hfval = type->is_int()->get_con(); + return TypeH::make(hfval); + } + return Type::HALF_FLOAT; +} + +Node* ReinterpretS2HFNode::Identity(PhaseGVN* phase) { + if (in(1)->Opcode() == Op_ReinterpretHF2S) { + assert(in(1)->in(1)->bottom_type()->isa_half_float(), ""); + return in(1)->in(1); + } + return this; +} + +const Type* ReinterpretHF2SNode::Value(PhaseGVN* phase) const { + const Type* type = phase->type(in(1)); + // Convert Half float constant value to short constant value. + if (type->isa_half_float_constant()) { + jshort hfval = type->is_half_float_constant()->_f; + return TypeInt::make(hfval); + } + return TypeInt::SHORT; +} + +bool Float16NodeFactory::is_float32_binary_oper(int opc) { + switch(opc) { + case Op_AddF: + case Op_SubF: + case Op_MulF: + case Op_DivF: + case Op_MaxF: + case Op_MinF: + return true; + default: + return false; + } +} + +int Float16NodeFactory::get_float16_binary_oper(int opc) { + switch(opc) { + case Op_AddF: + return Op_AddHF; + case Op_SubF: + return Op_SubHF; + case Op_MulF: + return Op_MulHF; + case Op_DivF: + return Op_DivHF; + case Op_MaxF: + return Op_MaxHF; + case Op_MinF: + return Op_MinHF; + default: + return false; + } +} + +Node* Float16NodeFactory::make(int opc, Node* c, Node* in1, Node* in2) { + switch(opc) { + case Op_AddF: return new AddHFNode(in1, in2); + case Op_SubF: return new SubHFNode(in1, in2); + case Op_MulF: return new MulHFNode(in1, in2); + case Op_DivF: return new DivHFNode(c, in1, in2); + case Op_MaxF: return new MaxHFNode(in1, in2); + case Op_MinF: return new MinHFNode(in1, in2); + default: ShouldNotReachHere(); + } +} diff --git a/src/hotspot/share/opto/convertnode.hpp b/src/hotspot/share/opto/convertnode.hpp index 9438176a9f996..80d4acb96fff2 100644 --- a/src/hotspot/share/opto/convertnode.hpp +++ b/src/hotspot/share/opto/convertnode.hpp @@ -112,6 +112,7 @@ class ConvF2HFNode : public ConvertNode { virtual int Opcode() const; virtual const Type* in_type() const { return TypeInt::FLOAT; } virtual const Type* Value(PhaseGVN* phase) const; + virtual Node* Ideal(PhaseGVN* phase, bool can_reshape); }; //------------------------------ConvF2INode------------------------------------ @@ -213,6 +214,30 @@ class ConvL2INode : public ConvertNode { virtual Node* Ideal(PhaseGVN* phase, bool can_reshape); }; + +//-----------------------------ReinterpretS2HFNode --------------------------- +// Reinterpret Short to Half Float +class ReinterpretS2HFNode : public Node { + public: + ReinterpretS2HFNode(Node* in1) : Node(0, in1) {} + virtual int Opcode() const; + virtual const Type* bottom_type() const { return Type::HALF_FLOAT; } + virtual const Type* Value(PhaseGVN* phase) const; + virtual Node* Identity(PhaseGVN* phase); + virtual uint ideal_reg() const { return Op_RegF; } +}; + +//-----------------------------ReinterpretS2HFNode --------------------------- +// Reinterpret Half Float to Short +class ReinterpretHF2SNode : public Node { + public: + ReinterpretHF2SNode(Node* in1) : Node(0,in1) {} + virtual int Opcode() const; + virtual const Type* Value(PhaseGVN* phase) const; + virtual const Type* bottom_type() const { return TypeInt::SHORT; } + virtual uint ideal_reg() const { return Op_RegI; } +}; + class RoundDNode : public Node { public: RoundDNode(Node* in1) : Node(nullptr, in1) {} @@ -269,5 +294,11 @@ class RoundDoubleModeNode: public Node { virtual const Type* Value(PhaseGVN* phase) const; }; +class Float16NodeFactory { + public: + static bool is_float32_binary_oper(int opc); + static int get_float16_binary_oper(int opc); + static Node* make(int opc, Node* c, Node* in1, Node* in2); +}; #endif // SHARE_OPTO_CONVERTNODE_HPP diff --git a/src/hotspot/share/opto/divnode.cpp b/src/hotspot/share/opto/divnode.cpp index bf765578cc464..deefe17e5157b 100644 --- a/src/hotspot/share/opto/divnode.cpp +++ b/src/hotspot/share/opto/divnode.cpp @@ -747,6 +747,115 @@ Node *DivFNode::Ideal(PhaseGVN *phase, bool can_reshape) { // return multiplication by the reciprocal return (new MulFNode(in(1), phase->makecon(TypeF::make(reciprocal)))); } +//============================================================================= +//------------------------------Value------------------------------------------ +// An DivHFNode divides its inputs. The third input is a Control input, used to +// prevent hoisting the divide above an unsafe test. +const Type* DivHFNode::Value(PhaseGVN* phase) const { + // Either input is TOP ==> the result is TOP + const Type* t1 = phase->type(in(1)); + const Type* t2 = phase->type(in(2)); + if(t1 == Type::TOP) { return Type::TOP; } + if(t2 == Type::TOP) { return Type::TOP; } + + // Either input is BOTTOM ==> the result is the local BOTTOM + const Type* bot = bottom_type(); + if((t1 == bot) || (t2 == bot) || + (t1 == Type::BOTTOM) || (t2 == Type::BOTTOM)) { + return bot; + } + + // x/x == 1, we ignore 0/0. + // Note: if t1 and t2 are zero then result is NaN (JVMS page 213) + // Does not work for variables because of NaN's + if (in(1) == in(2) && t1->base() == Type::HalfFloatCon && + !g_isnan(t1->getf()) && g_isfinite(t1->getf()) && t1->getf() != 0.0) { // could be negative ZERO or NaN + return TypeH::ONE; + } + + if (t2 == TypeH::ONE) { + return t1; + } + + // If divisor is a constant and not zero, divide them numbers + if (t1->base() == Type::HalfFloatCon && + t2->base() == Type::HalfFloatCon && + t2->getf() != 0.0) { + // could be negative zero + return TypeH::make(t1->getf()/t2->getf()); + } + + // If the dividend is a constant zero + // Note: if t1 and t2 are zero then result is NaN (JVMS page 213) + // Test TypeF::ZERO is not sufficient as it could be negative zero + + if (t1 == TypeH::ZERO && !g_isnan(t2->getf()) && t2->getf() != 0.0) { + return TypeH::ZERO; + } + + // If divisor or dividend is nan then result is nan. + if (g_isnan(t1->getf()) || g_isnan(t2->getf())) { + return TypeH::make(NAN); + } + + // Otherwise we give up all hope + return Type::HALF_FLOAT; +} + +//------------------------------isA_Copy--------------------------------------- +// Dividing by self is 1. +// IF the divisor is 1, we are an identity on the dividend. +Node* DivHFNode::Identity(PhaseGVN* phase) { + return (phase->type( in(2) ) == TypeH::ONE) ? in(1) : this; +} + + +//------------------------------Idealize--------------------------------------- +Node* DivHFNode::Ideal(PhaseGVN* phase, bool can_reshape) { + if (in(0) != nullptr && remove_dead_region(phase, can_reshape)) return this; + // Don't bother trying to transform a dead node + if (in(0) != nullptr && in(0)->is_top()) { return nullptr; } + + const Type* t2 = phase->type(in(2)); + if (t2 == TypeH::ONE) { // Identity? + return nullptr; // Skip it + } + const TypeH* tf = t2->isa_half_float_constant(); + if(tf == nullptr) { return nullptr; } + if(tf->base() != Type::HalfFloatCon) { return nullptr; } + + // Check for out of range values + if(tf->is_nan() || !tf->is_finite()) { return nullptr; } + + // Get the value + float f = tf->getf(); + int exp; + + // Consider the following geometric progression series of POT(power of two) numbers. + // 0.5 x 2^0 = 0.5, 0.5 x 2^1 = 1.0, 0.5 x 2^2 = 2.0, 0.5 x 2^3 = 4.0 ... 0.5 x 2^n, + // In all the above cases, normalized mantissa returned by frexp routine will + // be exactly equal to 0.5 while exponent will be 0,1,2,3...n + // Perform division to multiplication transform only if divisor is a POT value. + if(frexp((double)f, &exp) != 0.5) { return nullptr; } + + // Limit the range of acceptable exponents + if(exp < -14 || exp > 15) { return nullptr; } + + // Since divisor is a POT number, hence its reciprocal will never + // overflow 11 bits precision range of Float16 + // value if exponent returned by frexp routine strictly lie + // within the exponent range of normal min(0x1.0P-14) and + // normal max(0x1.ffcP+15) values. + // Thus we can safely compute the reciprocal of divisor without + // any concerns about the precision loss and transform the division + // into a multiplication operation. + float reciprocal = ((float)1.0) / f; + + assert(frexp((double)reciprocal, &exp) == 0.5, "reciprocal should be power of 2"); + + // return multiplication by the reciprocal + return (new MulHFNode(in(1), phase->makecon(TypeH::make(reciprocal)))); +} //============================================================================= //------------------------------Value------------------------------------------ diff --git a/src/hotspot/share/opto/divnode.hpp b/src/hotspot/share/opto/divnode.hpp index 9c8ba2561ee5d..9d82b43fa9a81 100644 --- a/src/hotspot/share/opto/divnode.hpp +++ b/src/hotspot/share/opto/divnode.hpp @@ -77,6 +77,20 @@ class DivFNode : public Node { virtual uint ideal_reg() const { return Op_RegF; } }; + +//------------------------------DivHFNode-------------------------------------- +// Half float division +class DivHFNode : public Node { +public: + DivHFNode(Node* c, Node* dividend, Node* divisor) : Node(c, dividend, divisor) {} + virtual int Opcode() const; + virtual Node* Identity(PhaseGVN* phase); + virtual Node* Ideal(PhaseGVN* phase, bool can_reshape); + virtual const Type* Value(PhaseGVN* phase) const; + virtual const Type* bottom_type() const { return Type::HALF_FLOAT; } + virtual uint ideal_reg() const { return Op_RegF; } +}; + //------------------------------DivDNode--------------------------------------- // Double division class DivDNode : public Node { diff --git a/src/hotspot/share/opto/library_call.cpp b/src/hotspot/share/opto/library_call.cpp index 1072a5d6a24a5..09725eb3c02e6 100644 --- a/src/hotspot/share/opto/library_call.cpp +++ b/src/hotspot/share/opto/library_call.cpp @@ -542,7 +542,8 @@ bool LibraryCallKit::try_to_inline(int predicate) { case vmIntrinsics::_longBitsToDouble: case vmIntrinsics::_floatToFloat16: case vmIntrinsics::_float16ToFloat: return inline_fp_conversions(intrinsic_id()); - + case vmIntrinsics::_sqrt_float16: return inline_fp16_operations(intrinsic_id(), 1); + case vmIntrinsics::_fma_float16: return inline_fp16_operations(intrinsic_id(), 3); case vmIntrinsics::_floatIsFinite: case vmIntrinsics::_floatIsInfinite: case vmIntrinsics::_doubleIsFinite: @@ -8620,3 +8621,54 @@ bool LibraryCallKit::inline_blackhole() { return true; } + +bool LibraryCallKit::inline_fp16_operations(vmIntrinsics::ID id, int num_args) { + if (!Matcher::match_rule_supported(Op_ReinterpretS2HF) || + !Matcher::match_rule_supported(Op_ReinterpretHF2S)) { + return false; + } + + // Transformed nodes + Node* fld1 = nullptr; + Node* fld2 = nullptr; + Node* fld3 = nullptr; + switch(num_args) { + case 3: + assert((argument(2)->is_ConI() && + argument(2)->get_int() >= min_jshort && + argument(2)->get_int() <= max_jshort) || + (argument(2)->bottom_type()->array_element_basic_type() == T_SHORT), ""); + fld3 = _gvn.transform(new ReinterpretS2HFNode(argument(2))); + // fall-through + case 2: + assert((argument(1)->is_ConI() && + argument(1)->get_int() >= min_jshort && + argument(1)->get_int() <= max_jshort) || + (argument(1)->bottom_type()->array_element_basic_type() == T_SHORT), ""); + fld2 = _gvn.transform(new ReinterpretS2HFNode(argument(1))); + // fall-through + case 1: + assert((argument(0)->is_ConI() && + argument(0)->get_int() >= min_jshort && + argument(0)->get_int() <= max_jshort) || + (argument(0)->bottom_type()->array_element_basic_type() == T_SHORT), ""); + fld1 = _gvn.transform(new ReinterpretS2HFNode(argument(0))); + break; + default: fatal("Unsupported number of arguments %d", num_args); + } + + Node* result = nullptr; + switch (id) { + // Unary operations + case vmIntrinsics::_sqrt_float16: result = _gvn.transform(new SqrtHFNode(C, control(), fld1)); break; + + // Ternary operations + case vmIntrinsics::_fma_float16: result = _gvn.transform(new FmaHFNode(control(), fld1, fld2, fld3)); break; + default: + fatal_unexpected_iid(id); + break; + } + set_result(_gvn.transform(new ReinterpretHF2SNode(result))); + return true; +} + diff --git a/src/hotspot/share/opto/library_call.hpp b/src/hotspot/share/opto/library_call.hpp index c5437e3bf73f0..9e0fb7edc4661 100644 --- a/src/hotspot/share/opto/library_call.hpp +++ b/src/hotspot/share/opto/library_call.hpp @@ -291,6 +291,7 @@ class LibraryCallKit : public GraphKit { bool inline_onspinwait(); bool inline_fp_conversions(vmIntrinsics::ID id); bool inline_fp_range_check(vmIntrinsics::ID id); + bool inline_fp16_operations(vmIntrinsics::ID id, int num_args); bool inline_number_methods(vmIntrinsics::ID id); bool inline_bitshuffle_methods(vmIntrinsics::ID id); bool inline_compare_unsigned(vmIntrinsics::ID id); diff --git a/src/hotspot/share/opto/machnode.cpp b/src/hotspot/share/opto/machnode.cpp index e271637893b28..c905f5aa47804 100644 --- a/src/hotspot/share/opto/machnode.cpp +++ b/src/hotspot/share/opto/machnode.cpp @@ -47,6 +47,7 @@ intptr_t MachOper::constant() const { return 0x00; } relocInfo::relocType MachOper::constant_reloc() const { return relocInfo::none; } jdouble MachOper::constantD() const { ShouldNotReachHere(); } jfloat MachOper::constantF() const { ShouldNotReachHere(); } +jshort MachOper::constantH() const { ShouldNotReachHere(); } jlong MachOper::constantL() const { ShouldNotReachHere(); } TypeOopPtr *MachOper::oop() const { return nullptr; } int MachOper::ccode() const { return 0x00; } diff --git a/src/hotspot/share/opto/machnode.hpp b/src/hotspot/share/opto/machnode.hpp index 4ac91175f78ed..9c2d9c62a7414 100644 --- a/src/hotspot/share/opto/machnode.hpp +++ b/src/hotspot/share/opto/machnode.hpp @@ -156,6 +156,7 @@ class MachOper : public ResourceObj { virtual jdouble constantD() const; virtual jfloat constantF() const; virtual jlong constantL() const; + virtual jshort constantH() const; virtual TypeOopPtr *oop() const; virtual int ccode() const; // A zero, default, indicates this value is not needed. diff --git a/src/hotspot/share/opto/matcher.cpp b/src/hotspot/share/opto/matcher.cpp index b94fb7101430e..5c7523cd17473 100644 --- a/src/hotspot/share/opto/matcher.cpp +++ b/src/hotspot/share/opto/matcher.cpp @@ -2305,8 +2305,10 @@ bool Matcher::find_shared_visit(MStack& mstack, Node* n, uint opcode, bool& mem_ case Op_EncodeISOArray: case Op_FmaD: case Op_FmaF: + case Op_FmaHF: case Op_FmaVD: case Op_FmaVF: + case Op_FmaVHF: case Op_MacroLogicV: case Op_VectorCmpMasked: case Op_CompressV: @@ -2477,8 +2479,10 @@ void Matcher::find_shared_post_visit(Node* n, uint opcode) { } case Op_FmaD: case Op_FmaF: + case Op_FmaHF: case Op_FmaVD: - case Op_FmaVF: { + case Op_FmaVF: + case Op_FmaVHF: { // Restructure into a binary tree for Matching. Node* pair = new BinaryNode(n->in(1), n->in(2)); n->set_req(2, pair); diff --git a/src/hotspot/share/opto/mulnode.cpp b/src/hotspot/share/opto/mulnode.cpp index ad98fda025f07..e30b7a0cfb08b 100644 --- a/src/hotspot/share/opto/mulnode.cpp +++ b/src/hotspot/share/opto/mulnode.cpp @@ -67,7 +67,8 @@ Node *MulNode::Ideal(PhaseGVN *phase, bool can_reshape) { // only valid for the actual Mul nodes. uint op = Opcode(); bool real_mul = (op == Op_MulI) || (op == Op_MulL) || - (op == Op_MulF) || (op == Op_MulD); + (op == Op_MulF) || (op == Op_MulD) || + (op == Op_MulHF); // Convert "(-a)*(-b)" into "a*b". if (real_mul && in1->is_Sub() && in2->is_Sub()) { @@ -122,7 +123,8 @@ Node *MulNode::Ideal(PhaseGVN *phase, bool can_reshape) { // constant, flatten the expression tree. if( t2->singleton() && // Right input is a constant? op != Op_MulF && // Float & double cannot reassociate - op != Op_MulD ) { + op != Op_MulD && + op != Op_MulHF) { if( t2 == Type::TOP ) return nullptr; Node *mul1 = in(1); #ifdef ASSERT @@ -536,10 +538,29 @@ Node* MulFNode::Ideal(PhaseGVN* phase, bool can_reshape) { Node* base = in(1); return new AddFNode(base, base); } + return MulNode::Ideal(phase, can_reshape); +} + +//============================================================================= +//------------------------------Ideal------------------------------------------ +// Check to see if we are multiplying by a constant 2 and convert to add, then try the regular MulNode::Ideal +Node* MulHFNode::Ideal(PhaseGVN* phase, bool can_reshape) { + const TypeH* t2 = phase->type(in(2))->isa_half_float_constant(); + // x * 2 -> x + x + if (t2 != nullptr && t2->getf() == 2) { + Node* base = in(1); + return new AddHFNode(base, base); + } return MulNode::Ideal(phase, can_reshape); } +// Compute the product type of two half float ranges into this node. +const Type* MulHFNode::mul_ring(const Type* t0, const Type* t1) const { + if(t0 == Type::HALF_FLOAT || t1 == Type::HALF_FLOAT) return Type::HALF_FLOAT; + return TypeH::make(t0->getf() * t1->getf()); +} + //============================================================================= //------------------------------mul_ring--------------------------------------- // Compute the product type of two double ranges into this node. @@ -1901,6 +1922,28 @@ const Type* FmaFNode::Value(PhaseGVN* phase) const { #endif } +//============================================================================= +//------------------------------Value------------------------------------------ +const Type* FmaHFNode::Value(PhaseGVN* phase) const { + const Type* t1 = phase->type(in(1)); + if (t1 == Type::TOP) { return Type::TOP; } + if (t1->base() != Type::HalfFloatCon) { return Type::HALF_FLOAT; } + const Type* t2 = phase->type(in(2)); + if (t2 == Type::TOP) { return Type::TOP; } + if (t2->base() != Type::HalfFloatCon) { return Type::HALF_FLOAT; } + const Type* t3 = phase->type(in(3)); + if (t3 == Type::TOP) { return Type::TOP; } + if (t3->base() != Type::HalfFloatCon) { return Type::HALF_FLOAT; } +#ifndef __STDC_IEC_559__ + return Type::HALF_FLOAT; +#else + float f1 = t1->getf(); + float f2 = t2->getf(); + float f3 = t3->getf(); + return TypeH::make(fma(f1, f2, f3)); +#endif +} + //============================================================================= //------------------------------hash------------------------------------------- // Hash function for MulAddS2INode. Operation is commutative with commutative pairs. diff --git a/src/hotspot/share/opto/mulnode.hpp b/src/hotspot/share/opto/mulnode.hpp index 4c5e3e33248da..d7eeaa1621f33 100644 --- a/src/hotspot/share/opto/mulnode.hpp +++ b/src/hotspot/share/opto/mulnode.hpp @@ -143,6 +143,24 @@ class MulFNode : public MulNode { virtual uint ideal_reg() const { return Op_RegF; } }; +//------------------------------MulHFNode--------------------------------------- +// Multiply 2 half floats +class MulHFNode : public MulNode { +public: + MulHFNode(Node* in1, Node* in2) : MulNode(in1, in2) {} + virtual int Opcode() const; + virtual Node* Ideal(PhaseGVN* phase, bool can_reshape); + virtual const Type* mul_ring(const Type*, const Type*) const; + const Type* mul_id() const { return TypeH::ONE; } + const Type* add_id() const { return TypeH::ZERO; } + int add_opcode() const { return Op_AddHF; } + int mul_opcode() const { return Op_MulHF; } + int max_opcode() const { return Op_MaxHF; } + int min_opcode() const { return Op_MinHF; } + const Type* bottom_type() const { return Type::HALF_FLOAT; } + virtual uint ideal_reg() const { return Op_RegF; } +}; + //------------------------------MulDNode--------------------------------------- // Multiply 2 doubles class MulDNode : public MulNode { @@ -390,6 +408,17 @@ class FmaFNode : public FmaNode { virtual const Type* Value(PhaseGVN* phase) const; }; +//------------------------------FmaHFNode------------------------------------- +// fused-multiply-add half-precision float +class FmaHFNode : public FmaNode { +public: + FmaHFNode(Node* c, Node* in1, Node* in2, Node* in3) : FmaNode(c, in1, in2, in3) {} + virtual int Opcode() const; + const Type* bottom_type() const { return Type::HALF_FLOAT; } + virtual uint ideal_reg() const { return Op_RegF; } + virtual const Type* Value(PhaseGVN* phase) const; +}; + //------------------------------MulAddS2INode---------------------------------- // Multiply shorts into integers and add them. // Semantics: I_OUT = S1 * S2 + S3 * S4 diff --git a/src/hotspot/share/opto/node.cpp b/src/hotspot/share/opto/node.cpp index cf371bb3fff36..c6db27c94e197 100644 --- a/src/hotspot/share/opto/node.cpp +++ b/src/hotspot/share/opto/node.cpp @@ -1596,6 +1596,13 @@ jfloat Node::getf() const { return ((ConFNode*)this)->type()->is_float_constant()->getf(); } +// Get a half float constant from a ConstNode. +// Returns the constant if it is a float ConstNode +jshort Node::geth() const { + assert( Opcode() == Op_ConH, "" ); + return ((ConHNode*)this)->type()->is_half_float_constant()->geth(); +} + #ifndef PRODUCT // Call this from debugger: diff --git a/src/hotspot/share/opto/node.hpp b/src/hotspot/share/opto/node.hpp index b3515236ff056..84097981fb991 100644 --- a/src/hotspot/share/opto/node.hpp +++ b/src/hotspot/share/opto/node.hpp @@ -58,6 +58,7 @@ class CallNode; class CallRuntimeNode; class CallStaticJavaNode; class CastFFNode; +class CastHHNode; class CastDDNode; class CastVVNode; class CastIINode; @@ -725,6 +726,7 @@ class Node { DEFINE_CLASS_ID(CastDD, ConstraintCast, 4) DEFINE_CLASS_ID(CastVV, ConstraintCast, 5) DEFINE_CLASS_ID(CastPP, ConstraintCast, 6) + DEFINE_CLASS_ID(CastHH, ConstraintCast, 7) DEFINE_CLASS_ID(CMove, Type, 3) DEFINE_CLASS_ID(SafePointScalarObject, Type, 4) DEFINE_CLASS_ID(DecodeNarrowPtr, Type, 5) @@ -909,6 +911,7 @@ class Node { DEFINE_CLASS_QUERY(CheckCastPP) DEFINE_CLASS_QUERY(CastII) DEFINE_CLASS_QUERY(CastLL) + DEFINE_CLASS_QUERY(CastFF) DEFINE_CLASS_QUERY(ConI) DEFINE_CLASS_QUERY(CastPP) DEFINE_CLASS_QUERY(ConstraintCast) @@ -1257,6 +1260,7 @@ class Node { intptr_t get_narrowcon() const; jdouble getd() const; jfloat getf() const; + jshort geth() const; // Nodes which are pinned into basic blocks virtual bool pinned() const { return false; } diff --git a/src/hotspot/share/opto/subnode.cpp b/src/hotspot/share/opto/subnode.cpp index 445eb16821443..b5e2128e4384c 100644 --- a/src/hotspot/share/opto/subnode.cpp +++ b/src/hotspot/share/opto/subnode.cpp @@ -553,6 +553,24 @@ const Type* SubFPNode::Value(PhaseGVN* phase) const { //============================================================================= +//------------------------------sub-------------------------------------------- +// A subtract node differences its two inputs. +const Type* SubHFNode::sub(const Type* t1, const Type* t2) const { + // no folding if one of operands is infinity or NaN, do not do constant folding + if(g_isfinite(t1->getf()) && g_isfinite(t2->getf())) { + return TypeH::make(t1->getf() - t2->getf()); + } + else if(g_isnan(t1->getf())) { + return t1; + } + else if(g_isnan(t2->getf())) { + return t2; + } + else { + return Type::HALF_FLOAT; + } +} + //------------------------------Ideal------------------------------------------ Node *SubFNode::Ideal(PhaseGVN *phase, bool can_reshape) { const Type *t2 = phase->type( in(2) ); @@ -1947,6 +1965,15 @@ const Type* SqrtFNode::Value(PhaseGVN* phase) const { return TypeF::make( (float)sqrt( (double)f ) ); } +const Type* SqrtHFNode::Value(PhaseGVN* phase) const { + const Type* t1 = phase->type(in(1)); + if (t1 == Type::TOP) { return Type::TOP; } + if (t1->base() != Type::HalfFloatCon) { return Type::HALF_FLOAT; } + float f = t1->getf(); + if (f < 0.0f) return Type::HALF_FLOAT; + return TypeH::make((float)sqrt((double)f)); +} + const Type* ReverseINode::Value(PhaseGVN* phase) const { const Type *t1 = phase->type( in(1) ); if (t1 == Type::TOP) { diff --git a/src/hotspot/share/opto/subnode.hpp b/src/hotspot/share/opto/subnode.hpp index 6ceaa851739d8..e6ca7a44970ba 100644 --- a/src/hotspot/share/opto/subnode.hpp +++ b/src/hotspot/share/opto/subnode.hpp @@ -130,6 +130,18 @@ class SubDNode : public SubFPNode { virtual uint ideal_reg() const { return Op_RegD; } }; +//------------------------------SubHFNode-------------------------------------- +// Subtract 2 half floats +class SubHFNode : public SubFPNode { +public: + SubHFNode(Node* in1, Node* in2) : SubFPNode(in1, in2) {} + virtual int Opcode() const; + virtual const Type* sub(const Type*, const Type*) const; + const Type* add_id() const { return TypeH::ZERO; } + const Type* bottom_type() const { return Type::HALF_FLOAT; } + virtual uint ideal_reg() const { return Op_RegF; } +}; + //------------------------------CmpNode--------------------------------------- // Compare 2 values, returning condition codes (-1, 0 or 1). class CmpNode : public SubNode { @@ -528,6 +540,20 @@ class SqrtFNode : public Node { virtual const Type* Value(PhaseGVN* phase) const; }; +//------------------------------SqrtHFNode------------------------------------- +// square root of a half-precision float +class SqrtHFNode : public Node { +public: + SqrtHFNode(Compile* C, Node* c, Node* in1) : Node(c, in1) { + init_flags(Flag_is_expensive); + C->add_expensive_node(this); + } + virtual int Opcode() const; + const Type* bottom_type() const { return Type::HALF_FLOAT; } + virtual uint ideal_reg() const { return Op_RegF; } + virtual const Type* Value(PhaseGVN* phase) const; +}; + //-------------------------------ReverseBytesINode-------------------------------- // reverse bytes of an integer class ReverseBytesINode : public Node { diff --git a/src/hotspot/share/opto/superword.cpp b/src/hotspot/share/opto/superword.cpp index 8000e4fd39ed0..2c6539b9a7c6f 100644 --- a/src/hotspot/share/opto/superword.cpp +++ b/src/hotspot/share/opto/superword.cpp @@ -2543,6 +2543,7 @@ void VLoopTypes::compute_vector_element_type() { // Smallest type containing range of values const Type* VLoopTypes::container_type(Node* n) const { + int opc = n->Opcode(); if (n->is_Mem()) { BasicType bt = n->as_Mem()->memory_type(); if (n->is_Store() && (bt == T_CHAR)) { @@ -2550,7 +2551,7 @@ const Type* VLoopTypes::container_type(Node* n) const { // preceding arithmetic operation extends values to signed Int. bt = T_SHORT; } - if (n->Opcode() == Op_LoadUB) { + if (opc == Op_LoadUB) { // Adjust type for unsigned byte loads, it is important for right shifts. // T_BOOLEAN is used because there is no basic type representing type // TypeInt::UBYTE. Use of T_BOOLEAN for vectors is fine because only @@ -2559,12 +2560,13 @@ const Type* VLoopTypes::container_type(Node* n) const { } return Type::get_const_basic_type(bt); } + const Type* t = _vloop.phase()->igvn().type(n); if (t->basic_type() == T_INT) { // Float to half float conversion may be succeeded by a conversion from // half float to float, in such a case back propagation of narrow type (SHORT) // may not be possible. - if (n->Opcode() == Op_ConvF2HF) { + if (n->Opcode() == Op_ConvF2HF || n->Opcode() == Op_ReinterpretHF2S) { return TypeInt::SHORT; } // A narrow type of arithmetic operations will be determined by diff --git a/src/hotspot/share/opto/superwordVTransformBuilder.cpp b/src/hotspot/share/opto/superwordVTransformBuilder.cpp index 2e32ce28d3ccb..0efe062ff6b25 100644 --- a/src/hotspot/share/opto/superwordVTransformBuilder.cpp +++ b/src/hotspot/share/opto/superwordVTransformBuilder.cpp @@ -158,9 +158,11 @@ VTransformVectorNode* SuperWordVTransformBuilder::make_vector_vtnode_for_pack(co p0->is_CMove() || VectorNode::is_scalar_op_that_returns_int_but_vector_op_returns_long(opc) || VectorNode::is_convert_opcode(opc) || + VectorNode::is_reinterpret_opcode(opc) || VectorNode::is_scalar_unary_op_with_equal_input_and_output_types(opc) || - opc == Op_FmaD || - opc == Op_FmaF || + opc == Op_FmaD || + opc == Op_FmaF || + opc == Op_FmaHF || opc == Op_SignumF || opc == Op_SignumD, "pack type must be in this list"); diff --git a/src/hotspot/share/opto/type.cpp b/src/hotspot/share/opto/type.cpp index 407a4a20a9bda..9c72592995b9a 100644 --- a/src/hotspot/share/opto/type.cpp +++ b/src/hotspot/share/opto/type.cpp @@ -27,6 +27,7 @@ #include "ci/ciTypeFlow.hpp" #include "classfile/javaClasses.hpp" #include "classfile/symbolTable.hpp" +#include "classfile/vmSymbols.hpp" #include "compiler/compileLog.hpp" #include "libadt/dict.hpp" #include "memory/oopFactory.hpp" @@ -42,6 +43,7 @@ #include "utilities/checkedCast.hpp" #include "utilities/powerOfTwo.hpp" #include "utilities/stringUtils.hpp" +#include "runtime/stubRoutines.hpp" // Portions of code courtesy of Clifford Click @@ -102,6 +104,9 @@ const Type::TypeInfo Type::_type_info[Type::lastype] = { { Abio, T_ILLEGAL, "abIO", false, 0, relocInfo::none }, // Abio { Return_Address, T_ADDRESS, "return_address",false, Op_RegP, relocInfo::none }, // Return_Address { Memory, T_ILLEGAL, "memory", false, 0, relocInfo::none }, // Memory + { HalfFloatBot, T_SHORT, "halffloat_top", false, Op_RegF, relocInfo::none }, // HalfFloatTop + { HalfFloatCon, T_SHORT, "hfcon:", false, Op_RegF, relocInfo::none }, // HalfFloatCon + { HalfFloatTop, T_SHORT, "short", false, Op_RegF, relocInfo::none }, // HalfFloatBot { FloatBot, T_FLOAT, "float_top", false, Op_RegF, relocInfo::none }, // FloatTop { FloatCon, T_FLOAT, "ftcon:", false, Op_RegF, relocInfo::none }, // FloatCon { FloatTop, T_FLOAT, "float", false, Op_RegF, relocInfo::none }, // FloatBot @@ -131,6 +136,7 @@ const Type *Type::ABIO; // State-of-machine only const Type *Type::BOTTOM; // All values const Type *Type::CONTROL; // Control only const Type *Type::DOUBLE; // All doubles +const Type *Type::HALF_FLOAT; // All half floats const Type *Type::FLOAT; // All floats const Type *Type::HALF; // Placeholder half of doublewide type const Type *Type::MEMORY; // Abstract store only @@ -451,6 +457,7 @@ void Type::Initialize_shared(Compile* current) { ABIO = make(Abio); // State-of-machine only RETURN_ADDRESS=make(Return_Address); FLOAT = make(FloatBot); // All floats + HALF_FLOAT = make(HalfFloatBot); // All half floats DOUBLE = make(DoubleBot); // All doubles BOTTOM = make(Bottom); // Everything HALF = make(Half); // Placeholder half of doublewide type @@ -462,6 +469,13 @@ void Type::Initialize_shared(Compile* current) { TypeF::POS_INF = TypeF::make(jfloat_cast(POSITIVE_INFINITE_F)); TypeF::NEG_INF = TypeF::make(-jfloat_cast(POSITIVE_INFINITE_F)); + TypeH::MAX = TypeH::make(max_jfloat16); // HalfFloat MAX + TypeH::MIN = TypeH::make(min_jfloat16); // HalfFloat MIN + TypeH::ZERO = TypeH::make((jshort)0); // HalfFloat 0 (positive zero) + TypeH::ONE = TypeH::make(one_jfloat16); // HalfFloat 1 + TypeH::POS_INF = TypeH::make(pos_inf_jfloat16); + TypeH::NEG_INF = TypeH::make(neg_inf_jfloat16); + TypeD::MAX = TypeD::make(max_jdouble); // Double MAX TypeD::MIN = TypeD::make(min_jdouble); // Double MIN TypeD::ZERO = TypeD::make(0.0); // Double 0 (positive zero) @@ -1032,6 +1046,7 @@ const Type *Type::xmeet( const Type *t ) const { // Cut in half the number of cases I must handle. Only need cases for when // the given enum "t->type" is less than or equal to the local enum "type". + case HalfFloatCon: case FloatCon: case DoubleCon: case Int: @@ -1067,19 +1082,30 @@ const Type *Type::xmeet( const Type *t ) const { case Bottom: // Ye Olde Default return t; + case HalfFloatTop: + if (_base == HalfFloatTop) return this; + case HalfFloatBot: // Half Float + if (_base == HalfFloatBot || _base == HalfFloatTop) return HALF_FLOAT; + if (_base == FloatBot || _base == FloatTop) return Type::BOTTOM; + if (_base == DoubleTop || _base == DoubleBot) return Type::BOTTOM; + typerr(t); + return Type::BOTTOM; + case FloatTop: - if( _base == FloatTop ) return this; + if (_base == FloatTop ) return this; case FloatBot: // Float - if( _base == FloatBot || _base == FloatTop ) return FLOAT; - if( _base == DoubleTop || _base == DoubleBot ) return Type::BOTTOM; + if (_base == FloatBot || _base == FloatTop) return FLOAT; + if (_base == HalfFloatTop || _base == HalfFloatBot) return Type::BOTTOM; + if (_base == DoubleTop || _base == DoubleBot) return Type::BOTTOM; typerr(t); return Type::BOTTOM; case DoubleTop: - if( _base == DoubleTop ) return this; + if (_base == DoubleTop) return this; case DoubleBot: // Double - if( _base == DoubleBot || _base == DoubleTop ) return DOUBLE; - if( _base == FloatTop || _base == FloatBot ) return Type::BOTTOM; + if (_base == DoubleBot || _base == DoubleTop) return DOUBLE; + if (_base == HalfFloatTop || _base == HalfFloatBot) return Type::BOTTOM; + if (_base == FloatTop || _base == FloatBot) return Type::BOTTOM; typerr(t); return Type::BOTTOM; @@ -1087,7 +1113,7 @@ const Type *Type::xmeet( const Type *t ) const { case Control: // Control of code case Abio: // State of world outside of program case Memory: - if( _base == t->_base ) return this; + if (_base == t->_base) return this; typerr(t); return Type::BOTTOM; @@ -1167,6 +1193,7 @@ bool Type::empty(void) const { switch (_base) { case DoubleTop: case FloatTop: + case HalfFloatTop: case Top: return true; @@ -1175,6 +1202,7 @@ bool Type::empty(void) const { case Return_Address: case Memory: case Bottom: + case HalfFloatBot: case FloatBot: case DoubleBot: return false; // never a singleton, therefore never empty @@ -1222,6 +1250,9 @@ Type::Category Type::category() const { case Type::AryKlassPtr: case Type::Function: case Type::Return_Address: + case Type::HalfFloatTop: + case Type::HalfFloatCon: + case Type::HalfFloatBot: case Type::FloatTop: case Type::FloatCon: case Type::FloatBot: @@ -1327,6 +1358,9 @@ const Type *TypeF::xmeet( const Type *t ) const { case NarrowKlass: case Int: case Long: + case HalfFloatTop: + case HalfFloatCon: + case HalfFloatBot: case DoubleTop: case DoubleCon: case DoubleBot: @@ -1405,6 +1439,138 @@ bool TypeF::empty(void) const { return false; // always exactly a singleton } +//============================================================================= +// Convenience common pre-built types. +const TypeH *TypeH::MAX; // Half float max +const TypeH *TypeH::MIN; // Half float min +const TypeH *TypeH::ZERO; // Half float zero +const TypeH *TypeH::ONE; // Half float one +const TypeH *TypeH::POS_INF; // Half float positive infinity +const TypeH *TypeH::NEG_INF; // Half float negative infinity + +//------------------------------make------------------------------------------- +// Create a halffloat constant +const TypeH *TypeH::make(short f) { + return (TypeH*)(new TypeH(f))->hashcons(); +} + +const TypeH *TypeH::make(float f) { + assert(StubRoutines::f2hf_adr() != nullptr, ""); + short hf = StubRoutines::f2hf(f); + return (TypeH*)(new TypeH(hf))->hashcons(); +} + +//------------------------------meet------------------------------------------- +// Compute the MEET of two types. It returns a new Type object. +const Type *TypeH::xmeet( const Type *t ) const { + // Perform a fast test for common case; meeting the same types together. + if (this == t) return this; // Meeting same type-rep? + + // Current "this->_base" is FloatCon + switch (t->base()) { // Switch on original type + case AnyPtr: // Mixing with oops happens when javac + case RawPtr: // reuses local variables + case OopPtr: + case InstPtr: + case AryPtr: + case MetadataPtr: + case KlassPtr: + case InstKlassPtr: + case AryKlassPtr: + case NarrowOop: + case NarrowKlass: + case Int: + case Long: + case FloatTop: + case FloatCon: + case FloatBot: + case DoubleTop: + case DoubleCon: + case DoubleBot: + case Bottom: // Ye Olde Default + return Type::BOTTOM; + + case HalfFloatBot: + return t; + + default: // All else is a mistake + typerr(t); + + case HalfFloatCon: // Half float-constant vs Half float-constant? + if (_f != t->geth()) { // unequal constants? + // must compare bitwise as positive zero, negative zero and NaN have + // all the same representation in C++ + return HALF_FLOAT; // Return generic float + } // Equal constants + case Top: + case HalfFloatTop: + break; // Return the Half float constant + } + return this; // Return the Half float constant +} + +//------------------------------xdual------------------------------------------ +// Dual: symmetric +const Type *TypeH::xdual() const { + return this; +} + +//------------------------------eq--------------------------------------------- +// Structural equality check for Type representations +bool TypeH::eq(const Type *t) const { + // Bitwise comparison to distinguish between +/-0. These values must be treated + // as different to be consistent with C1 and the interpreter. + return (_f == t->geth()); +} + +//------------------------------hash------------------------------------------- +// Type-specific hashing function. +uint TypeH::hash(void) const { + return *(uint*)(&_f); +} + +//------------------------------is_finite-------------------------------------- +// Has a finite value +bool TypeH::is_finite() const { + assert(StubRoutines::hf2f_adr() != nullptr, ""); + float f = StubRoutines::hf2f(geth()); + return g_isfinite(f) != 0; +} + +float TypeH::getf() const { + assert(StubRoutines::hf2f_adr() != nullptr, ""); + return StubRoutines::hf2f(geth()); +} + +//------------------------------is_nan----------------------------------------- +// Is not a number (NaN) +bool TypeH::is_nan() const { + assert(StubRoutines::hf2f_adr() != nullptr, ""); + float f = StubRoutines::hf2f(geth()); + return g_isnan(f) != 0; +} + +//------------------------------dump2------------------------------------------ +// Dump float constant Type +#ifndef PRODUCT +void TypeH::dump2( Dict &d, uint depth, outputStream *st ) const { + Type::dump2(d,depth, st); + st->print("%f", getf()); +} +#endif + +//------------------------------singleton-------------------------------------- +// TRUE if Type is a singleton type, FALSE otherwise. Singletons are simple +// constants (Ldi nodes). Singletons are integer, half float, float or double constants +// or a single symbol. +bool TypeH::singleton(void) const { + return true; // Always a singleton +} + +bool TypeH::empty(void) const { + return false; // always exactly a singleton +} + //============================================================================= // Convenience common pre-built types. const TypeD *TypeD::MAX; // Floating point max @@ -1440,6 +1606,9 @@ const Type *TypeD::xmeet( const Type *t ) const { case NarrowKlass: case Int: case Long: + case HalfFloatTop: + case HalfFloatCon: + case HalfFloatBot: case FloatTop: case FloatCon: case FloatBot: @@ -1636,6 +1805,9 @@ const Type *TypeInt::xmeet( const Type *t ) const { case NarrowOop: case NarrowKlass: case Long: + case HalfFloatTop: + case HalfFloatCon: + case HalfFloatBot: case FloatTop: case FloatCon: case FloatBot: @@ -1899,6 +2071,9 @@ const Type *TypeLong::xmeet( const Type *t ) const { case NarrowOop: case NarrowKlass: case Int: + case HalfFloatTop: + case HalfFloatCon: + case HalfFloatBot: case FloatTop: case FloatCon: case FloatBot: @@ -2693,6 +2868,9 @@ const Type *TypePtr::xmeet_helper(const Type *t) const { switch (t->base()) { // switch on original type case Int: // Mixing ints & oops happens when javac case Long: // reuses local variables + case HalfFloatTop: + case HalfFloatCon: + case HalfFloatBot: case FloatTop: case FloatCon: case FloatBot: @@ -3632,6 +3810,9 @@ const Type *TypeOopPtr::xmeet_helper(const Type *t) const { case Int: // Mixing ints & oops happens when javac case Long: // reuses local variables + case HalfFloatTop: + case HalfFloatCon: + case HalfFloatBot: case FloatTop: case FloatCon: case FloatBot: @@ -4200,6 +4381,9 @@ const Type *TypeInstPtr::xmeet_helper(const Type *t) const { case Int: // Mixing ints & oops happens when javac case Long: // reuses local variables + case HalfFloatTop: + case HalfFloatCon: + case HalfFloatBot: case FloatTop: case FloatCon: case FloatBot: @@ -4876,6 +5060,9 @@ const Type *TypeAryPtr::xmeet_helper(const Type *t) const { // Mixing ints & oops happens when javac reuses local variables case Int: case Long: + case HalfFloatTop: + case HalfFloatCon: + case HalfFloatBot: case FloatTop: case FloatCon: case FloatBot: @@ -5305,6 +5492,9 @@ const Type *TypeNarrowPtr::xmeet( const Type *t ) const { case Int: // Mixing ints & oops happens when javac case Long: // reuses local variables + case HalfFloatTop: + case HalfFloatCon: + case HalfFloatBot: case FloatTop: case FloatCon: case FloatBot: @@ -5460,6 +5650,9 @@ const Type *TypeMetadataPtr::xmeet( const Type *t ) const { case Int: // Mixing ints & oops happens when javac case Long: // reuses local variables + case HalfFloatTop: + case HalfFloatCon: + case HalfFloatBot: case FloatTop: case FloatCon: case FloatBot: @@ -5834,6 +6027,9 @@ const Type *TypeInstKlassPtr::xmeet( const Type *t ) const { case Int: // Mixing ints & oops happens when javac case Long: // reuses local variables + case HalfFloatTop: + case HalfFloatCon: + case HalfFloatBot: case FloatTop: case FloatCon: case FloatBot: @@ -6258,6 +6454,9 @@ const Type *TypeAryKlassPtr::xmeet( const Type *t ) const { case Int: // Mixing ints & oops happens when javac case Long: // reuses local variables + case HalfFloatTop: + case HalfFloatCon: + case HalfFloatBot: case FloatTop: case FloatCon: case FloatBot: diff --git a/src/hotspot/share/opto/type.hpp b/src/hotspot/share/opto/type.hpp index f6b7efcae3bcc..2dc8b989bde91 100644 --- a/src/hotspot/share/opto/type.hpp +++ b/src/hotspot/share/opto/type.hpp @@ -45,6 +45,7 @@ class Dict; class Type; class TypeD; class TypeF; +class TypeH; class TypeInteger; class TypeInt; class TypeLong; @@ -120,6 +121,9 @@ class Type { Abio, // Abstract I/O Return_Address, // Subroutine return address Memory, // Abstract store + HalfFloatTop, // No float value + HalfFloatCon, // Floating point constant + HalfFloatBot, // Any float value FloatTop, // No float value FloatCon, // Floating point constant FloatBot, // Any float value @@ -277,7 +281,8 @@ class Type { bool is_ptr_to_narrowklass() const; // Convenience access - float getf() const; + short geth() const; + virtual float getf() const; double getd() const; const TypeInt *is_int() const; @@ -289,6 +294,9 @@ class Type { const TypeD *isa_double() const; // Returns null if not a Double{Top,Con,Bot} const TypeD *is_double_constant() const; // Asserts it is a DoubleCon const TypeD *isa_double_constant() const; // Returns null if not a DoubleCon + const TypeH *isa_half_float() const; // Returns null if not a Float{Top,Con,Bot} + const TypeH *is_half_float_constant() const; // Asserts it is a FloatCon + const TypeH *isa_half_float_constant() const; // Returns null if not a FloatCon const TypeF *isa_float() const; // Returns null if not a Float{Top,Con,Bot} const TypeF *is_float_constant() const; // Asserts it is a FloatCon const TypeF *isa_float_constant() const; // Returns null if not a FloatCon @@ -428,6 +436,7 @@ class Type { static const Type *CONTROL; static const Type *DOUBLE; static const Type *FLOAT; + static const Type *HALF_FLOAT; static const Type *HALF; static const Type *MEMORY; static const Type *MULTI; @@ -518,6 +527,38 @@ class TypeF : public Type { #endif }; +// Class of Half Float-Constant Types. +class TypeH : public Type { + TypeH(short f) : Type(HalfFloatCon), _f(f) {}; +public: + virtual bool eq(const Type *t) const; + virtual uint hash() const; // Type specific hashing + virtual bool singleton(void) const; // TRUE if type is a singleton + virtual bool empty(void) const; // TRUE if type is vacuous +public: + const short _f; // Half Float constant + + static const TypeH *make(float f); + static const TypeH *make(short f); + + virtual bool is_finite() const; // Has a finite value + virtual bool is_nan() const; // Is not a number (NaN) + + virtual float getf() const; + virtual const Type *xmeet(const Type *t) const; + virtual const Type *xdual() const; // Compute dual right now. + // Convenience common pre-built types. + static const TypeH *MAX; + static const TypeH *MIN; + static const TypeH *ZERO; // positive zero only + static const TypeH *ONE; + static const TypeH *POS_INF; + static const TypeH *NEG_INF; +#ifndef PRODUCT + virtual void dump2(Dict &d, uint depth, outputStream *st) const; +#endif +}; + //------------------------------TypeD------------------------------------------ // Class of Double-Constant Types. class TypeD : public Type { @@ -1939,6 +1980,11 @@ inline float Type::getf() const { return ((TypeF*)this)->_f; } +inline short Type::geth() const { + assert( _base == HalfFloatCon, "Not a HalfFloatCon" ); + return ((TypeH*)this)->_f; +} + inline double Type::getd() const { assert( _base == DoubleCon, "Not a DoubleCon" ); return ((TypeD*)this)->_d; @@ -1971,6 +2017,21 @@ inline const TypeLong *Type::isa_long() const { return ( _base == Long ? (TypeLong*)this : nullptr); } +inline const TypeH *Type::isa_half_float() const { + return ((_base == HalfFloatTop || + _base == HalfFloatCon || + _base == HalfFloatBot) ? (TypeH*)this : nullptr); +} + +inline const TypeH *Type::is_half_float_constant() const { + assert( _base == HalfFloatCon, "Not a Float" ); + return (TypeH*)this; +} + +inline const TypeH *Type::isa_half_float_constant() const { + return ( _base == HalfFloatCon ? (TypeH*)this : nullptr); +} + inline const TypeF *Type::isa_float() const { return ((_base == FloatTop || _base == FloatCon || @@ -2160,7 +2221,8 @@ inline const TypeNarrowKlass* Type::make_narrowklass() const { } inline bool Type::is_floatingpoint() const { - if( (_base == FloatCon) || (_base == FloatBot) || + if( (_base == HalfFloatCon) || (_base == HalfFloatBot) || + (_base == FloatCon) || (_base == FloatBot) || (_base == DoubleCon) || (_base == DoubleBot) ) return true; return false; diff --git a/src/hotspot/share/opto/vectornode.cpp b/src/hotspot/share/opto/vectornode.cpp index 03e1e7ef247f9..e964dcafb1c22 100644 --- a/src/hotspot/share/opto/vectornode.cpp +++ b/src/hotspot/share/opto/vectornode.cpp @@ -46,6 +46,7 @@ int VectorNode::opcode(int sopc, BasicType bt) { case T_INT: return Op_AddVI; default: return 0; } + case Op_AddHF: return (bt == T_SHORT ? Op_AddVHF : 0); case Op_AddL: return (bt == T_LONG ? Op_AddVL : 0); case Op_AddF: return (bt == T_FLOAT ? Op_AddVF : 0); case Op_AddD: return (bt == T_DOUBLE ? Op_AddVD : 0); @@ -59,6 +60,7 @@ int VectorNode::opcode(int sopc, BasicType bt) { case T_INT: return Op_SubVI; default: return 0; } + case Op_SubHF: return (bt == T_SHORT ? Op_SubVHF : 0); case Op_SubL: return (bt == T_LONG ? Op_SubVL : 0); case Op_SubF: return (bt == T_FLOAT ? Op_SubVF : 0); case Op_SubD: return (bt == T_DOUBLE ? Op_SubVD : 0); @@ -73,6 +75,8 @@ int VectorNode::opcode(int sopc, BasicType bt) { default: return 0; } case Op_MulL: return (bt == T_LONG ? Op_MulVL : 0); + case Op_MulHF: + return (bt == T_SHORT ? Op_MulVHF : 0); case Op_MulF: return (bt == T_FLOAT ? Op_MulVF : 0); case Op_MulD: @@ -81,12 +85,16 @@ int VectorNode::opcode(int sopc, BasicType bt) { return (bt == T_DOUBLE ? Op_FmaVD : 0); case Op_FmaF: return (bt == T_FLOAT ? Op_FmaVF : 0); + case Op_FmaHF: + return (bt == T_SHORT ? Op_FmaVHF : 0); case Op_CMoveF: return (bt == T_FLOAT ? Op_VectorBlend : 0); case Op_CMoveD: return (bt == T_DOUBLE ? Op_VectorBlend : 0); case Op_Bool: return Op_VectorMaskCmp; + case Op_DivHF: + return (bt == T_SHORT ? Op_DivVHF : 0); case Op_DivF: return (bt == T_FLOAT ? Op_DivVF : 0); case Op_DivD: @@ -113,6 +121,8 @@ int VectorNode::opcode(int sopc, BasicType bt) { } case Op_MinL: return (bt == T_LONG ? Op_MinV : 0); + case Op_MinHF: + return (bt == T_SHORT ? Op_MinVHF : 0); case Op_MinF: return (bt == T_FLOAT ? Op_MinV : 0); case Op_MinD: @@ -128,6 +138,8 @@ int VectorNode::opcode(int sopc, BasicType bt) { } case Op_MaxL: return (bt == T_LONG ? Op_MaxV : 0); + case Op_MaxHF: + return (bt == T_SHORT ? Op_MaxVHF : 0); case Op_MaxF: return (bt == T_FLOAT ? Op_MaxV : 0); case Op_MaxD: @@ -157,6 +169,8 @@ int VectorNode::opcode(int sopc, BasicType bt) { return (is_integral_type(bt) ? Op_RotateRightV : 0); case Op_SqrtF: return (bt == T_FLOAT ? Op_SqrtVF : 0); + case Op_SqrtHF: + return (bt == T_SHORT ? Op_SqrtVHF : 0); case Op_SqrtD: return (bt == T_DOUBLE ? Op_SqrtVD : 0); case Op_RoundF: @@ -267,6 +281,9 @@ int VectorNode::opcode(int sopc, BasicType bt) { return Op_SignumVF; case Op_SignumD: return Op_SignumVD; + case Op_ReinterpretS2HF: + case Op_ReinterpretHF2S: + return Op_VectorReinterpret; default: assert(!VectorNode::is_convert_opcode(sopc), @@ -379,6 +396,10 @@ int VectorNode::scalar_opcode(int sopc, BasicType bt) { assert(false, "basic type not handled"); return 0; } + case Op_MinVHF: + return Op_MinHF; + case Op_MaxVHF: + return Op_MaxHF; default: assert(false, "Vector node %s is not handled in VectorNode::scalar_opcode", @@ -619,10 +640,10 @@ void VectorNode::vector_operands(Node* n, uint* start, uint* end) { *start = 1; *end = (n->is_Con() && Matcher::supports_vector_constant_rotates(n->get_int())) ? 2 : 3; break; - case Op_AddI: case Op_AddL: case Op_AddF: case Op_AddD: - case Op_SubI: case Op_SubL: case Op_SubF: case Op_SubD: - case Op_MulI: case Op_MulL: case Op_MulF: case Op_MulD: - case Op_DivF: case Op_DivD: + case Op_AddI: case Op_AddHF: case Op_AddL: case Op_AddF: case Op_AddD: + case Op_SubI: case Op_SubL: case Op_SubHF: case Op_SubF: case Op_SubD: + case Op_MulI: case Op_MulL: case Op_MulHF: case Op_MulF: case Op_MulD: + case Op_DivHF: case Op_DivF: case Op_DivD: case Op_AndI: case Op_AndL: case Op_OrI: case Op_OrL: case Op_XorI: case Op_XorL: @@ -676,33 +697,41 @@ VectorNode* VectorNode::make(int vopc, Node* n1, Node* n2, const TypeVect* vt, b } switch (vopc) { - case Op_AddVB: return new AddVBNode(n1, n2, vt); - case Op_AddVS: return new AddVSNode(n1, n2, vt); - case Op_AddVI: return new AddVINode(n1, n2, vt); - case Op_AddVL: return new AddVLNode(n1, n2, vt); - case Op_AddVF: return new AddVFNode(n1, n2, vt); - case Op_AddVD: return new AddVDNode(n1, n2, vt); - - case Op_SubVB: return new SubVBNode(n1, n2, vt); - case Op_SubVS: return new SubVSNode(n1, n2, vt); - case Op_SubVI: return new SubVINode(n1, n2, vt); - case Op_SubVL: return new SubVLNode(n1, n2, vt); - case Op_SubVF: return new SubVFNode(n1, n2, vt); - case Op_SubVD: return new SubVDNode(n1, n2, vt); - - case Op_MulVB: return new MulVBNode(n1, n2, vt); - case Op_MulVS: return new MulVSNode(n1, n2, vt); - case Op_MulVI: return new MulVINode(n1, n2, vt); - case Op_MulVL: return new MulVLNode(n1, n2, vt); - case Op_MulVF: return new MulVFNode(n1, n2, vt); - case Op_MulVD: return new MulVDNode(n1, n2, vt); - - case Op_DivVF: return new DivVFNode(n1, n2, vt); - case Op_DivVD: return new DivVDNode(n1, n2, vt); + case Op_AddVB: return new AddVBNode(n1, n2, vt); + case Op_AddVHF: return new AddVHFNode(n1, n2, vt); + case Op_AddVS: return new AddVSNode(n1, n2, vt); + case Op_AddVI: return new AddVINode(n1, n2, vt); + case Op_AddVL: return new AddVLNode(n1, n2, vt); + case Op_AddVF: return new AddVFNode(n1, n2, vt); + case Op_AddVD: return new AddVDNode(n1, n2, vt); + + case Op_SubVB: return new SubVBNode(n1, n2, vt); + case Op_SubVS: return new SubVSNode(n1, n2, vt); + case Op_SubVI: return new SubVINode(n1, n2, vt); + case Op_SubVL: return new SubVLNode(n1, n2, vt); + case Op_SubVHF: return new SubVHFNode(n1, n2, vt); + case Op_SubVF: return new SubVFNode(n1, n2, vt); + case Op_SubVD: return new SubVDNode(n1, n2, vt); + + case Op_MulVB: return new MulVBNode(n1, n2, vt); + case Op_MulVS: return new MulVSNode(n1, n2, vt); + case Op_MulVI: return new MulVINode(n1, n2, vt); + case Op_MulVL: return new MulVLNode(n1, n2, vt); + case Op_MulVHF: return new MulVHFNode(n1, n2, vt); + case Op_MulVF: return new MulVFNode(n1, n2, vt); + case Op_MulVD: return new MulVDNode(n1, n2, vt); + + case Op_DivVHF: return new DivVHFNode(n1, n2, vt); + case Op_DivVF: return new DivVFNode(n1, n2, vt); + case Op_DivVD: return new DivVDNode(n1, n2, vt); case Op_MinV: return new MinVNode(n1, n2, vt); case Op_MaxV: return new MaxVNode(n1, n2, vt); + case Op_MinVHF: return new MinVHFNode(n1, n2, vt); + case Op_MaxVHF: return new MaxVHFNode(n1, n2, vt); + + case Op_AbsVF: return new AbsVFNode(n1, vt); case Op_AbsVD: return new AbsVDNode(n1, vt); case Op_AbsVB: return new AbsVBNode(n1, vt); @@ -718,8 +747,9 @@ VectorNode* VectorNode::make(int vopc, Node* n1, Node* n2, const TypeVect* vt, b case Op_ReverseV: return new ReverseVNode(n1, vt); case Op_ReverseBytesV: return new ReverseBytesVNode(n1, vt); - case Op_SqrtVF: return new SqrtVFNode(n1, vt); - case Op_SqrtVD: return new SqrtVDNode(n1, vt); + case Op_SqrtVHF : return new SqrtVHFNode(n1, vt); + case Op_SqrtVF : return new SqrtVFNode(n1, vt); + case Op_SqrtVD : return new SqrtVDNode(n1, vt); case Op_RoundVF: return new RoundVFNode(n1, vt); case Op_RoundVD: return new RoundVDNode(n1, vt); @@ -788,6 +818,7 @@ VectorNode* VectorNode::make(int vopc, Node* n1, Node* n2, Node* n3, const TypeV switch (vopc) { case Op_FmaVD: return new FmaVDNode(n1, n2, n3, vt); case Op_FmaVF: return new FmaVFNode(n1, n2, n3, vt); + case Op_FmaVHF: return new FmaVHFNode(n1, n2, n3, vt); case Op_SelectFromTwoVector: return new SelectFromTwoVectorNode(n1, n2, n3, vt); case Op_SignumVD: return new SignumVDNode(n1, n2, n3, vt); case Op_SignumVF: return new SignumVFNode(n1, n2, n3, vt); @@ -930,10 +961,21 @@ bool VectorNode::is_vector_bitwise_not_pattern(Node* n) { return false; } +bool VectorNode::is_reinterpret_opcode(int opc) { + switch (opc) { + case Op_ReinterpretHF2S: + case Op_ReinterpretS2HF: + return true; + default: + return false; + } +} + bool VectorNode::is_scalar_unary_op_with_equal_input_and_output_types(int opc) { switch (opc) { case Op_SqrtF: case Op_SqrtD: + case Op_SqrtHF: case Op_AbsF: case Op_AbsD: case Op_AbsI: @@ -1432,6 +1474,11 @@ VectorCastNode* VectorCastNode::make(int vopc, Node* n1, BasicType bt, uint vlen } } + +VectorNode* VectorReinterpretNode::make(Node* n, const TypeVect* dst_vt, const TypeVect* src_vt) { + return new VectorReinterpretNode(n, dst_vt, src_vt); +} + int VectorCastNode::opcode(int sopc, BasicType bt, bool is_signed) { assert((is_integral_type(bt) && bt != T_LONG) || is_signed, ""); diff --git a/src/hotspot/share/opto/vectornode.hpp b/src/hotspot/share/opto/vectornode.hpp index 3f737e6e88173..542065698f24a 100644 --- a/src/hotspot/share/opto/vectornode.hpp +++ b/src/hotspot/share/opto/vectornode.hpp @@ -86,6 +86,7 @@ class VectorNode : public TypeNode { static bool is_shift_opcode(int opc); static bool can_use_RShiftI_instead_of_URShiftI(Node* n, BasicType bt); static bool is_convert_opcode(int opc); + static bool is_reinterpret_opcode(int opc); static bool is_minmax_opcode(int opc); static bool is_vshift_cnt_opcode(int opc); @@ -103,6 +104,7 @@ class VectorNode : public TypeNode { static bool is_muladds2i(const Node* n); static bool is_roundopD(Node* n); static bool is_scalar_rotate(Node* n); + static bool is_vector_rotate_supported(int opc, uint vlen, BasicType bt); static bool is_vector_integral_negate_supported(int opc, uint vlen, BasicType bt, bool use_predicate); static bool is_populate_index_supported(BasicType bt); @@ -211,6 +213,14 @@ class AddVFNode : public VectorNode { virtual int Opcode() const; }; +//------------------------------AddVHFNode-------------------------------------- +// Vector add float +class AddVHFNode : public VectorNode { +public: + AddVHFNode(Node* in1, Node* in2, const TypeVect* vt) : VectorNode(in1, in2, vt) {} + virtual int Opcode() const; +}; + //------------------------------AddVDNode-------------------------------------- // Vector add double class AddVDNode : public VectorNode { @@ -381,6 +391,14 @@ class SubVLNode : public VectorNode { virtual int Opcode() const; }; +//------------------------------SubVHFNode-------------------------------------- +// Vector subtract half float +class SubVHFNode : public VectorNode { +public: + SubVHFNode(Node* in1, Node* in2, const TypeVect* vt) : VectorNode(in1, in2, vt) {} + virtual int Opcode() const; +}; + //------------------------------SaturatingAddVNode----------------------------- // Vector saturating addition. class SaturatingAddVNode : public SaturatingVectorNode { @@ -449,6 +467,14 @@ class MulVLNode : public VectorNode { bool has_uint_inputs() const; }; +//------------------------------MulVFNode-------------------------------------- +// Vector multiply half float +class MulVHFNode : public VectorNode { +public: + MulVHFNode(Node* in1, Node* in2, const TypeVect* vt) : VectorNode(in1, in2, vt) {} + virtual int Opcode() const; +}; + //------------------------------MulVFNode-------------------------------------- // Vector multiply float class MulVFNode : public VectorNode { @@ -499,6 +525,14 @@ class FmaVFNode : public FmaVNode { virtual int Opcode() const; }; +//------------------------------FmaVHFNode------------------------------------- +// Vector fused-multiply-add half-precision float +class FmaVHFNode : public FmaVNode { +public: + FmaVHFNode(Node* in1, Node* in2, Node* in3, const TypeVect* vt) : FmaVNode(in1, in2, in3, vt) {} + virtual int Opcode() const; +}; + //------------------------------MulReductionVINode-------------------------------------- // Vector multiply byte, short and int as a reduction class MulReductionVINode : public ReductionNode { @@ -583,6 +617,15 @@ class DivVDNode : public VectorNode { virtual int Opcode() const; }; +//------------------------------DivVHFNode------------------------------------- +// Vector divide half float +class DivVHFNode : public VectorNode { +public: + DivVHFNode(Node* in1, Node* in2, const TypeVect* vt) : VectorNode(in1, in2, vt) {} + virtual int Opcode() const; +}; + + //------------------------------AbsVBNode-------------------------------------- // Vector Abs byte class AbsVBNode : public VectorNode { @@ -623,6 +666,22 @@ class MaxVNode : public VectorNode { virtual int Opcode() const; }; +//------------------------------MinVHFNode------------------------------------ +// Vector Min for half floats +class MinVHFNode : public VectorNode { +public: + MinVHFNode(Node* in1, Node* in2, const TypeVect* vt) : VectorNode(in1, in2, vt) {} + virtual int Opcode() const; +}; + +//------------------------------MaxVHFNode------------------------------------ +// Vector Max for half floats +class MaxVHFNode : public VectorNode { +public: + MaxVHFNode(Node* in1, Node* in2, const TypeVect* vt) : VectorNode(in1, in2, vt) {} + virtual int Opcode() const; +}; + class UMaxVNode : public VectorNode { public: UMaxVNode(Node* in1, Node* in2, const TypeVect* vt) : VectorNode(in1, in2, vt) { @@ -734,6 +793,15 @@ class SqrtVFNode : public VectorNode { SqrtVFNode(Node* in, const TypeVect* vt) : VectorNode(in,vt) {} virtual int Opcode() const; }; + +//------------------------------SqrtVHFNode------------------------------------- +// Vector Sqrt half-precision float +class SqrtVHFNode : public VectorNode { +public: + SqrtVHFNode(Node* in, const TypeVect* vt) : VectorNode(in, vt) {} + virtual int Opcode() const; +}; + //------------------------------RoundDoubleVNode-------------------------------- // Vector round double class RoundDoubleModeVNode : public VectorNode { @@ -1754,6 +1822,7 @@ class VectorReinterpretNode : public VectorNode { virtual Node* Identity(PhaseGVN* phase); virtual int Opcode() const; + static VectorNode* make(Node* n, const TypeVect* dst_vt, const TypeVect* src_vt); }; class VectorCastNode : public VectorNode { diff --git a/src/hotspot/share/opto/vtransform.cpp b/src/hotspot/share/opto/vtransform.cpp index d09a4c899f685..f998310344101 100644 --- a/src/hotspot/share/opto/vtransform.cpp +++ b/src/hotspot/share/opto/vtransform.cpp @@ -495,6 +495,10 @@ VTransformApplyResult VTransformElementWiseVectorNode::apply(const VLoopAnalyzer assert(first->req() == 2 && req() == 2, "only one input expected"); int vopc = VectorCastNode::opcode(opc, in1->bottom_type()->is_vect()->element_basic_type()); vn = VectorCastNode::make(vopc, in1, bt, vlen); + } else if (VectorNode::is_reinterpret_opcode(opc)) { + assert(first->req() == 2 && req() == 2, "only one input expected"); + const TypeVect* vt = TypeVect::make(bt, vlen); + vn = VectorReinterpretNode::make(in1, vt, in1->bottom_type()->is_vect()); } else if (VectorNode::can_use_RShiftI_instead_of_URShiftI(first, bt)) { opc = Op_RShiftI; vn = VectorNode::make(opc, in1, in2, vlen, bt); @@ -511,8 +515,9 @@ VTransformApplyResult VTransformElementWiseVectorNode::apply(const VLoopAnalyzer vn = VectorNode::make(opc, in1, in2, vlen, bt); // unary and binary } else { assert(req() == 4, "three inputs expected"); - assert(opc == Op_FmaD || - opc == Op_FmaF || + assert(opc == Op_FmaD || + opc == Op_FmaHF || + opc == Op_FmaF || opc == Op_SignumF || opc == Op_SignumD, "element wise operation must be from this list"); diff --git a/src/hotspot/share/utilities/globalDefinitions.hpp b/src/hotspot/share/utilities/globalDefinitions.hpp index ccd3106b471a2..6ed3979dee2d1 100644 --- a/src/hotspot/share/utilities/globalDefinitions.hpp +++ b/src/hotspot/share/utilities/globalDefinitions.hpp @@ -560,6 +560,11 @@ const jfloat min_jfloat = jfloat_cast(min_jintFloat); const jint max_jintFloat = (jint)(0x7f7fffff); const jfloat max_jfloat = jfloat_cast(max_jintFloat); +const jshort max_jfloat16 = 31743; +const jshort min_jfloat16 = 1; +const jshort one_jfloat16 = 15360; +const jshort pos_inf_jfloat16 = 31744; +const jshort neg_inf_jfloat16 = -1024; // A named constant for the integral representation of a Java null. const intptr_t NULL_WORD = 0; @@ -918,6 +923,7 @@ class JavaValue { void set_jfloat(jfloat f) { _value.f = f;} void set_jdouble(jdouble d) { _value.d = d;} void set_jint(jint i) { _value.i = i;} + void set_jshort(jshort i) { _value.i = i;} void set_jlong(jlong l) { _value.l = l;} void set_jobject(jobject h) { _value.h = h;} void set_oop(oopDesc* o) { _value.o = o;} diff --git a/src/java.base/share/classes/jdk/internal/vm/vector/Float16Math.java b/src/java.base/share/classes/jdk/internal/vm/vector/Float16Math.java new file mode 100644 index 0000000000000..03ae9226225df --- /dev/null +++ b/src/java.base/share/classes/jdk/internal/vm/vector/Float16Math.java @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package jdk.internal.vm.vector; + +import jdk.internal.vm.annotation.IntrinsicCandidate; +import jdk.internal.vm.annotation.ForceInline; + +import static java.lang.Float.*; + +/** + * The class {@code Float16Math} constains intrinsic entry points corresponding + * to scalar numeric operations defined in Float16 class. + * @author + * @since 24 + */ +public final class Float16Math { + private Float16Math() { + } + + public interface Float16UnaryMathOp { + short apply(short a); + } + + public interface Float16TernaryMathOp { + short apply(short a, short b, short c); + } + + @IntrinsicCandidate + public static short sqrt(short a, Float16UnaryMathOp defaultImpl) { + assert isNonCapturingLambda(defaultImpl) : defaultImpl; + return defaultImpl.apply(a); + } + + @IntrinsicCandidate + public static short fma(short a, short b, short c, Float16TernaryMathOp defaultImpl) { + assert isNonCapturingLambda(defaultImpl) : defaultImpl; + return defaultImpl.apply(a, b, c); + } + + public static boolean isNonCapturingLambda(Object o) { + return o.getClass().getDeclaredFields().length == 0; + } +} diff --git a/src/jdk.incubator.vector/share/classes/jdk/incubator/vector/Float16.java b/src/jdk.incubator.vector/share/classes/jdk/incubator/vector/Float16.java index f5f5a5a4e7ea7..50f6cc29f3cf3 100644 --- a/src/jdk.incubator.vector/share/classes/jdk/incubator/vector/Float16.java +++ b/src/jdk.incubator.vector/share/classes/jdk/incubator/vector/Float16.java @@ -39,6 +39,8 @@ import static java.lang.Float.floatToFloat16; import static java.lang.Integer.numberOfLeadingZeros; import static java.lang.Math.multiplyHigh; +import jdk.internal.vm.annotation.ForceInline; +import jdk.internal.vm.vector.Float16Math; /** * The {@code Float16} is a class holding 16-bit data @@ -321,8 +323,10 @@ public static Float16 valueOf(long value) { * * @param f a {@code float} */ + @ForceInline public static Float16 valueOf(float f) { - return new Float16(floatToFloat16(f)); + short hf = floatToFloat16(f); + return new Float16(hf); } /** @@ -764,6 +768,7 @@ public static boolean isFinite(Float16 f16) { * @jls 5.1.3 Narrowing Primitive Conversion */ @Override + @ForceInline public byte byteValue() { return (byte)floatValue(); } @@ -785,6 +790,7 @@ public String toString() { * @jls 5.1.3 Narrowing Primitive Conversion */ @Override + @ForceInline public short shortValue() { return (short)floatValue(); } @@ -800,6 +806,7 @@ public short shortValue() { * @jls 5.1.3 Narrowing Primitive Conversion */ @Override + @ForceInline public int intValue() { return (int)floatValue(); } @@ -830,6 +837,7 @@ public long longValue() { * @jls 5.1.2 Widening Primitive Conversion */ @Override + @ForceInline public float floatValue() { return float16ToFloat(value); } @@ -845,6 +853,7 @@ public float floatValue() { * @jls 5.1.2 Widening Primitive Conversion */ @Override + @ForceInline public double doubleValue() { return (double)floatValue(); } @@ -1196,7 +1205,9 @@ public static Float16 sqrt(Float16 radicand) { // Float16 -> double preserves the exact numerical value. The // conversion of double -> Float16 also benefits from the // 2p+2 property of IEEE 754 arithmetic. - return valueOf(Math.sqrt(radicand.doubleValue())); + short res = Float16Math.sqrt(float16ToRawShortBits(radicand), + (f16) -> float16ToRawShortBits(valueOf(Math.sqrt(shortBitsToFloat16(f16).doubleValue())))); + return shortBitsToFloat16(res); } /** @@ -1401,8 +1412,15 @@ public static Float16 fma(Float16 a, Float16 b, Float16 c) { // product is numerically exact in float before the cast to // double; not necessary to widen to double before the // multiply. - double product = (double)(a.floatValue() * b.floatValue()); - return valueOf(product + c.doubleValue()); + short fa = float16ToRawShortBits(a); + short fb = float16ToRawShortBits(b); + short fc = float16ToRawShortBits(c); + short res = Float16Math.fma(fa, fb, fc, + (f16a, f16b, f16c) -> { + double product = (double)(float16ToFloat(f16a) * float16ToFloat(f16b)); + return float16ToRawShortBits(valueOf(product + float16ToFloat(f16c))); + }); + return shortBitsToFloat16(res); } /** diff --git a/src/jdk.internal.vm.ci/share/classes/jdk/vm/ci/amd64/AMD64.java b/src/jdk.internal.vm.ci/share/classes/jdk/vm/ci/amd64/AMD64.java index eda1eada737b3..32600c4c9b5d7 100644 --- a/src/jdk.internal.vm.ci/share/classes/jdk/vm/ci/amd64/AMD64.java +++ b/src/jdk.internal.vm.ci/share/classes/jdk/vm/ci/amd64/AMD64.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2009, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2009, 2024, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -257,6 +257,7 @@ public enum CPUFeature implements CPUFeatureName { AVX_IFMA, APX_F, SHA512, + AVX512_FP16, } private final EnumSet features; diff --git a/test/hotspot/jtreg/compiler/c2/irTests/ConvF2HFIdealizationTests.java b/test/hotspot/jtreg/compiler/c2/irTests/ConvF2HFIdealizationTests.java new file mode 100644 index 0000000000000..8f68751612c34 --- /dev/null +++ b/test/hotspot/jtreg/compiler/c2/irTests/ConvF2HFIdealizationTests.java @@ -0,0 +1,73 @@ +/* + * Copyright (c) 2024, Arm Limited. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package compiler.c2.irTests; + +import compiler.lib.ir_framework.*; +import jdk.incubator.vector.Float16; +import static jdk.incubator.vector.Float16.*; +import jdk.test.lib.Asserts; + +/* + * @test + * @bug 8338061 + * @summary Test that Ideal transformations of ConvF2HF are being performed as expected. + * @modules jdk.incubator.vector + * @library /test/lib / + * @run driver compiler.c2.irTests.ConvF2HFIdealizationTests + */ +public class ConvF2HFIdealizationTests { + private short[] sin; + private short[] sout; + private static final int SIZE = 65504; + public ConvF2HFIdealizationTests() { + sin = new short[SIZE]; + sout = new short[SIZE]; + for (int i = 0; i < SIZE; i++) { + sin[i] = Float.floatToFloat16((float)i); + } + } + public static void main(String[] args) { + TestFramework.runWithFlags("--add-modules=jdk.incubator.vector", "-XX:-UseSuperWord"); + } + + @Test + @IR(counts = {IRNode.REINTERPRET_S2HF, ">=1", IRNode.REINTERPRET_HF2S, ">=1", IRNode.ADD_HF, ">=1" }, + failOn = {IRNode.ADD_F, IRNode.CONV_HF2F, IRNode.CONV_F2HF}, + applyIfCPUFeature = {"avx512_fp16", "true"}) + // Test pattern - ConvHF2F -> AddF -> ConvF2HF is optimized to ReinterpretS2HF -> AddHF -> ReinterpretHF2S + public void test1() { + for (int i = 0; i < SIZE; i++) { + sout[i] = Float.floatToFloat16(Float.float16ToFloat(sin[i]) + Float.float16ToFloat(sin[i])); + } + } + + @Check(test="test1") + public void checkResult() { + for (int i = 0; i < SIZE; i++) { + short expected = Float16.float16ToRawShortBits(Float16.add(Float16.shortBitsToFloat16(sin[i]), Float16.shortBitsToFloat16(sin[i]))); + if (expected != sout[i]) { + throw new RuntimeException("Invalid result: sout[" + i + "] = " + sout[i] + " != " + expected); + } + } + } +} diff --git a/test/hotspot/jtreg/compiler/c2/irTests/MulHFNodeIdealizationTests.java b/test/hotspot/jtreg/compiler/c2/irTests/MulHFNodeIdealizationTests.java new file mode 100644 index 0000000000000..36447f7c53f88 --- /dev/null +++ b/test/hotspot/jtreg/compiler/c2/irTests/MulHFNodeIdealizationTests.java @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2024, Arm Limited. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package compiler.c2.irTests; + +import compiler.lib.ir_framework.*; +import jdk.incubator.vector.Float16; +import static jdk.incubator.vector.Float16.*; +import java.util.Random; +import jdk.test.lib.Asserts; + +/* + * @test + * @bug 8336406 + * @summary Test that Ideal transformations of MulHFNode are being performed as expected. + * @modules jdk.incubator.vector + * @library /test/lib / + * @run driver compiler.c2.irTests.MulHFNodeIdealizationTests + */ +public class MulHFNodeIdealizationTests { + + private Float16 src; + private Float16 dst; + private Random rng; + + public static void main(String[] args) { + TestFramework.runWithFlags("--add-modules=jdk.incubator.vector"); + } + + public MulHFNodeIdealizationTests() { + rng = new Random(25); + src = valueOf(rng.nextFloat()); + dst = valueOf(rng.nextFloat()); + } + + @Test + @IR(counts = {IRNode.ADD_HF, "1"}, + applyIfCPUFeature = {"avx512_fp16", "true"}, + failOn = {IRNode.MUL_HF}) + public void test1() { + dst = multiply(src, valueOf(2.0f)); + } + + @Check(test="test1") + public void checkTest1() { + Float16 expected = valueOf(src.floatValue() * 2.0f); + if (float16ToRawShortBits(expected) != float16ToRawShortBits(dst)) { + throw new RuntimeException("Invalid result: dst = " + float16ToRawShortBits(dst) + " != " + float16ToRawShortBits(expected)); + } + } +} diff --git a/test/hotspot/jtreg/compiler/c2/irTests/TestFloat16ScalarOperations.java b/test/hotspot/jtreg/compiler/c2/irTests/TestFloat16ScalarOperations.java new file mode 100644 index 0000000000000..7b8afe0a570d1 --- /dev/null +++ b/test/hotspot/jtreg/compiler/c2/irTests/TestFloat16ScalarOperations.java @@ -0,0 +1,450 @@ +/* + * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2024, Arm Limited. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/** +* @test +* @bug 8308363 8336406 +* @summary Validate compiler IR for various Float16 scalar operations. +* @modules jdk.incubator.vector +* @requires vm.compiler2.enabled +* @library /test/lib / +* @run driver TestFloat16ScalarOperations +*/ +import compiler.lib.ir_framework.*; +import jdk.incubator.vector.Float16; +import static jdk.incubator.vector.Float16.*; +import java.util.Random; + +public class TestFloat16ScalarOperations { + private static final int count = 1024; + + private short[] src; + private short[] dst; + private short res; + + private static final Float16 POSITIVE_ZERO = valueOf(0.0f); + private static final Float16 NEGATIVE_ZERO = valueOf(-0.0f); + private static final Float16 MIN_NORMAL = valueOf(0x1.0P-14f); + private static final Float16 NEGATIVE_MAX_VALUE = valueOf(-0x1.ffcP+15f); + private static final Float16 LT_MAX_HALF_ULP = Float16.valueOf(14.0f); + private static final Float16 MAX_HALF_ULP = Float16.valueOf(16.0f); + + public static void main(String args[]) { + TestFramework.runWithFlags("--add-modules=jdk.incubator.vector"); + } + + public TestFloat16ScalarOperations() { + src = new short[count]; + dst = new short[count]; + Random r = jdk.test.lib.Utils.getRandomInstance(); + for (int i = 0; i < count; i++) { + src[i] = Float.floatToFloat16(r.nextFloat()); + } + } + + static void assertResult(float actual, float expected, String msg) { + if (actual != expected) { + if (!Float.isNaN(actual) || !Float.isNaN(expected)) { + String error = "TEST : " + msg + ": actual(" + actual + ") != expected(" + expected + ")"; + throw new AssertionError(error); + } + } + } + + static void assertResult(float actual, float expected, String msg, int iter) { + if (actual != expected) { + if (!Float.isNaN(actual) || !Float.isNaN(expected)) { + String error = "TEST (" + iter + "): " + msg + ": actual(" + actual + ") != expected(" + expected + ")"; + throw new AssertionError(error); + } + } + } + + @Test + @IR(counts = {IRNode.ADD_HF, " >0 ", IRNode.REINTERPRET_S2HF, " >0 ", IRNode.REINTERPRET_HF2S, " >0 "}, + applyIfCPUFeature = {"avx512_fp16", "true"}) + public void testAdd1() { + Float16 res = shortBitsToFloat16((short)0); + for (int i = 0; i < count; i++) { + res = Float16.add(res, shortBitsToFloat16(src[i])); + dst[i] = float16ToRawShortBits(res); + } + } + + @Test + @IR(failOn = {IRNode.ADD_HF, IRNode.REINTERPRET_S2HF, IRNode.REINTERPRET_HF2S}, + applyIfCPUFeature = {"avx512_fp16", "true"}) + public void testAdd2() { + Float16 hf0 = shortBitsToFloat16((short)0); + Float16 hf1 = shortBitsToFloat16((short)15360); + Float16 hf2 = shortBitsToFloat16((short)16384); + Float16 hf3 = shortBitsToFloat16((short)16896); + Float16 hf4 = shortBitsToFloat16((short)17408); + res = float16ToRawShortBits(Float16.add(Float16.add(Float16.add(Float16.add(hf0, hf1), hf2), hf3), hf4)); + } + + @Test + @IR(counts = {IRNode.SUB_HF, " >0 ", IRNode.REINTERPRET_S2HF, " >0 ", IRNode.REINTERPRET_HF2S, " >0 "}, + applyIfCPUFeature = {"avx512_fp16", "true"}) + public void testSub() { + Float16 res = shortBitsToFloat16((short)0); + for (int i = 0; i < count; i++) { + res = Float16.subtract(res, shortBitsToFloat16(src[i])); + dst[i] = float16ToRawShortBits(res); + } + } + + @Test + @IR(counts = {IRNode.MUL_HF, " >0 ", IRNode.REINTERPRET_S2HF, " >0 ", IRNode.REINTERPRET_HF2S, " >0 "}, + applyIfCPUFeature = {"avx512_fp16", "true"}) + public void testMul() { + Float16 res = shortBitsToFloat16((short)0); + for (int i = 0; i < count; i++) { + res = Float16.multiply(res, shortBitsToFloat16(src[i])); + dst[i] = float16ToRawShortBits(res); + } + } + + @Test + @IR(counts = {IRNode.DIV_HF, " >0 ", IRNode.REINTERPRET_S2HF, " >0 ", IRNode.REINTERPRET_HF2S, " >0 "}, + applyIfCPUFeature = {"avx512_fp16", "true"}) + public void testDiv() { + Float16 res = shortBitsToFloat16((short)0); + for (int i = 0; i < count; i++) { + res = Float16.divide(res, shortBitsToFloat16(src[i])); + dst[i] = float16ToRawShortBits(res); + } + } + + @Test + @IR(counts = {IRNode.MAX_HF, " >0 ", IRNode.REINTERPRET_S2HF, " >0 ", IRNode.REINTERPRET_HF2S, " >0 "}, + applyIfCPUFeature = {"avx512_fp16", "true"}) + public void testMax() { + Float16 res = shortBitsToFloat16((short)0); + for (int i = 0; i < count; i++) { + res = Float16.max(res, shortBitsToFloat16(src[i])); + dst[i] = float16ToRawShortBits(res); + } + } + + @Test + @IR(counts = {IRNode.MIN_HF, " >0 ", IRNode.REINTERPRET_S2HF, " >0 ", IRNode.REINTERPRET_HF2S, " >0 "}, + applyIfCPUFeature = {"avx512_fp16", "true"}) + public void testMin() { + Float16 res = shortBitsToFloat16((short)0); + for (int i = 0; i < count; i++) { + res = Float16.min(res, shortBitsToFloat16(src[i])); + dst[i] = float16ToRawShortBits(res); + } + } + + @Test + @IR(counts = {IRNode.SQRT_HF, " >0 ", IRNode.REINTERPRET_S2HF, " >0 ", IRNode.REINTERPRET_HF2S, " >0 "}, + applyIfCPUFeature = {"avx512_fp16", "true"}) + public void testSqrt() { + Float16 res = shortBitsToFloat16((short)0); + for (int i = 0; i < count; i++) { + res = Float16.sqrt(shortBitsToFloat16(src[i])); + dst[i] = float16ToRawShortBits(res); + } + } + + @Test + @IR(counts = {IRNode.FMA_HF, " >0 ", IRNode.REINTERPRET_S2HF, " >0 ", IRNode.REINTERPRET_HF2S, " >0 "}, + applyIfCPUFeature = {"avx512_fp16", "true"}) + public void testFma() { + Float16 res = shortBitsToFloat16((short)0); + for (int i = 0; i < count; i++) { + Float16 in = shortBitsToFloat16(src[i]); + res = Float16.fma(in, in, in); + dst[i] = float16ToRawShortBits(res); + } + } + + @Test + @IR(counts = {IRNode.MUL_HF, " >0 ", IRNode.REINTERPRET_S2HF, " >0 ", IRNode.REINTERPRET_HF2S, " >0 "}, + applyIfCPUFeature = {"avx512_fp16", "true"}) + public void testDivByPOT() { + Float16 actual = valueOf(0.0f); + for (int i = 0; i < 50; i++) { + Float16 divisor = valueOf(8.0f); + Float16 dividend = shortBitsToFloat16(src[i]); + actual = add(actual, divide(dividend, divisor)); + divisor = valueOf(16.0f); + actual = add(actual, divide(dividend, divisor)); + divisor = valueOf(32.0f); + actual = add(actual, divide(dividend, divisor)); + } + float expected = 0.0f; + for (int i = 0; i < 50; i++) { + float dividend = Float.float16ToFloat(src[i]); + expected += dividend / 8.0f; + expected += dividend / 16.0f; + expected += dividend / 32.0f; + } + assertResult(Math.round(actual.floatValue()), Math.round(expected), "testDivByPOT"); + } + + @Test + @IR(counts = {IRNode.MUL_HF, " 0 ", IRNode.ADD_HF, " >0 ", IRNode.REINTERPRET_S2HF, " >0 ", IRNode.REINTERPRET_HF2S, " >0 "}, + applyIfCPUFeature = {"avx512_fp16", "true"}) + public void testMulByTWO() { + Float16 res = valueOf(0.0f); + Float16 multiplier = valueOf(2.0f); + for (int i = 0; i < 20; i++) { + Float16 multiplicand = valueOf((float)i); + res = add(res, multiply(multiplicand, multiplier)); + } + assertResult(res.floatValue(), (float)((20 * (20 - 1))/2) * 2.0f, "testMulByTWO"); + } + + + // + // Tests points for various Float16 constant folding transforms. Following figure represents various + // special IEEE 754 binary16 values on a number line + // + // -Inf -0.0 Inf + // -------|-----------------------------|----------------------------|------ + // -MAX_VALUE 0.0 MAX_VALUE + // + // Number whose exponent lie between -14 and 15, both values inclusive, belongs to normal value range. + // IEEE 754 binary16 specification allows graceful degradation of numbers with exponents less than -14 + // into a sub-normal value range i.e. their exponents may extend uptill -24, this is because format + // supports 10 mantissa bits which can be used to represent a number with exponents less than -14. + // + // A number below the sub-normal value range is considered as 0.0. With regards to overflowing + // semantics, a value equal to or greater than MAX_VALUE + half ulp (MAX_VALUE) is considered as + // an Infinite value on both side of axis. + // + // In addition, format specifies special bit representation for +Inf, -Inf and NaN values. + // + // Tests also covers special cases for various operations as per Java SE specification. + // + + + @Test + @IR(counts = {IRNode.ADD_HF, " 0 ", IRNode.REINTERPRET_S2HF, " 0 ", IRNode.REINTERPRET_HF2S, " 0 "}, + applyIfCPUFeature = {"avx512_fp16", "true"}) + public void testAddConstantFolding() { + // If either value is NaN, then the result is NaN. + assertResult(add(Float16.NaN, valueOf(2.0f)).floatValue(), Float.NaN, "testAddConstantFolding"); + assertResult(add(Float16.NaN, Float16.NaN).floatValue(), Float.NaN, "testAddConstantFolding"); + assertResult(add(Float16.NaN, Float16.POSITIVE_INFINITY).floatValue(), Float.NaN, "testAddConstantFolding"); + + // The sum of two infinities of opposite sign is NaN. + assertResult(add(Float16.POSITIVE_INFINITY, Float16.NEGATIVE_INFINITY).floatValue(), Float.NaN, "testAddConstantFolding"); + + // The sum of two infinities of the same sign is the infinity of that sign. + assertResult(add(Float16.POSITIVE_INFINITY, Float16.POSITIVE_INFINITY).floatValue(), Float.POSITIVE_INFINITY, "testAddConstantFolding"); + assertResult(add(Float16.NEGATIVE_INFINITY, Float16.NEGATIVE_INFINITY).floatValue(), Float.NEGATIVE_INFINITY, "testAddConstantFolding"); + + // The sum of an infinity and a finite value is equal to the infinite operand. + assertResult(add(Float16.POSITIVE_INFINITY, valueOf(2.0f)).floatValue(), Float.POSITIVE_INFINITY, "testAddConstantFolding"); + assertResult(add(Float16.NEGATIVE_INFINITY, valueOf(2.0f)).floatValue(), Float.NEGATIVE_INFINITY, "testAddConstantFolding"); + + // The sum of two zeros of opposite sign is positive zero. + assertResult(add(NEGATIVE_ZERO, POSITIVE_ZERO).floatValue(), 0.0f, "testAddConstantFolding"); + + // The sum of two zeros of the same sign is the zero of that sign. + assertResult(add(NEGATIVE_ZERO, NEGATIVE_ZERO).floatValue(), -0.0f, "testAddConstantFolding"); + + // The sum of a zero and a nonzero finite value is equal to the nonzero operand. + assertResult(add(POSITIVE_ZERO, valueOf(2.0f)).floatValue(), 2.0f, "testAddConstantFolding"); + assertResult(add(NEGATIVE_ZERO, valueOf(2.0f)).floatValue(), 2.0f, "testAddConstantFolding"); + + // Number equal to MAX_VALUE when added to half upl for MAX_VALUE results into Inf. + assertResult(add(Float16.MAX_VALUE, MAX_HALF_ULP).floatValue(), Float.POSITIVE_INFINITY, "testAddConstantFolding"); + + // If the magnitude of the sum is too large to represent, we say the operation + // overflows; the result is then an infinity of appropriate sign. + assertResult(add(Float16.MAX_VALUE, Float16.MAX_VALUE).floatValue(), Float.POSITIVE_INFINITY, "testAddConstantFolding"); + + // Number equal to MAX_VALUE when added to half upl for MAX_VALUE results into MAX_VALUE. + assertResult(add(Float16.MAX_VALUE, LT_MAX_HALF_ULP).floatValue(), Float16.MAX_VALUE.floatValue(), "testAddConstantFolding"); + + assertResult(add(valueOf(1.0f), valueOf(2.0f)).floatValue(), 3.0f, "testAddConstantFolding"); + } + + @Test + @IR(counts = {IRNode.SUB_HF, " 0 ", IRNode.REINTERPRET_S2HF, " 0 ", IRNode.REINTERPRET_HF2S, " 0 "}, + applyIfCPUFeature = {"avx512_fp16", "true"}) + public void testSubConstantFolding() { + // If either value is NaN, then the result is NaN. + assertResult(subtract(Float16.NaN, valueOf(2.0f)).floatValue(), Float.NaN, "testAddConstantFolding"); + assertResult(subtract(Float16.NaN, Float16.NaN).floatValue(), Float.NaN, "testAddConstantFolding"); + assertResult(subtract(Float16.NaN, Float16.POSITIVE_INFINITY).floatValue(), Float.NaN, "testAddConstantFolding"); + + // The difference of two infinities of opposite sign is NaN. + assertResult(subtract(Float16.POSITIVE_INFINITY, Float16.NEGATIVE_INFINITY).floatValue(), Float.POSITIVE_INFINITY, "testAddConstantFolding"); + + // The difference of two infinities of the same sign is NaN. + assertResult(subtract(Float16.POSITIVE_INFINITY, Float16.POSITIVE_INFINITY).floatValue(), Float.NaN, "testAddConstantFolding"); + assertResult(subtract(Float16.NEGATIVE_INFINITY, Float16.NEGATIVE_INFINITY).floatValue(), Float.NaN, "testAddConstantFolding"); + + // The difference of an infinity and a finite value is equal to the infinite operand. + assertResult(subtract(Float16.POSITIVE_INFINITY, valueOf(2.0f)).floatValue(), Float.POSITIVE_INFINITY, "testAddConstantFolding"); + assertResult(subtract(Float16.NEGATIVE_INFINITY, valueOf(2.0f)).floatValue(), Float.NEGATIVE_INFINITY, "testAddConstantFolding"); + + // The difference of two zeros of opposite sign is positive zero. + assertResult(subtract(NEGATIVE_ZERO, POSITIVE_ZERO).floatValue(), 0.0f, "testAddConstantFolding"); + + // Number equal to -MAX_VALUE when subtracted by half upl of MAX_VALUE results into -Inf. + assertResult(subtract(NEGATIVE_MAX_VALUE, MAX_HALF_ULP).floatValue(), Float.NEGATIVE_INFINITY, "testAddConstantFolding"); + + // Number equal to -MAX_VALUE when subtracted by a number less than half upl for MAX_VALUE results into -MAX_VALUE. + assertResult(subtract(NEGATIVE_MAX_VALUE, LT_MAX_HALF_ULP).floatValue(), NEGATIVE_MAX_VALUE.floatValue(), "testAddConstantFolding"); + + assertResult(subtract(valueOf(1.0f), valueOf(2.0f)).floatValue(), -1.0f, "testAddConstantFolding"); + } + + @Test + @Warmup(value = 10000) + @IR(counts = {IRNode.MAX_HF, " 0 ", IRNode.REINTERPRET_S2HF, " 0 ", IRNode.REINTERPRET_HF2S, " 0 "}, + applyIfCPUFeature = {"avx512_fp16", "true"}) + public void testMaxConstantFolding() { + // If either value is NaN, then the result is NaN. + assertResult(max(valueOf(2.0f), Float16.NaN).floatValue(), Float.NaN, "testMaxConstantFolding"); + assertResult(max(Float16.NaN, Float16.NaN).floatValue(), Float.NaN, "testMaxConstantFolding"); + + // This operation considers negative zero to be strictly smaller than positive zero + assertResult(max(POSITIVE_ZERO, NEGATIVE_ZERO).floatValue(), 0.0f, "testMaxConstantFolding"); + + // Other cases. + assertResult(max(Float16.POSITIVE_INFINITY, Float16.NEGATIVE_INFINITY).floatValue(), Float.POSITIVE_INFINITY, "testMaxConstantFolding"); + assertResult(max(valueOf(1.0f), valueOf(2.0f)).floatValue(), 2.0f, "testMaxConstantFolding"); + assertResult(max(Float16.MAX_VALUE, Float16.MIN_VALUE).floatValue(), Float16.MAX_VALUE.floatValue(), "testMaxConstantFolding"); + } + + + @Test + @IR(counts = {IRNode.MIN_HF, " 0 ", IRNode.REINTERPRET_S2HF, " 0 ", IRNode.REINTERPRET_HF2S, " 0 "}, + applyIfCPUFeature = {"avx512_fp16", "true"}) + public void testMinConstantFolding() { + // If either value is NaN, then the result is NaN. + assertResult(min(valueOf(2.0f), Float16.NaN).floatValue(), Float.NaN, "testMinConstantFolding"); + assertResult(min(Float16.NaN, Float16.NaN).floatValue(), Float.NaN, "testMinConstantFolding"); + + // This operation considers negative zero to be strictly smaller than positive zero + assertResult(min(POSITIVE_ZERO, NEGATIVE_ZERO).floatValue(), -0.0f, "testMinConstantFolding"); + + // Other cases. + assertResult(min(Float16.POSITIVE_INFINITY, Float16.NEGATIVE_INFINITY).floatValue(), Float.NEGATIVE_INFINITY, "testMinConstantFolding"); + assertResult(min(valueOf(1.0f), valueOf(2.0f)).floatValue(), 1.0f, "testMinConstantFolding"); + assertResult(min(Float16.MAX_VALUE, Float16.MIN_VALUE).floatValue(), Float16.MIN_VALUE.floatValue(), "testMinConstantFolding"); + } + + @Test + @IR(counts = {IRNode.DIV_HF, " 0 ", IRNode.REINTERPRET_S2HF, " 0 ", IRNode.REINTERPRET_HF2S, " 0 "}, + applyIfCPUFeature = {"avx512_fp16", "true"}) + public void testDivConstantFolding() { + // If either value is NaN, then the result is NaN. + assertResult(divide(Float16.NaN, POSITIVE_ZERO).floatValue(), Float.NaN, "testDivConstantFolding"); + assertResult(divide(NEGATIVE_ZERO, Float16.NaN).floatValue(), Float.NaN, "testDivConstantFolding"); + + // Division of an infinity by an infinity results in NaN. + assertResult(divide(Float16.NEGATIVE_INFINITY, Float16.POSITIVE_INFINITY).floatValue(), Float.NaN, "testDivConstantFolding"); + + // Division of an infinity by a finite value results in a signed infinity. Sign of the result is positive if both operands have + // the same sign, and negative if the operands have different signs + assertResult(divide(Float16.NEGATIVE_INFINITY, valueOf(2.0f)).floatValue(), Float.NEGATIVE_INFINITY, "testDivConstantFolding"); + assertResult(divide(Float16.POSITIVE_INFINITY, valueOf(2.0f)).floatValue(), Float.POSITIVE_INFINITY, "testDivConstantFolding"); + + // Division of a finite value by an infinity results in a signed zero. The sign is + // determined by the above rule. + assertResult(divide(valueOf(2.0f), Float16.POSITIVE_INFINITY).floatValue(), 0.0f, "testDivConstantFolding"); + assertResult(divide(valueOf(2.0f), Float16.NEGATIVE_INFINITY).floatValue(), -0.0f, "testDivConstantFolding"); + + // Division of a zero by a zero results in NaN; division of zero by any other finite + // value results in a signed zero. The sign is determined by the rule stated above. + assertResult(divide(POSITIVE_ZERO, NEGATIVE_ZERO).floatValue(), Float.NaN, "testDivConstantFolding"); + assertResult(divide(POSITIVE_ZERO, Float16.MAX_VALUE).floatValue(), 0.0f, "testDivConstantFolding"); + assertResult(divide(NEGATIVE_ZERO, Float16.MAX_VALUE).floatValue(), -0.0f, "testDivConstantFolding"); + + // Division of a nonzero finite value by a zero results in a signed infinity. The sign + // is determined by the rule stated above + assertResult(divide(valueOf(2.0f), NEGATIVE_ZERO).floatValue(), Float.NEGATIVE_INFINITY, "testDivConstantFolding"); + assertResult(divide(valueOf(2.0f), POSITIVE_ZERO).floatValue(), Float.POSITIVE_INFINITY, "testDivConstantFolding"); + + // If the magnitude of the quotient is too large to represent, we say the operation + // overflows; the result is then an infinity of appropriate sign. + assertResult(divide(Float16.MAX_VALUE, Float16.MIN_NORMAL).floatValue(), Float.POSITIVE_INFINITY, "testDivConstantFolding"); + assertResult(divide(Float16.MAX_VALUE, valueOf(-0x1.0P-14f)).floatValue(), Float.NEGATIVE_INFINITY, "testDivConstantFolding"); + + assertResult(divide(valueOf(2.0f), valueOf(2.0f)).floatValue(), 1.0f, "testDivConstantFolding"); + } + + @Test + @IR(counts = {IRNode.MUL_HF, " 0 ", IRNode.REINTERPRET_S2HF, " 0 ", IRNode.REINTERPRET_HF2S, " 0 "}, + applyIfCPUFeature = {"avx512_fp16", "true"}) + public void testMulConstantFolding() { + // If any operand is NaN, the result is NaN. + assertResult(multiply(Float16.NaN, valueOf(4.0f)).floatValue(), Float.NaN, "testMulConstantFolding"); + assertResult(multiply(Float16.NaN, Float16.NaN).floatValue(), Float.NaN, "testMulConstantFolding"); + + // Multiplication of an infinity by a zero results in NaN. + assertResult(multiply(Float16.POSITIVE_INFINITY, POSITIVE_ZERO).floatValue(), Float.NaN, "testMulConstantFolding"); + + // Multiplication of an infinity by a finite value results in a signed infinity. + assertResult(multiply(Float16.POSITIVE_INFINITY, valueOf(2.0f)).floatValue(), Float.POSITIVE_INFINITY, "testMulConstantFolding"); + assertResult(multiply(Float16.NEGATIVE_INFINITY, valueOf(2.0f)).floatValue(), Float.NEGATIVE_INFINITY, "testMulConstantFolding"); + + // If the magnitude of the product is too large to represent, we say the operation + // overflows; the result is then an infinity of appropriate sign + assertResult(multiply(Float16.MAX_VALUE, Float16.MAX_VALUE).floatValue(), Float.POSITIVE_INFINITY, "testMulConstantFolding"); + assertResult(multiply(NEGATIVE_MAX_VALUE, Float16.MAX_VALUE).floatValue(), Float.NEGATIVE_INFINITY, "testMulConstantFolding"); + + assertResult(multiply(multiply(multiply(valueOf(1.0f), valueOf(2.0f)), valueOf(3.0f)), valueOf(4.0f)).floatValue(), 1.0f * 2.0f * 3.0f * 4.0f, "testMulConstantFolding"); + } + + @Test + @IR(counts = {IRNode.FMA_HF, " 0 ", IRNode.REINTERPRET_S2HF, " 0 ", IRNode.REINTERPRET_HF2S, " 0 "}, + applyIfCPUFeature = {"avx512_fp16", "true"}) + public void testFMAConstantFolding() { + // If any argument is NaN, the result is NaN. + assertResult(fma(Float16.NaN, valueOf(2.0f), valueOf(3.0f)).floatValue(), Float.NaN, "testFMAConstantFolding"); + assertResult(fma(valueOf(2.0f), Float16.NaN, valueOf(3.0f)).floatValue(), Float.NaN, "testFMAConstantFolding"); + assertResult(fma(valueOf(2.0f), valueOf(3.0f), Float16.NaN).floatValue(), Float.NaN, "testFMAConstantFolding"); + + // If one of the first two arguments is infinite and the other is zero, the result is NaN. + assertResult(fma(Float16.POSITIVE_INFINITY, POSITIVE_ZERO, valueOf(2.0f)).floatValue(), Float.NaN, "testFMAConstantFolding"); + assertResult(fma(Float16.POSITIVE_INFINITY, NEGATIVE_ZERO, valueOf(2.0f)).floatValue(), Float.NaN, "testFMAConstantFolding"); + assertResult(fma(NEGATIVE_ZERO, Float16.POSITIVE_INFINITY, valueOf(2.0f)).floatValue(), Float.NaN, "testFMAConstantFolding"); + assertResult(fma(POSITIVE_ZERO, Float16.POSITIVE_INFINITY, valueOf(2.0f)).floatValue(), Float.NaN, "testFMAConstantFolding"); + + // If the exact product of the first two arguments is infinite (in other words, at least one of the arguments is infinite + // and the other is neither zero nor NaN) and the third argument is an infinity of the opposite sign, the result is NaN. + assertResult(fma(valueOf(2.0f), Float16.POSITIVE_INFINITY, Float16.NEGATIVE_INFINITY).floatValue(), Float.NaN, "testFMAConstantFolding"); + assertResult(fma(valueOf(2.0f), Float16.NEGATIVE_INFINITY, Float16.POSITIVE_INFINITY).floatValue(), Float.NaN, "testFMAConstantFolding"); + assertResult(fma(Float16.POSITIVE_INFINITY, valueOf(2.0f), Float16.NEGATIVE_INFINITY).floatValue(), Float.NaN, "testFMAConstantFolding"); + assertResult(fma(Float16.NEGATIVE_INFINITY, valueOf(2.0f), Float16.POSITIVE_INFINITY).floatValue(), Float.NaN, "testFMAConstantFolding"); + + // Signed bits. + assertResult(fma(NEGATIVE_ZERO, POSITIVE_ZERO, POSITIVE_ZERO).floatValue(), 0.0f, "testFMAConstantFolding"); + assertResult(fma(NEGATIVE_ZERO, POSITIVE_ZERO, NEGATIVE_ZERO).floatValue(), -0.0f, "testFMAConstantFolding"); + + assertResult(fma(Float16.POSITIVE_INFINITY, valueOf(2.0f), valueOf(3.0f)).floatValue(), Float.POSITIVE_INFINITY, "testFMAConstantFolding"); + assertResult(fma(Float16.NEGATIVE_INFINITY, valueOf(2.0f), valueOf(3.0f)).floatValue(), Float.NEGATIVE_INFINITY, "testFMAConstantFolding"); + assertResult(fma(valueOf(1.0f), valueOf(2.0f), valueOf(3.0f)).floatValue(), 1.0f * 2.0f + 3.0f, "testFMAConstantFolding"); + } +} diff --git a/test/hotspot/jtreg/compiler/lib/ir_framework/IRNode.java b/test/hotspot/jtreg/compiler/lib/ir_framework/IRNode.java index 35e10c7ae9f9b..37f49a2ce6bde 100644 --- a/test/hotspot/jtreg/compiler/lib/ir_framework/IRNode.java +++ b/test/hotspot/jtreg/compiler/lib/ir_framework/IRNode.java @@ -209,6 +209,11 @@ public class IRNode { beforeMatchingNameRegex(ADD, "Add(I|L|F|D|P)"); } + public static final String ADD_F = PREFIX + "ADD_F" + POSTFIX; + static { + beforeMatchingNameRegex(ADD_F, "AddF"); + } + public static final String ADD_I = PREFIX + "ADD_I" + POSTFIX; static { beforeMatchingNameRegex(ADD_I, "AddI"); @@ -219,6 +224,11 @@ public class IRNode { beforeMatchingNameRegex(ADD_L, "AddL"); } + public static final String ADD_HF = PREFIX + "ADD_HF" + POSTFIX; + static { + beforeMatchingNameRegex(ADD_HF, "AddHF"); + } + public static final String ADD_P = PREFIX + "ADD_P" + POSTFIX; static { beforeMatchingNameRegex(ADD_P, "AddP"); @@ -239,6 +249,11 @@ public class IRNode { vectorNode(ADD_VF, "AddVF", TYPE_FLOAT); } + public static final String ADD_VHF = PREFIX + "ADD_VHF" + POSTFIX; + static { + beforeMatchingNameRegex(ADD_VHF, "AddVHF"); + } + public static final String ADD_VB = VECTOR_PREFIX + "ADD_VB" + POSTFIX; static { vectorNode(ADD_VB, "AddVB", TYPE_BYTE); @@ -528,6 +543,11 @@ public class IRNode { beforeMatchingNameRegex(CONV, "Conv"); } + public static final String CONV_F2HF = PREFIX + "CONV_F2HF" + POSTFIX; + static { + beforeMatchingNameRegex(CONV_F2HF, "ConvF2HF"); + } + public static final String CONV_I2L = PREFIX + "CONV_I2L" + POSTFIX; static { beforeMatchingNameRegex(CONV_I2L, "ConvI2L"); @@ -538,6 +558,11 @@ public class IRNode { beforeMatchingNameRegex(CONV_L2I, "ConvL2I"); } + public static final String CONV_HF2F = PREFIX + "CONV_HF2F" + POSTFIX; + static { + beforeMatchingNameRegex(CONV_HF2F, "ConvHF2F"); + } + public static final String CON_I = PREFIX + "CON_I" + POSTFIX; static { beforeMatchingNameRegex(CON_I, "ConI"); @@ -637,6 +662,16 @@ public class IRNode { vectorNode(FMA_VD, "FmaVD", TYPE_DOUBLE); } + public static final String FMA_VHF = PREFIX + "FMA_VHF" + POSTFIX; + static { + beforeMatchingNameRegex(FMA_VHF, "FmaVHF"); + } + + public static final String FMA_HF = PREFIX + "FMA_HF" + POSTFIX; + static { + beforeMatchingNameRegex(FMA_HF, "FmaHF"); + } + public static final String G1_COMPARE_AND_EXCHANGE_N_WITH_BARRIER_FLAG = COMPOSITE_PREFIX + "G1_COMPARE_AND_EXCHANGE_N_WITH_BARRIER_FLAG" + POSTFIX; static { String regex = START + "g1CompareAndExchangeN\\S*" + MID + "barrier\\(\\s*" + IS_REPLACED + "\\s*\\)" + END; @@ -1064,6 +1099,11 @@ public class IRNode { vectorNode(MAX_VD, "MaxV", TYPE_DOUBLE); } + public static final String MAX_VHF = PREFIX + "MAX_VHF" + POSTFIX; + static { + beforeMatchingNameRegex(MAX_VHF, "MaxVHF"); + } + public static final String MAX_VL = VECTOR_PREFIX + "MAX_VL" + POSTFIX; static { vectorNode(MAX_VL, "MaxV", TYPE_LONG); @@ -1139,6 +1179,16 @@ public class IRNode { beforeMatchingNameRegex(MIN_L, "MinL"); } + public static final String MIN_HF = PREFIX + "MIN_HF" + POSTFIX; + static { + beforeMatchingNameRegex(MIN_HF, "MinHF"); + } + + public static final String MAX_HF = PREFIX + "MAX_HF" + POSTFIX; + static { + beforeMatchingNameRegex(MAX_HF, "MaxHF"); + } + public static final String MIN_VI = VECTOR_PREFIX + "MIN_VI" + POSTFIX; static { vectorNode(MIN_VI, "MinV", TYPE_INT); @@ -1154,6 +1204,11 @@ public class IRNode { vectorNode(MIN_VD, "MinV", TYPE_DOUBLE); } + public static final String MIN_VHF = PREFIX + "MIN_VHF" + POSTFIX; + static { + beforeMatchingNameRegex(MIN_VHF, "MinVHF"); + } + public static final String MIN_VL = VECTOR_PREFIX + "MIN_VL" + POSTFIX; static { vectorNode(MIN_VL, "MinV", TYPE_LONG); @@ -1220,6 +1275,11 @@ public class IRNode { beforeMatchingNameRegex(MUL_F, "MulF"); } + public static final String MUL_HF = PREFIX + "MUL_HF" + POSTFIX; + static { + beforeMatchingNameRegex(MUL_HF, "MulHF"); + } + public static final String MUL_I = PREFIX + "MUL_I" + POSTFIX; static { beforeMatchingNameRegex(MUL_I, "MulI"); @@ -1245,6 +1305,11 @@ public class IRNode { vectorNode(MUL_VF, "MulVF", TYPE_FLOAT); } + public static final String MUL_VHF = PREFIX + "MUL_VHF" + POSTFIX; + static { + beforeMatchingNameRegex(MUL_VHF, "MulVHF"); + } + public static final String MUL_VD = VECTOR_PREFIX + "MUL_VD" + POSTFIX; static { vectorNode(MUL_VD, "MulVD", TYPE_DOUBLE); @@ -1420,6 +1485,16 @@ public class IRNode { trapNodes(RANGE_CHECK_TRAP, "range_check"); } + public static final String REINTERPRET_S2HF = PREFIX + "REINTERPRET_S2HF" + POSTFIX; + static { + beforeMatchingNameRegex(REINTERPRET_S2HF, "ReinterpretS2HF"); + } + + public static final String REINTERPRET_HF2S = PREFIX + "REINTERPRET_HF2S" + POSTFIX; + static { + beforeMatchingNameRegex(REINTERPRET_HF2S, "ReinterpretHF2S"); + } + public static final String REPLICATE_B = VECTOR_PREFIX + "REPLICATE_B" + POSTFIX; static { vectorNode(REPLICATE_B, "Replicate", TYPE_BYTE); @@ -1586,6 +1661,16 @@ public class IRNode { vectorNode(SIGNUM_VF, "SignumVF", TYPE_FLOAT); } + public static final String SQRT_HF = PREFIX + "SQRT_HF" + POSTFIX; + static { + beforeMatchingNameRegex(SQRT_HF, "SqrtHF"); + } + + public static final String SQRT_F = PREFIX + "SQRT_F" + POSTFIX; + static { + beforeMatchingNameRegex(SQRT_F, "SqrtF"); + } + public static final String SQRT_VF = VECTOR_PREFIX + "SQRT_VF" + POSTFIX; static { vectorNode(SQRT_VF, "SqrtVF", TYPE_FLOAT); @@ -1596,6 +1681,11 @@ public class IRNode { vectorNode(SQRT_VD, "SqrtVD", TYPE_DOUBLE); } + public static final String SQRT_VHF = PREFIX + "SQRT_VHF" + POSTFIX; + static { + beforeMatchingNameRegex(SQRT_VHF, "SqrtVHF"); + } + public static final String STORE = PREFIX + "STORE" + POSTFIX; static { beforeMatchingNameRegex(STORE, "Store(B|C|S|I|L|F|D|P|N)"); @@ -1727,6 +1817,11 @@ public class IRNode { beforeMatchingNameRegex(SUB_F, "SubF"); } + public static final String SUB_HF = PREFIX + "SUB_HF" + POSTFIX; + static { + beforeMatchingNameRegex(SUB_HF, "SubHF"); + } + public static final String SUB_I = PREFIX + "SUB_I" + POSTFIX; static { beforeMatchingNameRegex(SUB_I, "SubI"); @@ -1762,6 +1857,11 @@ public class IRNode { vectorNode(SUB_VF, "SubVF", TYPE_FLOAT); } + public static final String SUB_VHF = PREFIX + "SUB_VHF" + POSTFIX; + static { + beforeMatchingNameRegex(SUB_VHF, "SubVHF"); + } + public static final String SUB_VD = VECTOR_PREFIX + "SUB_VD" + POSTFIX; static { vectorNode(SUB_VD, "SubVD", TYPE_DOUBLE); @@ -1777,6 +1877,16 @@ public class IRNode { trapNodes(TRAP, "reason"); } + public static final String DIV_HF = PREFIX + "DIV_HF" + POSTFIX; + static { + beforeMatchingNameRegex(DIV_HF, "DivHF"); + } + + public static final String DIV_VHF = PREFIX + "DIV_VHF" + POSTFIX; + static { + beforeMatchingNameRegex(DIV_VHF, "DivVHF"); + } + public static final String UDIV_I = PREFIX + "UDIV_I" + POSTFIX; static { beforeMatchingNameRegex(UDIV_I, "UDivI"); diff --git a/test/hotspot/jtreg/compiler/lib/ir_framework/test/IREncodingPrinter.java b/test/hotspot/jtreg/compiler/lib/ir_framework/test/IREncodingPrinter.java index 9f58f7097027c..72e6ce3129a6c 100644 --- a/test/hotspot/jtreg/compiler/lib/ir_framework/test/IREncodingPrinter.java +++ b/test/hotspot/jtreg/compiler/lib/ir_framework/test/IREncodingPrinter.java @@ -102,11 +102,14 @@ public class IREncodingPrinter { "avx512dq", "avx512vl", "avx512f", + "avx512_fp16", "avx512_vnni", // AArch64 "sha3", "asimd", "sve", + "fphp", + "asimdhp", // Riscv64 "rvv", "zvbb", diff --git a/test/hotspot/jtreg/compiler/vectorization/TestFloat16VectorConvChain.java b/test/hotspot/jtreg/compiler/vectorization/TestFloat16VectorConvChain.java index 20401a989383a..174c34b953335 100644 --- a/test/hotspot/jtreg/compiler/vectorization/TestFloat16VectorConvChain.java +++ b/test/hotspot/jtreg/compiler/vectorization/TestFloat16VectorConvChain.java @@ -41,7 +41,12 @@ public class TestFloat16VectorConvChain { @Test - @IR(applyIfCPUFeatureOr = {"f16c", "true", "avx512vl", "true", "zvfh", "true"}, counts = {IRNode.VECTOR_CAST_HF2F, IRNode.VECTOR_SIZE_ANY, ">= 1", IRNode.VECTOR_CAST_F2HF, IRNode.VECTOR_SIZE_ANY, " >= 1"}) + @IR(applyIfCPUFeatureAnd = {"avx512_fp16", "false", "avx512vl", "true"}, + counts = {IRNode.VECTOR_CAST_HF2F, IRNode.VECTOR_SIZE_ANY, ">= 1", IRNode.VECTOR_CAST_F2HF, IRNode.VECTOR_SIZE_ANY, " >= 1"}) + @IR(applyIfCPUFeatureAnd = {"avx512_fp16", "false", "f16c", "true"}, + counts = {IRNode.VECTOR_CAST_HF2F, IRNode.VECTOR_SIZE_ANY, ">= 1", IRNode.VECTOR_CAST_F2HF, IRNode.VECTOR_SIZE_ANY, " >= 1"}) + @IR(applyIfCPUFeatureAnd = {"avx512_fp16", "false", "zvfh", "true"}, + counts = {IRNode.VECTOR_CAST_HF2F, IRNode.VECTOR_SIZE_ANY, ">= 1", IRNode.VECTOR_CAST_F2HF, IRNode.VECTOR_SIZE_ANY, " >= 1"}) public static void test(short [] res, short [] src1, short [] src2) { for (int i = 0; i < res.length; i++) { res[i] = (short)Float.float16ToFloat(Float.floatToFloat16(Float.float16ToFloat(src1[i]) + Float.float16ToFloat(src2[i]))); diff --git a/test/hotspot/jtreg/compiler/vectorization/TestFloat16VectorOperations.java b/test/hotspot/jtreg/compiler/vectorization/TestFloat16VectorOperations.java new file mode 100644 index 0000000000000..725fce95c9dec --- /dev/null +++ b/test/hotspot/jtreg/compiler/vectorization/TestFloat16VectorOperations.java @@ -0,0 +1,288 @@ +/* + * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2024, Arm Limited. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/** +* @test +* @bug 8342103 +* @summary Auto-vectorization support for various Float16 operations +* @requires vm.compiler2.enabled +* @modules jdk.incubator.vector +* @library /test/lib / +* @compile TestFloat16VectorOperations.java +* @run driver compiler.vectorization.TestFloat16VectorOperations +*/ + +package compiler.vectorization; +import compiler.lib.ir_framework.*; +import java.util.Random; +import static jdk.incubator.vector.Float16.*; +import static java.lang.Float.*; + +public class TestFloat16VectorOperations { + private short[] input1; + private short[] input2; + private short[] input3; + private short[] output; + private static short SCALAR_FP16 = (short)0x7777; + private static final int LEN = 2048; + private Random rng; + + public static void main(String args[]) { + TestFramework.runWithFlags("-XX:-TieredCompilation", "-Xbatch","--add-modules=jdk.incubator.vector"); + } + + public TestFloat16VectorOperations() { + input1 = new short[LEN]; + input2 = new short[LEN]; + input3 = new short[LEN]; + output = new short[LEN]; + rng = new Random(42); + for (int i = 0; i < LEN; ++i) { + input1[i] = floatToFloat16(rng.nextFloat()); + input2[i] = floatToFloat16(rng.nextFloat()); + input3[i] = floatToFloat16(rng.nextFloat()); + } + } + + @Test + @Warmup(10000) + @IR(counts = {IRNode.ADD_VHF, ">= 1"}, + applyIfCPUFeature = {"avx512_fp16", "true"}) + public void vectorAddFloat16() { + for (int i = 0; i < LEN; ++i) { + output[i] = float16ToRawShortBits(add(shortBitsToFloat16(input1[i]), shortBitsToFloat16(input2[i]))); + } + } + + @Check(test="vectorAddFloat16") + public void checkResultAdd() { + for (int i = 0; i < LEN; ++i) { + short expected = floatToFloat16(float16ToFloat(input1[i]) + float16ToFloat(input2[i])); + if (output[i] != expected) { + throw new RuntimeException("Invalid result: output[" + i + "] = " + output[i] + " != " + expected); + } + } + } + + @Test + @Warmup(10000) + @IR(counts = {IRNode.SUB_VHF, ">= 1"}, + applyIfCPUFeature = {"avx512_fp16", "true"}) + public void vectorSubFloat16() { + for (int i = 0; i < LEN; ++i) { + output[i] = float16ToRawShortBits(subtract(shortBitsToFloat16(input1[i]), shortBitsToFloat16(input2[i]))); + } + } + + @Check(test="vectorSubFloat16") + public void checkResultSub() { + for (int i = 0; i < LEN; ++i) { + short expected = floatToFloat16(float16ToFloat(input1[i]) - float16ToFloat(input2[i])); + if (output[i] != expected) { + throw new RuntimeException("Invalid result: output[" + i + "] = " + output[i] + " != " + expected); + } + } + } + + @Test + @Warmup(10000) + @IR(counts = {IRNode.MUL_VHF, ">= 1"}, + applyIfCPUFeature = {"avx512_fp16", "true"}) + public void vectorMulFloat16() { + for (int i = 0; i < LEN; ++i) { + output[i] = float16ToRawShortBits(multiply(shortBitsToFloat16(input1[i]), shortBitsToFloat16(input2[i]))); + } + } + + @Check(test="vectorMulFloat16") + public void checkResultMul() { + for (int i = 0; i < LEN; ++i) { + short expected = floatToFloat16(float16ToFloat(input1[i]) * float16ToFloat(input2[i])); + if (output[i] != expected) { + throw new RuntimeException("Invalid result: output[" + i + "] = " + output[i] + " != " + expected); + } + } + } + + @Test + @Warmup(10000) + @IR(counts = {IRNode.DIV_VHF, ">= 1"}, + applyIfCPUFeature = {"avx512_fp16", "true"}) + public void vectorDivFloat16() { + for (int i = 0; i < LEN; ++i) { + output[i] = float16ToRawShortBits(divide(shortBitsToFloat16(input1[i]), shortBitsToFloat16(input2[i]))); + } + } + + @Check(test="vectorDivFloat16") + public void checkResultDiv() { + for (int i = 0; i < LEN; ++i) { + short expected = floatToFloat16(float16ToFloat(input1[i]) / float16ToFloat(input2[i])); + if (output[i] != expected) { + throw new RuntimeException("Invalid result: output[" + i + "] = " + output[i] + " != " + expected); + } + } + } + + @Test + @Warmup(10000) + @IR(counts = {IRNode.MIN_VHF, ">= 1"}, + applyIfCPUFeature = {"avx512_fp16", "true"}) + public void vectorMinFloat16() { + for (int i = 0; i < LEN; ++i) { + output[i] = float16ToRawShortBits(min(shortBitsToFloat16(input1[i]), shortBitsToFloat16(input2[i]))); + } + } + + @Check(test="vectorMinFloat16") + public void checkResultMin() { + for (int i = 0; i < LEN; ++i) { + short expected = floatToFloat16(Math.min(float16ToFloat(input1[i]), float16ToFloat(input2[i]))); + if (output[i] != expected) { + throw new RuntimeException("Invalid result: output[" + i + "] = " + output[i] + " != " + expected); + } + } + } + + @Test + @Warmup(10000) + @IR(counts = {IRNode.MAX_VHF, ">= 1"}, + applyIfCPUFeature = {"avx512_fp16", "true"}) + public void vectorMaxFloat16() { + for (int i = 0; i < LEN; ++i) { + output[i] = float16ToRawShortBits(max(shortBitsToFloat16(input1[i]), shortBitsToFloat16(input2[i]))); + } + } + + @Check(test="vectorMaxFloat16") + public void checkResultMax() { + for (int i = 0; i < LEN; ++i) { + short expected = floatToFloat16(Math.max(float16ToFloat(input1[i]), float16ToFloat(input2[i]))); + if (output[i] != expected) { + throw new RuntimeException("Invalid result: output[" + i + "] = " + output[i] + " != " + expected); + } + } + } + + @Test + @Warmup(10000) + @IR(counts = {IRNode.SQRT_VHF, ">= 1"}, + applyIfCPUFeature = {"avx512_fp16", "true"}) + public void vectorSqrtFloat16() { + for (int i = 0; i < LEN; ++i) { + output[i] = float16ToRawShortBits(sqrt(shortBitsToFloat16(input1[i]))); + } + } + + @Check(test="vectorSqrtFloat16") + public void checkResultSqrt() { + for (int i = 0; i < LEN; ++i) { + short expected = float16ToRawShortBits(valueOf(Math.sqrt(float16ToFloat(input1[i])))); + if (output[i] != expected) { + throw new RuntimeException("Invalid result: output[" + i + "] = " + output[i] + " != " + expected); + } + } + } + + @Test + @Warmup(10000) + @IR(counts = {IRNode.FMA_VHF, ">= 1"}, + applyIfCPUFeature = {"avx512_fp16", "true"}) + public void vectorFmaFloat16() { + for (int i = 0; i < LEN; ++i) { + output[i] = float16ToRawShortBits(fma(shortBitsToFloat16(input1[i]), shortBitsToFloat16(input2[i]), shortBitsToFloat16(input3[i]))); + } + } + + @Check(test="vectorFmaFloat16") + public void checkResultFma() { + for (int i = 0; i < LEN; ++i) { + short expected = floatToFloat16(Math.fma(float16ToFloat(input1[i]), float16ToFloat(input2[i]), float16ToFloat(input3[i]))); + if (output[i] != expected) { + throw new RuntimeException("Invalid result: output[" + i + "] = " + output[i] + " != " + expected); + } + } + } + + @Test + @Warmup(10000) + @IR(counts = {IRNode.FMA_VHF, " >= 1"}, + applyIfCPUFeature = {"avx512_fp16", "true"}) + public void vectorFmaFloat16ScalarMixedConstants() { + for (int i = 0; i < LEN; ++i) { + output[i] = float16ToRawShortBits(fma(shortBitsToFloat16(input1[i]), shortBitsToFloat16(SCALAR_FP16), shortBitsToFloat16(floatToFloat16(3.0f)))); + } + } + + @Check(test="vectorFmaFloat16ScalarMixedConstants") + public void checkResultFmaScalarMixedConstants() { + for (int i = 0; i < LEN; ++i) { + short expected = floatToFloat16(Math.fma(float16ToFloat(input1[i]), float16ToFloat(SCALAR_FP16), 3.0f)); + if (output[i] != expected) { + throw new RuntimeException("Invalid result: output[" + i + "] = " + output[i] + " != " + expected); + } + } + } + + + @Test + @Warmup(10000) + @IR(counts = {IRNode.FMA_VHF, " >= 1"}, + applyIfCPUFeature = {"avx512_fp16", "true"}) + public void vectorFmaFloat16MixedConstants() { + for (int i = 0; i < LEN; ++i) { + output[i] = float16ToRawShortBits(fma(shortBitsToFloat16(input1[i]), shortBitsToFloat16(input2[i]), shortBitsToFloat16(floatToFloat16(3.0f)))); + } + } + + @Check(test="vectorFmaFloat16MixedConstants") + public void checkResultFmaMixedConstants() { + for (int i = 0; i < LEN; ++i) { + short expected = floatToFloat16(Math.fma(float16ToFloat(input1[i]), float16ToFloat(input2[i]), 3.0f)); + if (output[i] != expected) { + throw new RuntimeException("Invalid result: output[" + i + "] = " + output[i] + " != " + expected); + } + } + } + + @Test + @Warmup(10000) + @IR(counts = {IRNode.FMA_VHF, " 0 "}, + applyIfCPUFeature = {"avx512_fp16", "true"}) + public void vectorFmaFloat16AllConstants() { + for (int i = 0; i < LEN; ++i) { + output[i] = float16ToRawShortBits(fma(shortBitsToFloat16(floatToFloat16(1.0f)), shortBitsToFloat16(floatToFloat16(2.0f)), shortBitsToFloat16(floatToFloat16(3.0f)))); + } + } + + @Check(test="vectorFmaFloat16AllConstants") + public void checkResultFmaAllConstants() { + for (int i = 0; i < LEN; ++i) { + short expected = floatToFloat16(Math.fma(1.0f, 2.0f, 3.0f)); + if (output[i] != expected) { + throw new RuntimeException("Invalid result: output[" + i + "] = " + output[i] + " != " + expected); + } + } + } +} diff --git a/test/jdk/jdk/incubator/vector/ScalarFloat16OperationsTest.java b/test/jdk/jdk/incubator/vector/ScalarFloat16OperationsTest.java new file mode 100644 index 0000000000000..1fc37232cbbd0 --- /dev/null +++ b/test/jdk/jdk/incubator/vector/ScalarFloat16OperationsTest.java @@ -0,0 +1,347 @@ +/* + * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/* + * @test + * @bug 8342103 + * @summary C2 compiler support for Float16 type and associated operations + * @modules jdk.incubator.vector + * @library /test/lib + * @compile ScalarFloat16OperationsTest.java + * @run testng/othervm/timeout=300 -ea -esa -Xbatch -XX:-TieredCompilation -XX:-UseSuperWord ScalarFloat16OperationsTest + * @run testng/othervm/timeout=300 -ea -esa -Xbatch -XX:-TieredCompilation -XX:+UseSuperWord ScalarFloat16OperationsTest + */ + +import org.testng.Assert; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.util.Random; +import java.util.stream.IntStream; +import jdk.incubator.vector.Float16; +import static jdk.incubator.vector.Float16.*; + +public class ScalarFloat16OperationsTest { + static final int SIZE = 65504; + static Random r = jdk.test.lib.Utils.getRandomInstance(); + static final int INVOC_COUNT = Integer.getInteger("jdk.incubator.vector.test.loop-iterations", 100); + + @DataProvider + public static Object[][] unaryOpProvider() { + Float16 [] input = new Float16[SIZE]; + Float16 [] special_input = { + Float16.MAX_VALUE, Float16.MIN_VALUE, Float16.MIN_NORMAL, Float16.POSITIVE_INFINITY, + Float16.NEGATIVE_INFINITY, Float16.valueOf(0.0f), Float16.valueOf(-0.0f), Float16.NaN + }; + + // Input array covers entire Float16 value range + IntStream.range(0, input.length).forEach(i -> {input[i] = valueOf(i);}); + + return new Object[][] { + {input}, + {special_input} + }; + } + + @DataProvider + public static Object[][] binaryOpProvider() { + Float16 [] input1 = new Float16[SIZE]; + Float16 [] input2 = new Float16[SIZE]; + Float16 [] special_input = { + Float16.MAX_VALUE, Float16.MIN_VALUE, Float16.MIN_NORMAL, Float16.POSITIVE_INFINITY, + Float16.NEGATIVE_INFINITY, Float16.valueOf(0.0f), Float16.valueOf(-0.0f), Float16.NaN + }; + + // Input arrays covers entire Float16 value range interspersed with special values. + IntStream.range(0, input1.length).forEach(i -> {input1[i] = valueOf(i);}); + IntStream.range(0, input2.length).forEach(i -> {input2[i] = valueOf(i);}); + + for (int i = 0; i < special_input.length; i += 256) { + input1[r.nextInt(input1.length)] = special_input[i]; + input2[r.nextInt(input2.length)] = special_input[i]; + } + + return new Object[][] { + {input1, input2}, + {special_input, special_input}, + }; + } + + @DataProvider + public static Object[][] ternaryOpProvider() { + Float16 [] input1 = new Float16[SIZE]; + Float16 [] input2 = new Float16[SIZE]; + Float16 [] input3 = new Float16[SIZE]; + Float16 [] special_input = { + Float16.MAX_VALUE, Float16.MIN_VALUE, Float16.MIN_NORMAL, Float16.POSITIVE_INFINITY, + Float16.NEGATIVE_INFINITY, Float16.valueOf(0.0f), Float16.valueOf(-0.0f), Float16.NaN + }; + + // Input arrays covers entire Float16 value range interspersed with special values. + IntStream.range(0, input1.length).forEach(i -> {input1[i] = valueOf(i);}); + IntStream.range(0, input2.length).forEach(i -> {input2[i] = valueOf(i);}); + IntStream.range(0, input3.length).forEach(i -> {input3[i] = valueOf(i);}); + for (int i = 0; i < special_input.length; i += 256) { + input1[r.nextInt(input1.length)] = special_input[i]; + input2[r.nextInt(input2.length)] = special_input[i]; + input3[r.nextInt(input3.length)] = special_input[i]; + } + + return new Object[][] { + {input1, input2, input3}, + {special_input, special_input, special_input}, + }; + } + + interface FUnOp1 { + Float16 apply(Float16 a); + } + + interface FUnOp2 { + boolean apply(Float16 a); + } + + static void assertArraysEquals(Float16[] r, Float16[] a, FUnOp1 f) { + int i = 0; + try { + for (; i < a.length; i++) { + Assert.assertEquals(r[i], f.apply(a[i])); + } + } catch (AssertionError e) { + Assert.assertEquals(r[i], f.apply(a[i]), "at index #" + i + ", input = " + a[i]); + } + } + + static void assertArraysEquals(boolean[] r, Float16[] a, FUnOp2 f) { + int i = 0; + try { + for (; i < a.length; i++) { + Assert.assertEquals(r[i], f.apply(a[i])); + } + } catch (AssertionError e) { + Assert.assertEquals(r[i], f.apply(a[i]), "at index #" + i + ", input = " + a[i]); + } + } + + interface FBinOp { + Float16 apply(Float16 a, Float16 b); + } + + static void assertArraysEquals(Float16[] r, Float16[] a, Float16[] b, FBinOp f) { + int i = 0; + try { + for (; i < r.length; i++) { + Assert.assertEquals(r[i], f.apply(a[i], b[i])); + } + } catch (AssertionError e) { + Assert.assertEquals(r[i], f.apply(a[i], b[i]), "at index #" + i + ", input1 = " + a[i] + ", input2 = " + b[i]); + } + } + + interface FTernOp { + Float16 apply(Float16 a, Float16 b, Float16 c); + } + + static void assertArraysEquals(Float16[] r, Float16[] a, Float16[] b, Float16[] c, FTernOp f) { + int i = 0; + try { + for (; i < r.length; i++) { + Assert.assertEquals(r[i], f.apply(a[i], b[i], c[i])); + } + } catch (AssertionError e) { + Assert.assertEquals(r[i], f.apply(a[i], b[i], c[i]), "at index #" + i + ", input1 = " + a[i] + ", input2 = " + b[i] + ", input3 = " + c[i]); + } + } + + + @Test(dataProvider = "unaryOpProvider") + public static void absTest(Object input) { + Float16 [] farr = (Float16[])input; + Float16 [] res = new Float16[farr.length]; + for (int ic = 0; ic < INVOC_COUNT; ic++) { + for (int i = 0; i < res.length; i++) { + res[i] = abs(farr[i]); + } + } + assertArraysEquals(res, farr, (fp16) -> valueOf(Math.abs(fp16.floatValue()))); + } + + @Test(dataProvider = "unaryOpProvider") + public static void negTest(Object input) { + Float16 [] farr = (Float16[])input; + Float16 [] res = new Float16[farr.length]; + for (int ic = 0; ic < INVOC_COUNT; ic++) { + for (int i = 0; i < res.length; i++) { + res[i] = negate(farr[i]); + } + } + assertArraysEquals(res, farr, (fp16) -> shortBitsToFloat16((short)(float16ToRawShortBits(fp16) ^ (short)0x0000_8000))); + } + + @Test(dataProvider = "unaryOpProvider") + public static void sqrtTest(Object input) { + Float16 [] farr = (Float16[])input; + Float16 [] res = new Float16[farr.length]; + for (int ic = 0; ic < INVOC_COUNT; ic++) { + for (int i = 0; i < res.length; i++) { + res[i] = sqrt(farr[i]); + } + } + assertArraysEquals(res, farr, (fp16) -> valueOf(Math.sqrt(fp16.floatValue()))); + } + + @Test(dataProvider = "unaryOpProvider") + public static void isInfiniteTest(Object input) { + Float16 [] farr = (Float16[])input; + boolean [] res = new boolean[farr.length]; + for (int ic = 0; ic < INVOC_COUNT; ic++) { + for (int i = 0; i < res.length; i++) { + res[i] = isInfinite(farr[i]); + } + } + assertArraysEquals(res, farr, (fp16) -> Float.isInfinite(fp16.floatValue())); + } + + @Test(dataProvider = "unaryOpProvider") + public static void isFiniteTest(Object input) { + Float16 [] farr = (Float16[])input; + boolean [] res = new boolean[farr.length]; + for (int ic = 0; ic < INVOC_COUNT; ic++) { + for (int i = 0; i < res.length; i++) { + res[i] = isFinite(farr[i]); + } + } + assertArraysEquals(res, farr, (fp16) -> Float.isFinite(fp16.floatValue())); + } + + @Test(dataProvider = "unaryOpProvider") + public static void isNaNTest(Object input) { + Float16 [] farr = (Float16[])input; + boolean [] res = new boolean[farr.length]; + for (int ic = 0; ic < INVOC_COUNT; ic++) { + for (int i = 0; i < res.length; i++) { + res[i] = isNaN(farr[i]); + } + } + assertArraysEquals(res, farr, (fp16) -> Float.isNaN(fp16.floatValue())); + } + + @Test(dataProvider = "binaryOpProvider") + public static void addTest(Object input1, Object input2) { + Float16 [] farr1 = (Float16[])input1; + Float16 [] farr2 = (Float16[])input2; + + Float16 [] res = new Float16[farr1.length]; + for (int ic = 0; ic < INVOC_COUNT; ic++) { + for (int i = 0; i < res.length; i++) { + res[i] = add(farr1[i], farr2[i]); + } + } + assertArraysEquals(res, farr1, farr2, (fp16_val1, fp16_val2) -> valueOf(fp16_val1.floatValue() + fp16_val2.floatValue())); + } + + @Test(dataProvider = "binaryOpProvider") + public static void subtractTest(Object input1, Object input2) { + Float16 [] farr1 = (Float16[])input1; + Float16 [] farr2 = (Float16[])input2; + + Float16 [] res = new Float16[farr1.length]; + for (int ic = 0; ic < INVOC_COUNT; ic++) { + for (int i = 0; i < res.length; i++) { + res[i] = subtract(farr1[i], farr2[i]); + } + } + assertArraysEquals(res, farr1, farr2, (fp16_val1, fp16_val2) -> valueOf(fp16_val1.floatValue() - fp16_val2.floatValue())); + } + + @Test(dataProvider = "binaryOpProvider") + public static void multiplyTest(Object input1, Object input2) { + Float16 [] farr1 = (Float16[])input1; + Float16 [] farr2 = (Float16[])input2; + + Float16 [] res = new Float16[farr1.length]; + for (int ic = 0; ic < INVOC_COUNT; ic++) { + for (int i = 0; i < res.length; i++) { + res[i] = multiply(farr1[i], farr2[i]); + } + } + assertArraysEquals(res, farr1, farr2, (fp16_val1, fp16_val2) -> valueOf(fp16_val1.floatValue() * fp16_val2.floatValue())); + } + + @Test(dataProvider = "binaryOpProvider") + public static void divideTest(Object input1, Object input2) { + Float16 [] farr1 = (Float16[])input1; + Float16 [] farr2 = (Float16[])input2; + + Float16 [] res = new Float16[farr1.length]; + for (int ic = 0; ic < INVOC_COUNT; ic++) { + for (int i = 0; i < res.length; i++) { + res[i] = divide(farr1[i], farr2[i]); + } + } + assertArraysEquals(res, farr1, farr2, (fp16_val1, fp16_val2) -> valueOf(fp16_val1.floatValue() / fp16_val2.floatValue())); + } + + @Test(dataProvider = "binaryOpProvider") + public static void maxTest(Object input1, Object input2) { + Float16 [] farr1 = (Float16[])input1; + Float16 [] farr2 = (Float16[])input2; + + Float16 [] res = new Float16[farr1.length]; + for (int ic = 0; ic < INVOC_COUNT; ic++) { + for (int i = 0; i < res.length; i++) { + res[i] = max(farr1[i], farr2[i]); + } + } + assertArraysEquals(res, farr1, farr2, (fp16_val1, fp16_val2) -> valueOf(Float.max(fp16_val1.floatValue(), fp16_val2.floatValue()))); + } + + @Test(dataProvider = "binaryOpProvider") + public static void minTest(Object input1, Object input2) { + Float16 [] farr1 = (Float16[])input1; + Float16 [] farr2 = (Float16[])input2; + + Float16 [] res = new Float16[farr1.length]; + for (int ic = 0; ic < INVOC_COUNT; ic++) { + for (int i = 0; i < res.length; i++) { + res[i] = min(farr1[i], farr2[i]); + } + } + assertArraysEquals(res, farr1, farr2, (fp16_val1, fp16_val2) -> valueOf(Float.min(fp16_val1.floatValue(), fp16_val2.floatValue()))); + } + + @Test(dataProvider = "ternaryOpProvider") + public static void fmaTest(Object input1, Object input2, Object input3) { + Float16 [] farr1 = (Float16[])input1; + Float16 [] farr2 = (Float16[])input2; + Float16 [] farr3 = (Float16[])input2; + + Float16 [] res = new Float16[farr1.length]; + for (int ic = 0; ic < INVOC_COUNT; ic++) { + for (int i = 0; i < res.length; i++) { + res[i] = fma(farr1[i], farr2[i], farr3[i]); + } + } + assertArraysEquals(res, farr1, farr2, farr3, (fp16_val1, fp16_val2, fp16_val3) -> valueOf(Math.fma(fp16_val1.floatValue(), fp16_val2.floatValue(), fp16_val3.floatValue()))); + } +} diff --git a/test/micro/org/openjdk/bench/jdk/incubator/vector/Float16OperationsBenchmark.java b/test/micro/org/openjdk/bench/jdk/incubator/vector/Float16OperationsBenchmark.java new file mode 100644 index 0000000000000..763c45d3470d2 --- /dev/null +++ b/test/micro/org/openjdk/bench/jdk/incubator/vector/Float16OperationsBenchmark.java @@ -0,0 +1,296 @@ +/* + * Copyright (c) 2024, Oracle and/or its affiliates. All rights vectorReserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood ShovectorRes, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package org.openjdk.bench.java.lang; + +import java.util.stream.IntStream; +import java.util.concurrent.TimeUnit; +import jdk.incubator.vector.*; +import org.openjdk.jmh.annotations.*; +import static jdk.incubator.vector.Float16.*; +import static java.lang.Float.*; + +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@State(Scope.Thread) +@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector", "-Xbatch", "-XX:-TieredCompilation"}) +public class Float16OperationsBenchmark { + @Param({"256", "512", "1024", "2048"}) + int vectorDim; + + int [] rexp; + short [] vectorRes; + short [] vector1; + short [] vector2; + short [] vector3; + boolean [] vectorPredicate; + + static final short f16_one = Float.floatToFloat16(1.0f); + static final short f16_two = Float.floatToFloat16(2.0f); + + @Setup(Level.Trial) + public void BmSetup() { + rexp = new int[vectorDim]; + vectorRes = new short[vectorDim]; + vector1 = new short[vectorDim]; + vector2 = new short[vectorDim]; + vector3 = new short[vectorDim]; + vectorPredicate = new boolean[vectorDim]; + + IntStream.range(0, vectorDim).forEach(i -> {vector1[i] = Float.floatToFloat16((float)i);}); + IntStream.range(0, vectorDim).forEach(i -> {vector2[i] = Float.floatToFloat16((float)i);}); + IntStream.range(0, vectorDim).forEach(i -> {vector3[i] = Float.floatToFloat16((float)i);}); + + // Special Values + Float16 [] specialValues = {Float16.NaN, Float16.NEGATIVE_INFINITY, Float16.valueOf(0.0), Float16.valueOf(-0.0), Float16.POSITIVE_INFINITY}; + IntStream.range(0, vectorDim).forEach( + i -> { + if ((i % 64) == 0) { + int idx1 = i % specialValues.length; + int idx2 = (i + 1) % specialValues.length; + int idx3 = (i + 2) % specialValues.length; + vector1[i] = float16ToRawShortBits(specialValues[idx1]); + vector2[i] = float16ToRawShortBits(specialValues[idx2]); + vector3[i] = float16ToRawShortBits(specialValues[idx3]); + } + } + ); + } + + @Benchmark + public void addBenchmark() { + for (int i = 0; i < vectorDim; i++) { + vectorRes[i] = float16ToRawShortBits(add(shortBitsToFloat16(vector1[i]), shortBitsToFloat16(vector2[i]))); + } + } + + @Benchmark + public void subBenchmark() { + for (int i = 0; i < vectorDim; i++) { + vectorRes[i] = float16ToRawShortBits(subtract(shortBitsToFloat16(vector1[i]), shortBitsToFloat16(vector2[i]))); + } + } + + @Benchmark + public void mulBenchmark() { + for (int i = 0; i < vectorDim; i++) { + vectorRes[i] = float16ToRawShortBits(multiply(shortBitsToFloat16(vector1[i]), shortBitsToFloat16(vector2[i]))); + } + } + + @Benchmark + public void divBenchmark() { + for (int i = 0; i < vectorDim; i++) { + vectorRes[i] = float16ToRawShortBits(divide(shortBitsToFloat16(vector1[i]), shortBitsToFloat16(vector2[i]))); + } + } + + @Benchmark + public void fmaBenchmark() { + for (int i = 0; i < vectorDim; i++) { + vectorRes[i] = float16ToRawShortBits(fma(shortBitsToFloat16(vector1[i]), shortBitsToFloat16(vector2[i]), shortBitsToFloat16(vector3[i]))); + } + } + + @Benchmark + public boolean isInfiniteBenchmark() { + boolean res = true; + for (int i = 0; i < vectorDim; i++) { + res &= isInfinite(shortBitsToFloat16(vector1[i])); + } + return res; + } + + @Benchmark + public boolean isFiniteBenchmark() { + boolean res = true; + for (int i = 0; i < vectorDim; i++) { + res &= isFinite(shortBitsToFloat16(vector1[i])); + } + return res; + } + + @Benchmark + public boolean isNaNBenchmark() { + boolean res = true; + for (int i = 0; i < vectorDim; i++) { + res &= isNaN(shortBitsToFloat16(vector1[i])); + } + return res; + } + + @Benchmark + public void isNaNStoreBenchmark() { + for (int i = 0; i < vectorDim; i++) { + vectorPredicate[i] = Float16.isNaN(shortBitsToFloat16(vector1[i])); + } + } + + + @Benchmark + public void isNaNCMovBenchmark() { + for (int i = 0; i < vectorDim; i++) { + vectorRes[i] = Float16.isNaN(shortBitsToFloat16(vector1[i])) ? f16_one : f16_two; + } + } + + + @Benchmark + public void isInfiniteStoreBenchmark() { + for (int i = 0; i < vectorDim; i++) { + vectorPredicate[i] = Float16.isInfinite(shortBitsToFloat16(vector1[i])); + } + } + + + @Benchmark + public void isInfiniteCMovBenchmark() { + for (int i = 0; i < vectorDim; i++) { + vectorRes[i] = Float16.isInfinite(shortBitsToFloat16(vector1[i])) ? f16_one : f16_two; + } + } + + + @Benchmark + public void isFiniteStoreBenchmark() { + for (int i = 0; i < vectorDim; i++) { + vectorPredicate[i] = Float16.isFinite(shortBitsToFloat16(vector1[i])); + } + } + + + @Benchmark + public void isFiniteCMovBenchmark() { + for (int i = 0; i < vectorDim; i++) { + vectorRes[i] = Float16.isFinite(shortBitsToFloat16(vector1[i])) ? f16_one : f16_two; + } + } + + @Benchmark + public void maxBenchmark() { + for (int i = 0; i < vectorDim; i++) { + vectorRes[i] = float16ToRawShortBits(max(shortBitsToFloat16(vector1[i]), shortBitsToFloat16(vector2[i]))); + } + } + + @Benchmark + public void minBenchmark() { + for (int i = 0; i < vectorDim; i++) { + vectorRes[i] = float16ToRawShortBits(min(shortBitsToFloat16(vector1[i]), shortBitsToFloat16(vector2[i]))); + } + } + + @Benchmark + public void sqrtBenchmark() { + for (int i = 0; i < vectorDim; i++) { + vectorRes[i] = float16ToRawShortBits(sqrt(shortBitsToFloat16(vector1[i]))); + } + } + + @Benchmark + public void negateBenchmark() { + for (int i = 0; i < vectorDim; i++) { + vectorRes[i] = float16ToRawShortBits(negate(shortBitsToFloat16(vector1[i]))); + } + } + + @Benchmark + public void absBenchmark() { + for (int i = 0; i < vectorDim; i++) { + vectorRes[i] = float16ToRawShortBits(abs(shortBitsToFloat16(vector1[i]))); + } + } + + @Benchmark + public void getExponentBenchmark() { + for (int i = 0; i < vectorDim; i++) { + rexp[i] = getExponent(shortBitsToFloat16(vector1[i])); + } + } + + @Benchmark + public short cosineSimilarityDoubleRoundingFP16() { + short macRes = floatToFloat16(0.0f); + short vector1Square = floatToFloat16(0.0f); + short vector2Square = floatToFloat16(0.0f); + for (int i = 0; i < vectorDim; i++) { + // Explicit add and multiply operation ensures double rounding. + Float16 vec1 = shortBitsToFloat16(vector1[i]); + Float16 vec2 = shortBitsToFloat16(vector2[i]); + macRes = float16ToRawShortBits(add(multiply(vec1, vec2), shortBitsToFloat16(macRes))); + vector1Square = float16ToRawShortBits(add(multiply(vec1, vec1), shortBitsToFloat16(vector1Square))); + vector2Square = float16ToRawShortBits(add(multiply(vec2, vec2), shortBitsToFloat16(vector2Square))); + } + return float16ToRawShortBits(divide(shortBitsToFloat16(macRes), add(shortBitsToFloat16(vector1Square), shortBitsToFloat16(vector2Square)))); + } + + @Benchmark + public short cosineSimilaritySingleRoundingFP16() { + short macRes = floatToFloat16(0.0f); + short vector1Square = floatToFloat16(0.0f); + short vector2Square = floatToFloat16(0.0f); + for (int i = 0; i < vectorDim; i++) { + Float16 vec1 = shortBitsToFloat16(vector1[i]); + Float16 vec2 = shortBitsToFloat16(vector2[i]); + macRes = float16ToRawShortBits(fma(vec1, vec2, shortBitsToFloat16(macRes))); + vector1Square = float16ToRawShortBits(fma(vec1, vec1, shortBitsToFloat16(vector1Square))); + vector2Square = float16ToRawShortBits(fma(vec2, vec2, shortBitsToFloat16(vector2Square))); + } + return float16ToRawShortBits(divide(shortBitsToFloat16(macRes), add(shortBitsToFloat16(vector1Square), shortBitsToFloat16(vector2Square)))); + } + + @Benchmark + public short cosineSimilarityDequantizedFP16() { + float macRes = 0.0f; + float vector1Square = 0.0f; + float vector2Square = 0.0f; + for (int i = 0; i < vectorDim; i++) { + float vec1 = float16ToFloat(vector1[i]); + float vec2 = float16ToFloat(vector2[i]); + macRes = Math.fma(vec1, vec2, macRes); + vector1Square = Math.fma(vec1, vec1, vector1Square); + vector2Square = Math.fma(vec2, vec2, vector2Square); + } + return floatToFloat16(macRes / (vector1Square + vector2Square)); + } + + @Benchmark + public short euclideanDistanceFP16() { + short distRes = floatToFloat16(0.0f); + short squareRes = floatToFloat16(0.0f); + for (int i = 0; i < vectorDim; i++) { + squareRes = float16ToRawShortBits(subtract(shortBitsToFloat16(vector1[i]), shortBitsToFloat16(vector2[i]))); + distRes = float16ToRawShortBits(fma(shortBitsToFloat16(squareRes), shortBitsToFloat16(squareRes), shortBitsToFloat16(distRes))); + } + return float16ToRawShortBits(sqrt(shortBitsToFloat16(distRes))); + } + + @Benchmark + public short euclideanDistanceDequantizedFP16() { + float distRes = 0.0f; + float squareRes = 0.0f; + for (int i = 0; i < vectorDim; i++) { + squareRes = float16ToFloat(vector1[i]) - float16ToFloat(vector2[i]); + distRes = distRes + squareRes * squareRes; + } + return float16ToRawShortBits(sqrt(shortBitsToFloat16(floatToFloat16(distRes)))); + } +}