Rewrite det(inv(X)) → 1/det(X)#2102
Rewrite det(inv(X)) → 1/det(X)#2102alessandrogentili001 wants to merge 2 commits intopymc-devs:mainfrom
Conversation
| @register_stabilize | ||
| @register_specialize | ||
| @node_rewriter([Elemwise]) | ||
| def local_reciprocal_linalg_special_cases(fgraph, node): |
There was a problem hiding this comment.
These aren't related to linalg, the name is wrong
Each of these should be a separate rewrite, and we should be using the new scalar Op properties. We have monotonic_increasing and monotonic_decreasing, which imply sign(f(x)) -> sign(x) and sign(f(x)) -> sign(-x) if the op is also zero_preserving.
We just have to be careful about strict montonicity vs non-strict. I can't remember if ceil(x) is marked as monotonic for example. We might need a separate flag for the strict variety and check it in this rewrite.
| def det_of_inv(fgraph, node): | ||
| """Replace det(matrix_inverse(X)) with reciprocal(det(X)). | ||
|
|
||
| Since det(inv(X)) = 1/det(X), we avoid computing the inverse. |
There was a problem hiding this comment.
| Since det(inv(X)) = 1/det(X), we avoid computing the inverse. |
|
Hi jesse, thanks for the feedback! I've refactored the implementation to address your points regarding modularity and property-based rewriting.
Let me know if you have any further suggestions! |
| class Erf(UnaryScalarOp): | ||
| preserves_zero = True | ||
| monotonic_increasing = True | ||
| strictly_monotonic_increasing = True |
There was a problem hiding this comment.
don't love this, what cases disagree right now between the two?
There was a problem hiding this comment.
ceil and floor, for example
There was a problem hiding this comment.
Also I don't get it, any discrete input version to these ops is also not strictly monotonic.
nvm
There was a problem hiding this comment.
Don't mark those instead?
I'm not against that, but then we should use the strictly_ language everywhere (drop the shorter one) to be clear what is going on
There was a problem hiding this comment.
Circling back to the sign thing, if that's the motivation, I don't you can apply it based on monotonicity, strict or not. sign(exp(x)) is obviously not sign(x).
We are adding these properties for specific uses, not for mathematical idealism, so they need not be verbose nor geberalized besides the problems we want to solve with them. sctrict vs non strict is more a question of invertible 1-1 map not the direction. wouldn't a combination ot those 2 poperties + zero preserving be a better way to achieve the goal?
There was a problem hiding this comment.
exp isn't zero preserving, so the rule doesn't apply. Both things are important. It has to be strictly monotonic increasing and zero preserving. I think taking strict monotonicity as our canonical form is nice, because who cares about ceil/floor anyway. But I also think it's important to be clear in language, otherwise someone can come along in a few years and say "Well technicaly BitwiseInverse is monotonic_increasing, why isn't it marked" and the answer is "because we define monotonicity as strict monotonicity but it isn't written anywhere". We lose nothing by just writing it.
There was a problem hiding this comment.
Yeah I know sign thing doesn't apply to zero, did I say so?
Otherwise okay, we can go with verbose, don't love it but it's strictly more precise.
Can we stop there and not at strictly_monotonic_increasing_over_defined_domain?
There was a problem hiding this comment.
I was reacting to this: sign(exp(x)) is obviously not sign(x).
Maybe i misunderstood your point.
I'm not being dogmatic about the name change, if we want to just define monotonic to mean "strictly monotonic" and put it in the docs somewhere, I have no objection. I just want it to be written down, and I like self-documenting code. Agreed that there is a limit.
| @register_specialize | ||
| @node_rewriter([log]) | ||
| def local_log_reciprocal(fgraph, node): | ||
| """Rewrite log(reciprocal(x)) -> -log(x).""" |
There was a problem hiding this comment.
we should do the more general as well (reciprocal is fine): log(a/b), where a or b is a non-negative constant -> log(a) - log(b) (the constant constant-folded already).
| @register_specialize | ||
| @node_rewriter([sign]) | ||
| def local_sign_reciprocal(fgraph, node): | ||
| """Rewrite sign(reciprocal(x)) -> sign(x).""" |
There was a problem hiding this comment.
same here sign a/b, where one is a positive constant -> sign of the other term. If the constant is negative, 1-sign of the other. If it's mixed, can't do anything
Description
Related Issue
Checklist
Type of change