Skip to content

Refactor BLAS code#2049

Open
jessegrabowski wants to merge 5 commits into
pymc-devs:mainfrom
jessegrabowski:blas-refactor
Open

Refactor BLAS code#2049
jessegrabowski wants to merge 5 commits into
pymc-devs:mainfrom
jessegrabowski:blas-refactor

Conversation

@jessegrabowski

@jessegrabowski jessegrabowski commented Apr 13, 2026

Copy link
Copy Markdown
Member

I was inspired by the linalg refactor so I wanted to make a pass at BLAS. My objectives here is to make the BLAS code more maintainable. I am doing this by:

  • Convert blas.py into a module
  • Split out one file per blas Op (GEMM, GEMV, GER)
  • Move string codegen into dedicated .h or .c files

Point 3 has two levels we could pursue. Level one is to extract all static code into headers that can be pulled into the string codegen. This is done for all three BLAS functions. The second level is to move all string codegen into a helper function, then only codegen the call to that function. I did this only to GER in the last commit. It's significantly more readable, but it also has the overhead of being a function vs 100% inline. We can discuss.

I think one more step I want to explore in this PR is moving all of the c code and potentially the COps to link/c/blas instead of tensor/blas, as a part of this idea that "C should just be another backend" raised in #2006

@jessegrabowski

Copy link
Copy Markdown
Member Author

Some fun facts I didn't know:

  • We link against blas_fortran, not blas_c. Wow!
  • There is a bug reported in 2006 (!!) about Accelerate's fortran blas bindings. Their sdot function is wrong. That has never been patched. We have to run hack code to check if the bug is there and monkeypatch the fortran_blas sdot to the c_blas sdot on mac every time we import pytensor. Who knew?
  • I also learned that pytorch and mlx have specialized GEMM (admm) and BatchedDot (bmm) functions. We should almost certainly kill the BlasOpt rewrite database and just make them normal rewrites. We can write pullbacks for GEMM, GEMV, and GER trivially, and we gain the ability to do nice dispatches.

@ricardoV94

ricardoV94 commented Apr 13, 2026

Copy link
Copy Markdown
Member

We can write pullbacks for GEMM, GEMV, and GER trivially, and we gain the ability to do nice dispatches.

We can, but we don't usually bother with grads for specialized Ops. I wouldn't expose user-facing GEMM but always start with the canonical Dot + alpha forms for which we have other rewrites/batch-rules/etc...

@jessegrabowski

Copy link
Copy Markdown
Member Author

I'm not thinking about something user-facing for sure. The pullbacks are so simple that it might result in better graphs than trying to start from general forms and rewrite both the forward and backward graph. Not sure. I want to modernize the rewrites next.

@ricardoV94

Copy link
Copy Markdown
Member

The pullbacks are so simple that it might result in better graphs than trying to start from general forms and rewrite both the forward and backward graph. Not sure. I want to modernize the rewrites next.

Simple or not is also more code we need to test and maintain. How hard are the BLAS patterns really? Two dots and some scalar multiplications? If we can't handle those we have bigger problems

@jessegrabowski

jessegrabowski commented Apr 13, 2026

Copy link
Copy Markdown
Member Author

If we can't handle those we have bigger problems

Well, empirically...

@ricardoV94

Copy link
Copy Markdown
Member

Ok with dropping GER, there's a rewrite that was opting out of Dot -> Outer Mul, let's not opt out after we drop it

Move the monolithic blas.py / blas_c.py / blas_headers.py into a blas/
package split by Op (gemm, gemv, ger, batched, blas_c, blas_headers,
_core). No behavior change; codegen is unchanged. Update import paths.
Extract the GEMM/GEMV/GER C codegen out of the Op classes into generator
functions in _codegen.py that take the Apply node, and hold the BLAS C
support code as Python strings in _c_code.py (no on-disk .c/.h files).
Behavior-preserving refactor.
Emit only the statically-known precision for Gemm/Dot22/Dot22Scalar/CGemv,
dropping the runtime switch(type_num) and dtype guards.
inplace is a static op prop, so Gemm/CGemv emit only the relevant output
setup arm; drop the runtime params->inplace branch and their params_type.
Rewrite ger_c_code to inline the orchestration (dtype- and destructive-
pinned, validation elided) calling only the leaf loops / BLAS dispatch in
GER_HELPER; drop CGer's params_type.
@jessegrabowski

Copy link
Copy Markdown
Member Author

Punting on dropping GER for now, needs real benchmarking.

I reorganized this PR following a conversation with @ricardoV94 this morning. He didn't like all the .h/.c/.cpp files -- fair enough. Moreover he pointed out that we should be using codegen everywhere because we use node information to create specialized kernels. Actually the BLAS code wasn't doing that -- the codegen path wasn't looking at the nodes at all. It was generating large C functions with a bunch of conditional logic to pick the right thing at runtime. That might be what we have to do (e.g. if static shapes are unknown), but we do know things like dtype at compile time, and we weren't using that information.

So this PR now has 5 commits that I hope we agree are much tighter. blas.py is split into a blas/ package and the codegen is moved into node-taking generators. These are the first 2 commits. The next 3 are adding specializations to the codegen based on information from the nodes. We could drop these if we want, but my minibenching found it's no worse than main, maybe ~10% speedup on small/medium GEMM/Dot22 (my benchmarks are polluted though because I've been running a heavy model for 3 days and I can't stop it now)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants