Speedup from switch to +=
115 points by j0e1 2 years ago | 79 comments- chillee 2 years agoOk, I work on PyTorch, so probably should clear up some misconceptions in this thread.
1. In PyTorch (and other array programming libraries like Numpy), the operations being passed around are tensors/arrays (i.e. large chunks of memory). Thus, += is overloaded to mean "in-place write" to the arrays.
So, `+` vs `+=` is the equivalent of
vs.a: float[1000] b: float[1000] for i in [0, 1000]: b[i] = a[i] + 2
The main performance advantage comes in 1. no need to allocate an extra array, 2. you're using less memory overall, so various caching levels can work better. It has nothing to do with python bytecodes.a: float[1000] for i in [0, 1000]: a[i] = a[i] + 2
2. As for whether it generally makes sense to do this optimization manually... Usually, PyTorch users don't use in-place operations as its a bit uglier mathematically and have various foot-guns/restrictions that users find confusing. Generally, it's best to have this optimization be done automatically by an optimizing compiler.
3. PyTorch in general does support using in-place operations during training, albeit with some caveats.
(PS) 4. Putting everything on one line (as some folks suggest) is almost certainly not going to help performance - the primary performance bottlenecks here have almost nothing to do with CPU perf.
- teruakohatu 2 years agoThanks for the input. Before I start throwing += into my PyTorch code can you explain what you mean here:
> Generally, it's best to have this optimization be done automatically by an optimizing compiler.
What compiler should be optimizing this operation?
There are comments on the commit reporting errors under certain conditions.
- chillee 2 years agoTo clarify, by "compilers" I mean "deep learning compilers".
There's many different paths to optimizing compilers folks use with PyTorch. One with close integration is NVFuser (see https://www.reddit.com/r/MachineLearning/comments/xa75km/p_p...), although there are other compilers like ONNXRuntime.
Yes, handling autograd (during training) is a whole different thing, and not all compilers support that.
- chillee 2 years ago
- teruakohatu 2 years ago
- FabHK 2 years agoPlot twist: it breaks the code...?
> Changing this back to the original implementation fixed an error I was getting when doing textual inversion on Windows
https://github.com/lstein/stable-diffusion/commit/62863ac586...
- staticassertion 2 years agoLove to see it. A perfect example of why this optimization can't be done automatically - in the case of `else` you're working with a mutable reference to `x` passed in, which means that now your function is mutating something it used to not mutate.
A "safe" way to do this is still straightforward, I think.
It could be faster but I don't know what `x` is and I'm not going to guess. Also, `copy` may not be sufficient, `deepcopy` may be necessary - again, I don't know what `x` is so I can't figure that out. Pls use type annotations :)from copy import copy def _forward(self, x, context=None): x = x.contiguous() if x.device.type == 'mps' else x x = copy(x) x += self.attn1(self.norm1(x)) x += self.attn2(self.norm2(x), context=context) x += self.ff(self.norm3(x)) return x
- FabHK 2 years agoHow about this? (As the copy operation implicit in x=x+y seemed ok.)
def _forward(self, x, context=None): x = x.contiguous() if x.device.type == 'mps' else x x = x + self.attn1(self.norm1(x)) x += self.attn2(self.norm2(x), context=context) x += self.ff(self.norm3(x)) return x
- Someone 2 years agoAlternatively, keep the first line as is. That gives you a copy that’s only known to the function, so you can change the later ones.
I would only do that if I had seen it to be faster, though, and add a comment on why the first line couldn’t do +=.
- mgraczyk 2 years agoThat's not safe if the problem is the in place mutation. You will still mutate x while reading from it.
- Arnavion 2 years agostaticassertion's point is that the current code's usage of `+=` mutates the x that was passed in by the caller, and their suggestion is to copy x into a function local before mutating it, which is similar to how the original `+` code also worked on a function local x (the result of `attn1() + x`).
- 2 years ago
- desmond373 2 years agoMight be platform dependent whether that first line counts as a mutate or not, seeing as it can be converted to not do anything in some cases.
- Arnavion 2 years ago
- FabHK 2 years ago
- chrismorgan 2 years agoThis, incidentally, demonstrates why I love the ownership model (single ownership and aliasing-xor-mutability referencing) seen in Rust (and a few other languages are poking around with similar concepts). When working in Python or JavaScript as I sometimes do, it’s generally the feature I miss the most.
The problem here comes down to not knowing whether you’re allowed to modify a value in-place or not, because it’s not clear who owns it: it wasn’t written down anywhere, and in stable-diffusion alone it was fine to mutate it, but textual-inversion did something so it wasn’t (perhaps passing it something it expected to not be mutated). This is a moderately common type of bug that can be extraordinarily difficult to diagnose—it’s unusually easy to pinpoint here because it promptly raises a RuntimeError—and which is statically impossible in Rust, because the whole “am I allowed to mutate it” thing is resolved in the type system.
- staticassertion 2 years ago
- nodja 2 years agoI see lots of people answering why it's faster, but not many saying why the engineers chose the slower version.
As everyone said, this is more performant because x is being modified in place, the reason this was not done in place is because you can't train a neural network if an instruction is being done in place. During training a network goes literally through all operations that were done and see how well they performed so they can be adjusted using a secondary value called a gradient, this is done during the backwards pass. If you replace something in place you're essentially overwriting the input values that were passed to that function, and by extension, the output values of the function called before, essentially breaking the network chain, unless you also copy the inputs together with the gradients, which would cause an even worse performance hit and be a memory hog.
The breakage bug later in the issue is proof of this, when sampling to generate an image only the forward pass is done on the network, but textual inversion requires you to train the network and therefore do the backwards pass, triggering the error since the dependency graph is broken. I should also note that technically the add operation should be safe to do in place as it's reversible, but I'm not a pytorch expert so I'm not sure exactly what's going on in there.
- umvi 2 years agoSee, this is a great example of where a comment needed to be added, but wasn't.
If the engineers that originally implemented the function intentionally chose the slower version, a quick comment as to why would have prevented this from happening in the first place.
- nodja 2 years agoThis is common knowledge, so common that someone that hasn't coded anything besides some basic linear regression model like me knows about it. It's like commenting on why you'd put parenthesis in some formula, it's just gonna say "parenthesis here because this operation takes priority", similarly in a pytorch model, if it was done by those standards the code would be filled with "operation not done in place because it would break the network graph". You're more likely to encounter the opposite comment, "doing this operation in-place because it'll be discarded later" or something along those lines.
One of the first things you're taught when learning pytorch is that you're not coding in python, but actually creating a network graph that is loaded and executed on a GPU. Other common sense things is knowing that you shouldn't use stuff that is in the stdlib or in numpy and use torch.* variants instead, not doing so will incur either undefined behavior, cause massive memory copies between the CPU and GPU or most likely, error out at runtime.
Note that this is a repo that is forked from the official repo, it's a community repo focused on inference and thus doesn't care about training so it has completely different considerations than the original code.
- crabbycarrot 2 years agoOn ML teams, a comment like this would not get past code review because it's obvious – avoiding in-place operations in PyTorch is the standard.
- hbogert 2 years agoWill you be my colleague please?
This is idd the time to place a comment, yet so many people don't do that.
- nodja 2 years ago
- umvi 2 years ago
- ironhaven 2 years agoBecause of operator overloading "+=" can call a more optimized method than "+". If this code was written in a language without operator overloading I don't think this would be a very interesting pull request. THis could be a example of why some people don't like operator overloading and why some programing languages (java, zig, etc) do not implment the feature.
- staticassertion 2 years agoI don't think this is an operator overloading thing? It's just that `x = y + x` is equivalent to
Basically, creating an object `z` just to throw it away.z = y + x x = z
`x += y` just adds y to x directly without any intermediary.
You could write this in any language pretty easily. For example, in Rust:
as opposed to the more efficient:let x = "abc".to_string(); let y = "123".to_string(); let x = x + &y;
It's just using an operation to mutate in place vs an immutable operation.let mut x = "abc".to_string(); let y = "123".to_string(); x.push_str(&y);
- masklinn 2 years ago> I don't think this is an operator overloading thing?
It’s the confusion / idea that this is trivial change which is the overload thing.
- masklinn 2 years ago
- noobermin 2 years agoIf python did not have operator overloading it would not be used for numeric programming to the extent it is. Overloading is key to its success in that field.
The problem is thinking `+' and `+=' are the same, they are not and `+' should not be used when `+=' can be used.
- NavinF 2 years agoOperator overloading is a major reason why libraries like pytorch exist so IMO that's a moot point.
Btw there's ongoing work to automatically optimize expressions like this. See the XLA compiler for example. Right now deep learning has a ton of seemingly obvious compute/memory optimisations that are not done automatically.
- staticassertion 2 years ago
- JonathonW 2 years agoIf they're seeing these kinds of gains from relatively minor changes to their Python code, I can't help but wonder how much faster the model would run in a compiled language or a language with a good JIT (way more optimization work's gone into the mainstream Javascript runtimes than CPython).
I'd assumed that overall performance in Stable Diffusion was limited by the code running on the GPU, with Python performance being a fairly minor factor-- but I guess that's not the case?
- jonas21 2 years agoThis is PyTorch code, so the Python is setting up a bunch of kernels that are executed on the GPU. The switch from + to += might allow two of those kernels to be fused together or something, and that could lead to the large performance gain.
The Python part only runs a handful of times so JIT vs. non-JIT doesn't really make a difference.
- masklinn 2 years agoNah it’s because PyTorch has a different implementation for __iadd__. It’s saving a copy by mutating the LHS in-place, and possibly more divergent as comments report broken code.
- sdenton4 2 years agoI haven't played much with torch, but the game is generally that you have a graph of computations which gets JIT compiled into GPU ops. The compiler may have more or less competence at finding modifications (eg, 'fusions') to reduce the number of GPU ops required to perform the computation.
See, for example, XLA: https://www.tensorflow.org/xla
It looks like maybe nvFuser is an equivalent library for pytorch? https://pytorch.org/blog/introducing-nvfuser-a-deep-learning...
- brrrrrm 2 years agoThe Python code is run every time.
- stingraycharles 2 years agoYes but not nearly as much as the GPU code, which I think is what the parent is saying; it’s typically not the bottleneck.
- stingraycharles 2 years ago
- masklinn 2 years ago
- smhx 2 years agoIt can run much faster. For example, using the PyTorch nvFuser JIT gives a 50% speedup:
https://old.reddit.com/r/MachineLearning/comments/xa75km/p_p...
- thomasahle 2 years agoIn PyTorch `x = y + x` is actually semantically different from `x += y`, so you can't easily make the switch with a compiler.
The difference is that `x += y` modifies `x` inplace, where `x = x + y` creates a new object. In other words, if anybody had a reference to `x` before the update, the "optimized" code would break things.
- WirelessGigabit 2 years agoCompiler could use a pointer to pointer.
I guess this is the kind of this stuff that drew me to Rust. This kind of behavior gives me the creeps. Just like Ruby’s conventions.
- fragmede 2 years agoRust has the same behavior. https://news.ycombinator.com/item?id=32805756
- fragmede 2 years ago
- WirelessGigabit 2 years ago
- sshine 2 years agoI don't know anything about stable diffusion, but I've been optimizing a lot of prime-field arithmetic in Rust lately, and we experienced a similar speedup going from `+ x` to `+= x` (for scalars and especially for composite structures like vectors and polynomials).
- thayne 2 years agoFor composite structures that isn't too surprising, but for scalars, I would have expected llvm to optimize the addition and assignment into a single in place addition.
- codeflo 2 years agoThe short answer is yes.
The long answer is that it’s not so clear what an “in place addition” even means at the level of CPU instructions after you consider register allocation. For example, if you have
and the never mention v again, then the whole operation is performed directly in the register that is specified to receive the first argument in a function call, not in whatever register might have been allocated for v.v = x; v += y; f(v);
That’s because, with some complications I don’t want to go into, compilers look at the dependency graph of values rather than at the variable names.
- codeflo 2 years ago
- thayne 2 years ago
- eyelidlessness 2 years ago> way more optimization work's gone into the mainstream Javascript runtimes than CPython
Even so, there are absolutely silly things which can hint JS JITs to optimize (or to not deoptimize). Like defining and instantiating a class rather than just creating POJOs with the same values, or assigning NaN instead of null to uninitialized numeric variables/properties. Conditional control flow can deopt, but generally performs better around different function calls than within a single function. Even creating and throwing errors for control flow (which is generally expensive, and terrible for maintenance) can be optimal if your try/catch is the whole body of the function it resides in. And all of those might vary between JITs.
- luizfzs 2 years agoI've always assumed Python was interpreted until I heard Nuitka [1].
It would be interesting to get a benchmark using CPython vs Nuitka related to this change.
- joelgibson 2 years agoThis change isn't a matter of Python being slower than a compiled language, it's changing the meaning of the code. The line
creates a copy of the array x, adds y to it, and then sets the variable x to that new array. In contrast, the linex = x + y
adds the array y in-place into the array x (and so hopefully no other piece of code is relying on x being immutable). This kind of trade-off occurs in pretty much all programming, for instance you see it whenever big-integer libraries are used in C++ or Rust.x += y
- bee_rider 2 years agoI don't think this is necessarily a minor change -- += and + are operators. I have no familiarity with this library, but I think _forward(...) works on tensors of something like that, they are probably big chunky data structures. += probably saves a copy or whatever.
- thrown_22 2 years agoThis is like saying that passing a struct vs a pointer to a struct is a minor change for C code. I mean it's just one extra * !
- 2 years ago
- jonas21 2 years ago
- staticassertion 2 years agoThis isn't a Python issue, this is a "I'm copying when I don't need to" issue. As I mention elsewhere, you can write this sort of "bug" in almost any language pretty easily (as I demonstrate with Rust).
This isn't a case of "The Python interpreter is bad" it's just that the code is doing what the user asked it to do - create a completely new copy of the data, then overwrite the old copy with it. Immutable operations like this are slow, mutating the value (what += does) is fast.
Granted, a compiled language could recognize that you're doing this, but it also might not - is `+` and `+=` semantically identical such that the compiler can replace one with the other? Maybe? Probably not, if I had to guess. The correct answer is to just use the faster operation, as it is with all language.
I don't know the type of `x`, but I'd suggest another optimization here would be to:
a) Preallocate the buffer rather before mutating it 3x (which is still likely forcing some allocations)
b) Reuse that buffer if it's so important, store it in `self` and clear it before use.
- datalopers 2 years agoThis StackOverflow answer [1] goes into performance details of INPLACE_ADD versus BINARY_ADD.
- teruakohatu 2 years agoI guess this is the beauty of making a model open source.
- myrryr 2 years agoit is a hell of a good case study that is for sure.
- myrryr 2 years ago
- eru 2 years agoI wonder what version of Python they were using?
I'm wondering, because recent version have improved performance a lot. 3.11 is much faster than 3.10, and what's in 3.12 is already much faster than 3.11.
- eminence32 2 years agoThe upstream Stable Diffusion uses python 3.8:
https://github.com/CompVis/stable-diffusion/blob/69ae4b35e0a...
- eru 2 years agoThanks.
- eru 2 years ago
- 2 years ago
- eminence32 2 years ago
- brrrrrm 2 years agoIt’s not clear a JIT compiled language would help much here unless the operations were baked into the JIT itself (which would have to identify the memory savings of an in-place call).
- Waterluvian 2 years agoOne comment asks about putting it all on one line, and this is where interpreted languages without a JIT kinda blow.
Many times I have had to decide if my Python code would be more legible or get free performance.
The thing I like about JavaScript is that I can _usually_ trust the JIT to make my code faster than I could, meaning I can focus entirely on writing clean code.
P.S. you can always hand optimize. If you do, just comment the heck out of it.
- NavinF 2 years agoThis has nothing to do with python. A JITed/AoT compiled version of the old code should do exactly the same thing because it would build the same pytorch graph.
- nodja 2 years ago> Many times I have had to decide if my Python code would be more legible or get free performance.
This is rarely an option that has presented itself to me. If there's a clear performance issue in my code then I probably picked the wrong algorithm or my code has a bug, unless you decided for some reason to do heavy calculations in raw python. If you're doing operations on big chunks of data you should always use something like numpy or jax.
Even OPs issue the clear reason is that it's doing an operation in place instead of creating a copy, for ML models this can only be done at inference time and not training time since you need to keep track of the whole network, hence why the code was in it's unoptimized state.
- NavinF 2 years ago
- 2 years ago
- eesmith 2 years agoLincoln Stein. Now that's a name I've not heard in a long time. A long time.
He's the author of the essay "How Perl Saved the Genome Project", the books "Network Programming with Perl" and "Writing Apache Modules with Perl and C", and a number of Perl packages including CGI.pm - which helped power the dot-com era - and GD.pm.
- teo_zero 2 years agoBut wait... x+=y is equivalent to x=x+y not to x=y+x. Only if + is commutative, then the three are equivalent. Are we sure the + operation is commutatve for this type of data? And does the compiler know it?
It would be interesting to check whether changing every expression to x=x+y has a performance more similar to += or to ...+x
- 2 years ago
- dahfizz 2 years agoIs python in the fast path? Why not rewrite in a performant language for a XXX% speedup?
- savant_penguin 2 years agoIn this case I believe python is faster by a few months.
Jokes aside this is pytorch so this is compiled to C++ or cuda, the problem likely comes from the different functions that are called for += vs +
- bee_rider 2 years agoThe += operator is almost certainly calling some method on sends out the real work to some tuned hardware-specific framework written in a fast language.
- pclmulqdq 2 years agoNot exactly: most of these frameworks essentially JIT compile the entire operation graph so that it can be executed, and the Python code only touches the data at the endpoints of the full computation. I don't know why the JIT compiler doesn't optimize a = b + a to a += b, but I guess they assumed that the JIT-ed code path would only be used once, so the compiler has to be fast.
- dahfizz 2 years agoSo python is marshalling data to and from an ffi in the fast path? That sounds even worse
- bee_rider 2 years agoI'm not sure this is the conventional use of the phrase "fast path."
But anyway, the idea is usually that the Python code calls out to the framework with operations that are in some way "large," and so the overhead is not so significant.
Python probably doesn't have to do any marshaling, hypothetically the framework could just return an object that represents a pointer. Then the python code sends that pointer to another framework method.
- kjeetgill 2 years agoI think you mean critical path (which is ironically usually the slowest path). A fast path is usually a hardcoded shortcut you can take for select cases.
- bee_rider 2 years ago
- pclmulqdq 2 years ago
- savant_penguin 2 years ago
- olliej 2 years agoIs this a lookup overhead thing or a memcpy based overhead regression? In the case of the latter it seems like this may result in an unexpected mutation of the source data?
- thweorui23432 2 years agoSpeedup likely won't work for training the model.
- NavinF 2 years agoYep, intermediate results (activations) are kept in memory during training.
- NavinF 2 years ago
- MaXtreeM 2 years agoThere is a case in C# where using compound assignment is actually slower [0]. Based on comments this should be fixed in .NET7 I haven't checked it myself.
[0]: https://mobile.twitter.com/badamczewski01/status/15618171584...
- mhzsh 2 years agoBut why is it faster? A non-associative translation to byte code (or however python works)?
- lnyan 2 years agoFor PyTorch, `+=` is interpreted as an in-place operation
- onedognight 2 years agoMy guess is that it operates in place with no memory allocations or copying.
- actually_a_dog 2 years agoNot exactly:
>>> def f(x): x += 1 ... >>> def g(x): x = x + 1 ... >>> dis.dis(f) 1 0 LOAD_FAST 0 (x) 3 LOAD_CONST 1 (1) 6 INPLACE_ADD 7 STORE_FAST 0 (x) 10 LOAD_CONST 0 (None) 13 RETURN_VALUE >>> dis.dis(g) 1 0 LOAD_FAST 0 (x) 3 LOAD_CONST 1 (1) 6 BINARY_ADD 7 STORE_FAST 0 (x) 10 LOAD_CONST 0 (None) 13 RETURN_VALUE
- lnyan 2 years ago
- spullara 2 years agoMutation faster than making a new object.
- noobermin 2 years agoWhenever I see things like this in highly visible code that people exclaim about across the internet it makes me really take a moment to absorb how much time I spend agonizing over minutae in my daily work and how people who really are just lucky can get away with much worse. Just a reminder about how the idea that "tech" is a meritocracy was never really true.
- WatchDog 2 years agoI assume that you don't have thousands of people looking over your code, how can you know that it doesn't have similar or greater room for optimization?
- WatchDog 2 years ago