Add MIGraphX execution provider support#2165
Conversation
|
@aditya-dl please read the following Contributor License Agreement(CLA). If you agree with the CLA, please reply with the following information.
Contributor License AgreementContribution License AgreementThis Contribution License Agreement (“Agreement”) is agreed to by the party signing below (“You”),
|
|
@kunal-vaishnavi @baijumeswani could you help review this PR? |
There was a problem hiding this comment.
Pull request overview
Note
Copilot was unable to run its full agentic suite in this review.
Adds MIGraphX (AMD GPU) execution provider integration to ONNX Runtime GenAI, including provider name normalization and runtime behaviors needed to avoid recompilation during prompt processing.
Changes:
- Register MIGraphX in the session-options dispatch table and add a MIGraphX EP implementation (V2 plugin path with V1 fallback).
- Normalize provider names so
"migraphx"and"MIGraphXExecutionProvider"map to"MIGraphX", and enable graph capture for MIGraphX. - Add “static input shape” prompt-time padding and update position/logits shapes to support padded prompt lengths.
Reviewed changes
Copilot reviewed 13 out of 13 changed files in this pull request and generated 11 comments.
Show a summary per file
| File | Description |
|---|---|
| src/models/session_options.cpp | Adds MIGraphX EP registration in the provider dispatch map. |
| src/models/position_inputs.h | Updates CreateAndInitializePositionIDs signature to accept actual seq_length. |
| src/models/position_inputs.cpp | Implements position id initialization w/ padded shapes and safe indexing using seq_length. |
| src/models/model.h | Adds State::prompt_gen_ flag to control prompt-time padding behavior. |
| src/models/model.cpp | Pads input_ids to max_length during prompt generation for static-shape EPs. |
| src/models/logits.cpp | Forces logits output shape to max_length during prompt generation. |
| src/migraphx/session_options.h | Declares MIGraphX EP append entrypoint. |
| src/migraphx/session_options.cpp | Implements MIGraphX EP append with V2 plugin-first and V1 fallback. |
| src/generators.h | Adds use_static_input_shapes to generator params. |
| src/generators.cpp | Toggles prompt_gen_ around prompt vs token-generation runs. |
| src/config.h | Declares NeedsStaticInputShapes. |
| src/config.cpp | Adds MIGraphX name normalization, enables graph capture for MIGraphX, and implements NeedsStaticInputShapes. |
| cmake/global_variables.cmake | Adds MIGraphX sources to the CMake glob. |
c947215 to
c2f4145
Compare
Add MIGraphX (AMD GPU) execution provider support to ONNX Runtime
GenAI. The provider is exposed as "migraphx" to users; OGA also
accepts "MIGraphXExecutionProvider" (the catalog form used by the
AMD-shipped Windows ML EP MSIX) so test harnesses that match
config strings against WinML-discovered names work without bypass
hacks.
Changes:
- Create src/migraphx/session_options.{h,cpp} with
AppendExecutionProvider that tries V2 plugin path
("MIGraphXExecutionProvider") then falls back to V1 legacy API
("MIGraphX")
- Add provider name normalization: both "migraphx" and
"MIGraphXExecutionProvider" map to "MIGraphX"; register in the
dispatch table
- Enable graph capture for MIGraphX to allow compiled graph reuse
during token generation
- Add static input shape padding (prompt_gen_ flag) so MIGraphX
avoids recompilation on varying prompt lengths. Gated behind
NeedsStaticInputShapes() to only activate for MIGraphX, not
other EPs
- Fix input_ids padding to handle both int32 and int64 element
types, and correct per-row copy for batch_size > 1
- Fix position_ids padding to prevent out-of-bounds read on
next_tokens when tensor shape is padded to max_length for
batch_size > 1
Configuration:
"provider_options": [{ "migraphx": {} }]
or: config.append_provider("migraphx")
also accepted: "MIGraphXExecutionProvider" (catalog form)
Known limitations:
- Beam search not supported (requires past_present_share_buffer=true
which requires num_beams=1)
- Inputs allocated on CPU; the MIGraphX EP handles CPU<->GPU
transfers internally
c2f4145 to
d563801
Compare
Add MIGraphX (AMD GPU) execution provider support to ONNX Runtime GenAI. The provider is exposed as
"migraphx"to users (with"MIGraphXExecutionProvider"accepted as a catalog-name alias for compatibility with Windows ML EP discovery flows).This PR supersedes #2093 (which used the now-reverted "AMDGPU" naming).
Changes:
src/migraphx/session_options.{h,cpp}withAppendExecutionProviderthat tries V2 plugin path ("MIGraphXExecutionProvider") then falls back to V1 legacy API ("MIGraphX")"migraphx"and"MIGraphXExecutionProvider"map to"MIGraphX"; register"MIGraphX"in the dispatch tableprompt_gen_flag) so MIGraphX avoids recompilation on varying prompt lengths. Gated behindNeedsStaticInputShapes()so it only activates for MIGraphX, not other EPsinput_idspadding to handle both int32 and int64 element types, and correct per-row copy forbatch_size > 1position_idspadding to prevent out-of-bounds read onnext_tokenswhen tensor shape is padded tomax_lengthforbatch_size > 1Configuration:
or
Also accepted:
"MIGraphXExecutionProvider"(catalog form, what Windows ML returns from the EP discovery API).Known limitations:
past_present_share_buffer=truewhich requiresnum_beams=1)