Refactor BLAS code#2049
Conversation
|
Some fun facts I didn't know:
|
We can, but we don't usually bother with grads for specialized Ops. I wouldn't expose user-facing GEMM but always start with the canonical Dot + alpha forms for which we have other rewrites/batch-rules/etc... |
|
I'm not thinking about something user-facing for sure. The pullbacks are so simple that it might result in better graphs than trying to start from general forms and rewrite both the forward and backward graph. Not sure. I want to modernize the rewrites next. |
Simple or not is also more code we need to test and maintain. How hard are the BLAS patterns really? Two dots and some scalar multiplications? If we can't handle those we have bigger problems |
Well, empirically... |
|
Ok with dropping GER, there's a rewrite that was opting out of Dot -> Outer Mul, let's not opt out after we drop it |
Move the monolithic blas.py / blas_c.py / blas_headers.py into a blas/ package split by Op (gemm, gemv, ger, batched, blas_c, blas_headers, _core). No behavior change; codegen is unchanged. Update import paths.
Extract the GEMM/GEMV/GER C codegen out of the Op classes into generator functions in _codegen.py that take the Apply node, and hold the BLAS C support code as Python strings in _c_code.py (no on-disk .c/.h files). Behavior-preserving refactor.
Emit only the statically-known precision for Gemm/Dot22/Dot22Scalar/CGemv, dropping the runtime switch(type_num) and dtype guards.
inplace is a static op prop, so Gemm/CGemv emit only the relevant output setup arm; drop the runtime params->inplace branch and their params_type.
Rewrite ger_c_code to inline the orchestration (dtype- and destructive- pinned, validation elided) calling only the leaf loops / BLAS dispatch in GER_HELPER; drop CGer's params_type.
8b69f18 to
6cab685
Compare
|
Punting on dropping GER for now, needs real benchmarking. I reorganized this PR following a conversation with @ricardoV94 this morning. He didn't like all the .h/.c/.cpp files -- fair enough. Moreover he pointed out that we should be using codegen everywhere because we use node information to create specialized kernels. Actually the BLAS code wasn't doing that -- the codegen path wasn't looking at the nodes at all. It was generating large C functions with a bunch of conditional logic to pick the right thing at runtime. That might be what we have to do (e.g. if static shapes are unknown), but we do know things like dtype at compile time, and we weren't using that information. So this PR now has 5 commits that I hope we agree are much tighter. |
I was inspired by the linalg refactor so I wanted to make a pass at BLAS. My objectives here is to make the BLAS code more maintainable. I am doing this by:
blas.pyinto a module.hor.cfilesPoint 3 has two levels we could pursue. Level one is to extract all static code into headers that can be pulled into the string codegen. This is done for all three BLAS functions. The second level is to move all string codegen into a helper function, then only codegen the call to that function. I did this only to GER in the last commit. It's significantly more readable, but it also has the overhead of being a function vs 100% inline. We can discuss.
I think one more step I want to explore in this PR is moving all of the c code and potentially the COps to
link/c/blasinstead oftensor/blas, as a part of this idea that "C should just be another backend" raised in #2006